mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-04-04 22:43:40 +08:00
Add flag to optionally allow missing layer keys in FLUX lora loader.
This commit is contained in:
parent
3510643870
commit
50897ba066
@ -64,7 +64,10 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
assert len(src_layer_dict) == 0
|
||||
|
||||
def add_qkv_lora_layer_if_present(
|
||||
src_keys: list[str], src_weight_shapes: list[tuple[int, int]], dst_qkv_key: str
|
||||
src_keys: list[str],
|
||||
src_weight_shapes: list[tuple[int, int]],
|
||||
dst_qkv_key: str,
|
||||
allow_missing_keys: bool = False,
|
||||
) -> None:
|
||||
"""Handle the Q, K, V matrices for a transformer block. We need special handling because the diffusers format
|
||||
stores them in separate matrices, whereas the BFL format used internally by InvokeAI concatenates them.
|
||||
@ -74,9 +77,9 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
if not any(keys_present):
|
||||
return
|
||||
|
||||
src_layer_dicts = [grouped_state_dict.pop(key, None) for key in src_keys]
|
||||
sub_layers: list[LoRALayer] = []
|
||||
for src_layer_dict, src_weight_shape in zip(src_layer_dicts, src_weight_shapes, strict=True):
|
||||
for src_key, src_weight_shape in zip(src_keys, src_weight_shapes, strict=True):
|
||||
src_layer_dict = grouped_state_dict.pop(src_key, None)
|
||||
if src_layer_dict is not None:
|
||||
values = {
|
||||
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
|
||||
@ -89,6 +92,8 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
|
||||
assert len(src_layer_dict) == 0
|
||||
else:
|
||||
if not allow_missing_keys:
|
||||
raise ValueError(f"Missing LoRA layer: '{src_key}'.")
|
||||
values = {
|
||||
"lora_up.weight": torch.zeros((src_weight_shape[0], 1)),
|
||||
"lora_down.weight": torch.zeros((1, src_weight_shape[1])),
|
||||
@ -193,6 +198,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
(mlp_hidden_dim, hidden_size),
|
||||
],
|
||||
f"single_blocks.{i}.linear1",
|
||||
allow_missing_keys=True,
|
||||
)
|
||||
|
||||
# Output projections.
|
||||
|
Loading…
x
Reference in New Issue
Block a user