Improve docs related to dynamic T5 sequence length selection.

This commit is contained in:
Ryan Dick 2024-11-29 16:11:51 +00:00
parent 4581a37a48
commit 8d04ec3f95
2 changed files with 14 additions and 2 deletions

View File

@ -81,7 +81,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
valid_seq_lens = [self.t5_max_seq_len]
if self.use_short_t5_seq_len:
valid_seq_lens = [128, 256, 512]
# We allow a minimum sequence length of 128. Going too short results in more significant image chagnes.
valid_seq_lens = list(range(128, self.t5_max_seq_len, 128))
valid_seq_lens.append(self.t5_max_seq_len)
with (
t5_text_encoder_info as t5_text_encoder,

View File

@ -15,7 +15,17 @@ class HFEncoder(nn.Module):
self.hf_module = self.hf_module.eval().requires_grad_(False)
def forward(self, text: list[str], valid_seq_lens: list[int]) -> Tensor:
"""Encode text into a tensor.
Args:
text: A list of text prompts to encode.
valid_seq_lens: A list of valid sequence lengths. The shortest valid sequence length that can contain the
text will be used. If the largest valid sequence length cannot contain the text, the encoding will be
truncated.
"""
valid_seq_lens = sorted(valid_seq_lens)
# Perform initial encoding with the maximum valid sequence length.
batch_encoding = self.tokenizer(
text,
truncation=True,
@ -26,8 +36,8 @@ class HFEncoder(nn.Module):
return_tensors="pt",
)
seq_len: int = batch_encoding["length"][0].item()
# Find selected_seq_len, the minimum valid sequence length that can contain all of the input tokens.
seq_len: int = batch_encoding["length"][0].item()
selected_seq_len = valid_seq_lens[-1]
for len in valid_seq_lens:
if len >= seq_len: