mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-07 03:17:05 +08:00
Dynamically select smaller t5 seq len to save inference time.
This commit is contained in:
parent
a76a1244bc
commit
437d1087a2
@ -22,7 +22,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
|
||||
title="FLUX Text Encoding",
|
||||
tags=["prompt", "conditioning", "flux"],
|
||||
category="conditioning",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxTextEncoderInvocation(BaseInvocation):
|
||||
@ -41,6 +41,11 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
t5_max_seq_len: Literal[256, 512] = InputField(
|
||||
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
|
||||
)
|
||||
use_short_t5_seq_len: bool = InputField(
|
||||
description="Use a shorter sequence length for the T5 encoder if a short prompt is used. This can improve "
|
||||
+ "performance and reduced peak memory, but may result in slightly different image outputs.",
|
||||
default=True,
|
||||
)
|
||||
prompt: str = InputField(
|
||||
description="Text prompt to encode.",
|
||||
ui_component=UIComponent.Textarea,
|
||||
@ -65,6 +70,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
valid_seq_lens = [self.t5_max_seq_len]
|
||||
if self.use_short_t5_seq_len:
|
||||
valid_seq_lens = [128, 256, 512]
|
||||
|
||||
with (
|
||||
t5_text_encoder_info as t5_text_encoder,
|
||||
t5_tokenizer_info as t5_tokenizer,
|
||||
@ -72,10 +81,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
assert isinstance(t5_tokenizer, T5Tokenizer)
|
||||
|
||||
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
|
||||
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False)
|
||||
|
||||
context.util.signal_progress("Running T5 encoder")
|
||||
prompt_embeds = t5_encoder(prompt)
|
||||
prompt_embeds = t5_encoder(prompt, valid_seq_lens)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds
|
||||
@ -113,10 +122,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
# There are currently no supported CLIP quantized models. Add support here if needed.
|
||||
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
|
||||
|
||||
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
|
||||
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True)
|
||||
|
||||
context.util.signal_progress("Running CLIP encoder")
|
||||
pooled_prompt_embeds = clip_encoder(prompt)
|
||||
pooled_prompt_embeds = clip_encoder(prompt, [77])
|
||||
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
return pooled_prompt_embeds
|
||||
|
@ -1,32 +1,43 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
|
||||
from torch import Tensor, nn
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
|
||||
class HFEncoder(nn.Module):
|
||||
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
|
||||
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool):
|
||||
super().__init__()
|
||||
self.max_length = max_length
|
||||
self.is_clip = is_clip
|
||||
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
||||
self.tokenizer = tokenizer
|
||||
self.hf_module = encoder
|
||||
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
||||
|
||||
def forward(self, text: list[str]) -> Tensor:
|
||||
def forward(self, text: list[str], valid_seq_lens: list[int]) -> Tensor:
|
||||
valid_seq_lens = sorted(valid_seq_lens)
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=False,
|
||||
max_length=max(valid_seq_lens),
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
seq_len: int = batch_encoding["length"][0].item()
|
||||
# Find selected_seq_len, the minimum valid sequence length that can contain all of the input tokens.
|
||||
selected_seq_len = valid_seq_lens[-1]
|
||||
for len in valid_seq_lens:
|
||||
if len >= seq_len:
|
||||
selected_seq_len = len
|
||||
break
|
||||
|
||||
input_ids = batch_encoding["input_ids"][..., :selected_seq_len]
|
||||
|
||||
outputs = self.hf_module(
|
||||
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
||||
input_ids=input_ids.to(self.hf_module.device),
|
||||
attention_mask=None,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user