mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-09 12:37:45 +08:00
Make T5 encoder optonal in SD3 workflows.
This commit is contained in:
parent
f1de11d6bf
commit
1eca4f12c8
@ -17,6 +17,7 @@ from invokeai.app.invocations.fields import (
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
@ -61,7 +62,12 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _load_text_conditioning(
|
||||
self, context: InvocationContext, conditioning_name: str, dtype: torch.dtype, device: torch.device
|
||||
self,
|
||||
context: InvocationContext,
|
||||
conditioning_name: str,
|
||||
joint_attention_dim: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(conditioning_name)
|
||||
@ -72,8 +78,11 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
t5_embeds = sd3_conditioning.t5_embeds
|
||||
if t5_embeds is None:
|
||||
# TODO(ryand): Construct a zero tensor of the correct shape to use as the T5 conditioning.
|
||||
raise NotImplementedError("SD3 inference without T5 conditioning is not yet supported.")
|
||||
t5_embeds = torch.zeros(
|
||||
(1, SD3_T5_MAX_SEQ_LEN, joint_attention_dim),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
clip_prompt_embeds = torch.cat([sd3_conditioning.clip_l_embeds, sd3_conditioning.clip_g_embeds], dim=-1)
|
||||
clip_prompt_embeds = torch.nn.functional.pad(
|
||||
@ -138,14 +147,24 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
inference_dtype = TorchDevice.choose_torch_dtype()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
|
||||
# Load/process the conditioning data.
|
||||
# TODO(ryand): Make CFG optional.
|
||||
do_classifier_free_guidance = True
|
||||
pos_prompt_embeds, pos_pooled_prompt_embeds = self._load_text_conditioning(
|
||||
context, self.positive_text_conditioning.conditioning_name, inference_dtype, device
|
||||
context=context,
|
||||
conditioning_name=self.positive_text_conditioning.conditioning_name,
|
||||
joint_attention_dim=transformer_info.model.config.joint_attention_dim,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
neg_prompt_embeds, neg_pooled_prompt_embeds = self._load_text_conditioning(
|
||||
context, self.negative_text_conditioning.conditioning_name, inference_dtype, device
|
||||
context=context,
|
||||
conditioning_name=self.negative_text_conditioning.conditioning_name,
|
||||
joint_attention_dim=transformer_info.model.config.joint_attention_dim,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
# TODO(ryand): Support both sequential and batched CFG inference.
|
||||
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
|
||||
@ -160,8 +179,6 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# Prepare the CFG scale list.
|
||||
cfg_scale = self._prepare_cfg_scale(len(timesteps))
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
|
||||
# Generate initial latent noise.
|
||||
num_channels_latents = transformer_info.model.config.in_channels
|
||||
assert isinstance(num_channels_latents, int)
|
||||
|
@ -22,6 +22,9 @@ from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
|
||||
|
||||
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
|
||||
SD3_T5_MAX_SEQ_LEN = 256
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_text_encoder",
|
||||
@ -48,6 +51,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
# The SD3 models were trained with text encoder dropout, so the T5 encoder can be omitted to save time/memory.
|
||||
t5_encoder: T5EncoderField | None = InputField(
|
||||
title="T5Encoder",
|
||||
default=None,
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
@ -61,10 +65,9 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
clip_l_embeddings, clip_l_pooled_embeddings = self._clip_encode(context, self.clip_l)
|
||||
clip_g_embeddings, clip_g_pooled_embeddings = self._clip_encode(context, self.clip_g)
|
||||
|
||||
t5_max_seq_len = 256
|
||||
t5_embeddings: torch.Tensor | None = None
|
||||
if self.t5_encoder is not None:
|
||||
t5_embeddings = self._t5_encode(context, t5_max_seq_len)
|
||||
t5_embeddings = self._t5_encode(context, SD3_T5_MAX_SEQ_LEN)
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
|
Loading…
Reference in New Issue
Block a user