Dynamically select smaller t5 seq len to save inference time.

This commit is contained in:
Ryan Dick 2024-11-29 00:15:32 +00:00
parent a76a1244bc
commit 437d1087a2
2 changed files with 31 additions and 11 deletions

View File

@ -22,7 +22,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
title="FLUX Text Encoding",
tags=["prompt", "conditioning", "flux"],
category="conditioning",
version="1.1.0",
version="1.2.0",
classification=Classification.Prototype,
)
class FluxTextEncoderInvocation(BaseInvocation):
@ -41,6 +41,11 @@ class FluxTextEncoderInvocation(BaseInvocation):
t5_max_seq_len: Literal[256, 512] = InputField(
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
)
use_short_t5_seq_len: bool = InputField(
description="Use a shorter sequence length for the T5 encoder if a short prompt is used. This can improve "
+ "performance and reduced peak memory, but may result in slightly different image outputs.",
default=True,
)
prompt: str = InputField(
description="Text prompt to encode.",
ui_component=UIComponent.Textarea,
@ -65,6 +70,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
prompt = [self.prompt]
valid_seq_lens = [self.t5_max_seq_len]
if self.use_short_t5_seq_len:
valid_seq_lens = [128, 256, 512]
with (
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
@ -72,10 +81,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, T5Tokenizer)
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False)
context.util.signal_progress("Running T5 encoder")
prompt_embeds = t5_encoder(prompt)
prompt_embeds = t5_encoder(prompt, valid_seq_lens)
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
@ -113,10 +122,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
# There are currently no supported CLIP quantized models. Add support here if needed.
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True)
context.util.signal_progress("Running CLIP encoder")
pooled_prompt_embeds = clip_encoder(prompt)
pooled_prompt_embeds = clip_encoder(prompt, [77])
assert isinstance(pooled_prompt_embeds, torch.Tensor)
return pooled_prompt_embeds

View File

@ -1,32 +1,43 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from torch import Tensor, nn
from transformers import PreTrainedModel, PreTrainedTokenizer
class HFEncoder(nn.Module):
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool):
super().__init__()
self.max_length = max_length
self.is_clip = is_clip
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
self.tokenizer = tokenizer
self.hf_module = encoder
self.hf_module = self.hf_module.eval().requires_grad_(False)
def forward(self, text: list[str]) -> Tensor:
def forward(self, text: list[str], valid_seq_lens: list[int]) -> Tensor:
valid_seq_lens = sorted(valid_seq_lens)
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
max_length=max(valid_seq_lens),
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
seq_len: int = batch_encoding["length"][0].item()
# Find selected_seq_len, the minimum valid sequence length that can contain all of the input tokens.
selected_seq_len = valid_seq_lens[-1]
for len in valid_seq_lens:
if len >= seq_len:
selected_seq_len = len
break
input_ids = batch_encoding["input_ids"][..., :selected_seq_len]
outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
input_ids=input_ids.to(self.hf_module.device),
attention_mask=None,
output_hidden_states=False,
)