Make T5 encoder optonal in SD3 workflows.

This commit is contained in:
Ryan Dick 2024-10-25 18:49:28 +00:00 committed by Brandon
parent f1de11d6bf
commit 1eca4f12c8
2 changed files with 29 additions and 9 deletions

View File

@ -17,6 +17,7 @@ from invokeai.app.invocations.fields import (
)
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
@ -61,7 +62,12 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
def _load_text_conditioning(
self, context: InvocationContext, conditioning_name: str, dtype: torch.dtype, device: torch.device
self,
context: InvocationContext,
conditioning_name: str,
joint_attention_dim: int,
dtype: torch.dtype,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Load the conditioning data.
cond_data = context.conditioning.load(conditioning_name)
@ -72,8 +78,11 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
t5_embeds = sd3_conditioning.t5_embeds
if t5_embeds is None:
# TODO(ryand): Construct a zero tensor of the correct shape to use as the T5 conditioning.
raise NotImplementedError("SD3 inference without T5 conditioning is not yet supported.")
t5_embeds = torch.zeros(
(1, SD3_T5_MAX_SEQ_LEN, joint_attention_dim),
device=device,
dtype=dtype,
)
clip_prompt_embeds = torch.cat([sd3_conditioning.clip_l_embeds, sd3_conditioning.clip_g_embeds], dim=-1)
clip_prompt_embeds = torch.nn.functional.pad(
@ -138,14 +147,24 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
inference_dtype = TorchDevice.choose_torch_dtype()
device = TorchDevice.choose_torch_device()
transformer_info = context.models.load(self.transformer.transformer)
# Load/process the conditioning data.
# TODO(ryand): Make CFG optional.
do_classifier_free_guidance = True
pos_prompt_embeds, pos_pooled_prompt_embeds = self._load_text_conditioning(
context, self.positive_text_conditioning.conditioning_name, inference_dtype, device
context=context,
conditioning_name=self.positive_text_conditioning.conditioning_name,
joint_attention_dim=transformer_info.model.config.joint_attention_dim,
dtype=inference_dtype,
device=device,
)
neg_prompt_embeds, neg_pooled_prompt_embeds = self._load_text_conditioning(
context, self.negative_text_conditioning.conditioning_name, inference_dtype, device
context=context,
conditioning_name=self.negative_text_conditioning.conditioning_name,
joint_attention_dim=transformer_info.model.config.joint_attention_dim,
dtype=inference_dtype,
device=device,
)
# TODO(ryand): Support both sequential and batched CFG inference.
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
@ -160,8 +179,6 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# Prepare the CFG scale list.
cfg_scale = self._prepare_cfg_scale(len(timesteps))
transformer_info = context.models.load(self.transformer.transformer)
# Generate initial latent noise.
num_channels_latents = transformer_info.model.config.in_channels
assert isinstance(num_channels_latents, int)

View File

@ -22,6 +22,9 @@ from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
SD3_T5_MAX_SEQ_LEN = 256
@invocation(
"sd3_text_encoder",
@ -48,6 +51,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
# The SD3 models were trained with text encoder dropout, so the T5 encoder can be omitted to save time/memory.
t5_encoder: T5EncoderField | None = InputField(
title="T5Encoder",
default=None,
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
@ -61,10 +65,9 @@ class Sd3TextEncoderInvocation(BaseInvocation):
clip_l_embeddings, clip_l_pooled_embeddings = self._clip_encode(context, self.clip_l)
clip_g_embeddings, clip_g_pooled_embeddings = self._clip_encode(context, self.clip_g)
t5_max_seq_len = 256
t5_embeddings: torch.Tensor | None = None
if self.t5_encoder is not None:
t5_embeddings = self._t5_encode(context, t5_max_seq_len)
t5_embeddings = self._t5_encode(context, SD3_T5_MAX_SEQ_LEN)
conditioning_data = ConditioningFieldData(
conditionings=[