mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-07 03:17:05 +08:00
Improve docs related to dynamic T5 sequence length selection.
This commit is contained in:
parent
4581a37a48
commit
8d04ec3f95
@ -81,7 +81,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
valid_seq_lens = [self.t5_max_seq_len]
|
||||
if self.use_short_t5_seq_len:
|
||||
valid_seq_lens = [128, 256, 512]
|
||||
# We allow a minimum sequence length of 128. Going too short results in more significant image chagnes.
|
||||
valid_seq_lens = list(range(128, self.t5_max_seq_len, 128))
|
||||
valid_seq_lens.append(self.t5_max_seq_len)
|
||||
|
||||
with (
|
||||
t5_text_encoder_info as t5_text_encoder,
|
||||
|
@ -15,7 +15,17 @@ class HFEncoder(nn.Module):
|
||||
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
||||
|
||||
def forward(self, text: list[str], valid_seq_lens: list[int]) -> Tensor:
|
||||
"""Encode text into a tensor.
|
||||
|
||||
Args:
|
||||
text: A list of text prompts to encode.
|
||||
valid_seq_lens: A list of valid sequence lengths. The shortest valid sequence length that can contain the
|
||||
text will be used. If the largest valid sequence length cannot contain the text, the encoding will be
|
||||
truncated.
|
||||
"""
|
||||
valid_seq_lens = sorted(valid_seq_lens)
|
||||
|
||||
# Perform initial encoding with the maximum valid sequence length.
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
@ -26,8 +36,8 @@ class HFEncoder(nn.Module):
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
seq_len: int = batch_encoding["length"][0].item()
|
||||
# Find selected_seq_len, the minimum valid sequence length that can contain all of the input tokens.
|
||||
seq_len: int = batch_encoding["length"][0].item()
|
||||
selected_seq_len = valid_seq_lens[-1]
|
||||
for len in valid_seq_lens:
|
||||
if len >= seq_len:
|
||||
|
Loading…
Reference in New Issue
Block a user