Update Transformers to 4.34 and fix pad_to_multiple_of

This commit is contained in:
Wubbbi 2023-10-06 18:55:59 +02:00 committed by Kent Keirsey
parent 8702a63197
commit a0be83e370
2 changed files with 11 additions and 4 deletions

View File

@ -166,6 +166,13 @@ class ModelPatcher:
init_tokens_count = None
new_tokens_added = None
# This is required since Transformers 4.32, see transformers/pull/25088
# More information: https://tinyurl.com/ycxxzdhh
if "A100" in torch.cuda.get_device_name():
pad_to_multiple_of = 64
else:
pad_to_multiple_of = 8
try:
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
@ -175,7 +182,7 @@ class ModelPatcher:
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
ti_manager = TextualInversionManager(ti_tokenizer)
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
init_tokens_count = text_encoder.resize_token_embeddings(None, pad_to_multiple_of).num_embeddings
def _get_trigger(ti_name, index):
trigger = ti_name
@ -190,7 +197,7 @@ class ModelPatcher:
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
# modify text_encoder
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of)
model_embeddings = text_encoder.get_input_embeddings()
for ti_name, ti in ti_list:
@ -222,7 +229,7 @@ class ModelPatcher:
finally:
if init_tokens_count and new_tokens_added:
text_encoder.resize_token_embeddings(init_tokens_count)
text_encoder.resize_token_embeddings(init_tokens_count, pad_to_multiple_of)
@classmethod
@contextmanager

View File

@ -82,7 +82,7 @@ dependencies = [
"torchvision~=0.16",
"torchmetrics~=0.11.0",
"torchsde~=0.2.5",
"transformers~=4.31.0",
"transformers~=4.34.0",
"uvicorn[standard]~=0.21.1",
"windows-curses; sys_platform=='win32'",
]