Add flag to optionally allow missing layer keys in FLUX lora loader.

This commit is contained in:
Ryan Dick 2024-11-11 19:36:40 +00:00
parent 3510643870
commit 50897ba066

View File

@ -64,7 +64,10 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
assert len(src_layer_dict) == 0
def add_qkv_lora_layer_if_present(
src_keys: list[str], src_weight_shapes: list[tuple[int, int]], dst_qkv_key: str
src_keys: list[str],
src_weight_shapes: list[tuple[int, int]],
dst_qkv_key: str,
allow_missing_keys: bool = False,
) -> None:
"""Handle the Q, K, V matrices for a transformer block. We need special handling because the diffusers format
stores them in separate matrices, whereas the BFL format used internally by InvokeAI concatenates them.
@ -74,9 +77,9 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
if not any(keys_present):
return
src_layer_dicts = [grouped_state_dict.pop(key, None) for key in src_keys]
sub_layers: list[LoRALayer] = []
for src_layer_dict, src_weight_shape in zip(src_layer_dicts, src_weight_shapes, strict=True):
for src_key, src_weight_shape in zip(src_keys, src_weight_shapes, strict=True):
src_layer_dict = grouped_state_dict.pop(src_key, None)
if src_layer_dict is not None:
values = {
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
@ -89,6 +92,8 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
assert len(src_layer_dict) == 0
else:
if not allow_missing_keys:
raise ValueError(f"Missing LoRA layer: '{src_key}'.")
values = {
"lora_up.weight": torch.zeros((src_weight_shape[0], 1)),
"lora_down.weight": torch.zeros((1, src_weight_shape[1])),
@ -193,6 +198,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
(mlp_hidden_dim, hidden_size),
],
f"single_blocks.{i}.linear1",
allow_missing_keys=True,
)
# Output projections.