mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-04-03 07:21:32 +08:00
140 lines
6.0 KiB
Python
140 lines
6.0 KiB
Python
from contextlib import ExitStack
|
|
from typing import Iterator, Literal, Optional, Tuple
|
|
|
|
import torch
|
|
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
|
|
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
|
from invokeai.app.invocations.fields import (
|
|
FieldDescriptions,
|
|
FluxConditioningField,
|
|
Input,
|
|
InputField,
|
|
TensorField,
|
|
UIComponent,
|
|
)
|
|
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
|
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
|
from invokeai.backend.model_manager.config import ModelFormat
|
|
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
|
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
|
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
|
|
|
|
|
@invocation(
|
|
"flux_text_encoder",
|
|
title="FLUX Text Encoding",
|
|
tags=["prompt", "conditioning", "flux"],
|
|
category="conditioning",
|
|
version="1.1.1",
|
|
classification=Classification.Prototype,
|
|
)
|
|
class FluxTextEncoderInvocation(BaseInvocation):
|
|
"""Encodes and preps a prompt for a flux image."""
|
|
|
|
clip: CLIPField = InputField(
|
|
title="CLIP",
|
|
description=FieldDescriptions.clip,
|
|
input=Input.Connection,
|
|
)
|
|
t5_encoder: T5EncoderField = InputField(
|
|
title="T5Encoder",
|
|
description=FieldDescriptions.t5_encoder,
|
|
input=Input.Connection,
|
|
)
|
|
t5_max_seq_len: Literal[256, 512] = InputField(
|
|
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
|
|
)
|
|
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
|
|
mask: Optional[TensorField] = InputField(
|
|
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
|
# Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally
|
|
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
|
|
t5_embeddings = self._t5_encode(context)
|
|
clip_embeddings = self._clip_encode(context)
|
|
conditioning_data = ConditioningFieldData(
|
|
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
|
)
|
|
|
|
conditioning_name = context.conditioning.save(conditioning_data)
|
|
return FluxConditioningOutput(
|
|
conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask)
|
|
)
|
|
|
|
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
|
|
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
|
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
|
|
|
prompt = [self.prompt]
|
|
|
|
with (
|
|
t5_text_encoder_info as t5_text_encoder,
|
|
t5_tokenizer_info as t5_tokenizer,
|
|
):
|
|
assert isinstance(t5_text_encoder, T5EncoderModel)
|
|
assert isinstance(t5_tokenizer, T5Tokenizer)
|
|
|
|
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
|
|
|
|
context.util.signal_progress("Running T5 encoder")
|
|
prompt_embeds = t5_encoder(prompt)
|
|
|
|
assert isinstance(prompt_embeds, torch.Tensor)
|
|
return prompt_embeds
|
|
|
|
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
|
|
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
|
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
|
|
|
prompt = [self.prompt]
|
|
|
|
with (
|
|
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
|
|
clip_tokenizer_info as clip_tokenizer,
|
|
ExitStack() as exit_stack,
|
|
):
|
|
assert isinstance(clip_text_encoder, CLIPTextModel)
|
|
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
|
|
|
clip_text_encoder_config = clip_text_encoder_info.config
|
|
assert clip_text_encoder_config is not None
|
|
|
|
# Apply LoRA models to the CLIP encoder.
|
|
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
|
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
|
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
|
exit_stack.enter_context(
|
|
LayerPatcher.apply_smart_model_patches(
|
|
model=clip_text_encoder,
|
|
patches=self._clip_lora_iterator(context),
|
|
prefix=FLUX_LORA_CLIP_PREFIX,
|
|
dtype=clip_text_encoder.dtype,
|
|
cached_weights=cached_weights,
|
|
)
|
|
)
|
|
else:
|
|
# There are currently no supported CLIP quantized models. Add support here if needed.
|
|
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
|
|
|
|
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
|
|
|
|
context.util.signal_progress("Running CLIP encoder")
|
|
pooled_prompt_embeds = clip_encoder(prompt)
|
|
|
|
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
|
return pooled_prompt_embeds
|
|
|
|
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
|
for lora in self.clip.loras:
|
|
lora_info = context.models.load(lora.lora)
|
|
assert isinstance(lora_info.model, ModelPatchRaw)
|
|
yield (lora_info.model, lora.weight)
|
|
del lora_info
|