mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-09 04:18:46 +08:00
Lots of updates centered around using the lora patcher rather than changing the modules in the transformer model
This commit is contained in:
parent
5a035dd19f
commit
f53da60b84
@ -9,7 +9,6 @@ from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.modules.lora import replace_linear_with_lora
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
@ -305,8 +304,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
assert isinstance(transformer, Flux)
|
||||
config = transformer_info.config
|
||||
assert config is not None
|
||||
if self.transformer.structural_lora:
|
||||
replace_linear_with_lora(transformer, 128)
|
||||
|
||||
# Apply LoRA models to the transformer.
|
||||
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||
|
@ -20,7 +20,6 @@ from invokeai.backend.flux.modules.layers import (
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
from invokeai.backend.flux.modules.lora import replace_linear_with_lora
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -1,148 +0,0 @@
|
||||
import torch
|
||||
|
||||
import bitsandbytes as bnb
|
||||
from torch import nn
|
||||
|
||||
def replace_linear_with_lora(
|
||||
module: nn.Module,
|
||||
max_rank: int,
|
||||
scale: float = 1.0,
|
||||
) -> None:
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, (LinearLora, BNBNF4LinearLora)):
|
||||
# TODO: We really need to undo this after generation runs
|
||||
return
|
||||
if isinstance(child, nn.Linear):
|
||||
dtype = child.weight.dtype
|
||||
loraModule = LinearLora
|
||||
if hasattr(child, "compute_dtype"):
|
||||
dtype = getattr(child, "compute_dtype")
|
||||
loraModule = BNBNF4LinearLora
|
||||
new_lora = loraModule(
|
||||
# Double the in features to accommodate the control image. This conditional is to avoid increasing the final_layer in features
|
||||
in_features=child.out_features*2 if child.in_features == child.out_features else child.in_features,
|
||||
out_features=child.out_features,
|
||||
bias=child.bias is not None,
|
||||
rank=max_rank,
|
||||
scale=scale,
|
||||
dtype=dtype,
|
||||
device=child.weight.device,
|
||||
)
|
||||
new_lora.weight = child.weight
|
||||
new_lora.bias = child.bias if child.bias is not None else None
|
||||
setattr(module, name, new_lora)
|
||||
else:
|
||||
replace_linear_with_lora(
|
||||
module=child,
|
||||
max_rank=max_rank,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
class BNBNF4LinearLora(bnb.nn.LinearNF4):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool,
|
||||
rank: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
lora_bias: bool = True,
|
||||
scale: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
input_features=in_features,
|
||||
output_features=out_features,
|
||||
bias=bias is not None,
|
||||
device=device,
|
||||
compute_dtype=dtype,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
assert isinstance(scale, float), "scale must be a float"
|
||||
self.scale = scale
|
||||
self.rank = rank
|
||||
self.lora_bias = lora_bias
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if rank > (new_rank := min(self.out_features, self.in_features)):
|
||||
self.rank = new_rank
|
||||
self.lora_A = bnb.nn.LinearNF4(
|
||||
input_features=in_features,
|
||||
output_features=self.rank,
|
||||
bias=False,
|
||||
device=device,
|
||||
compute_dtype=dtype
|
||||
)
|
||||
self.lora_B = bnb.nn.LinearNF4(
|
||||
input_features=self.rank,
|
||||
output_features=out_features,
|
||||
bias=self.lora_bias,
|
||||
device=device,
|
||||
compute_dtype=dtype
|
||||
)
|
||||
def set_scale(self, scale: float) -> None:
|
||||
assert isinstance(scale, float), "scalar value must be a float"
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
base_out = super().forward(x)
|
||||
_lora_out_B = self.lora_B(self.lora_A(x))
|
||||
lora_update = _lora_out_B * self.scale
|
||||
return base_out + lora_update
|
||||
|
||||
class LinearLora(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool,
|
||||
rank: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
lora_bias: bool = True,
|
||||
scale: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias is not None,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
assert isinstance(scale, float), "scale must be a float"
|
||||
self.scale = scale
|
||||
self.rank = rank
|
||||
self.lora_bias = lora_bias
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if rank > (new_rank := min(self.out_features, self.in_features)):
|
||||
self.rank = new_rank
|
||||
self.lora_A = nn.Linear(
|
||||
in_features=in_features,
|
||||
out_features=self.rank,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.lora_B = nn.Linear(
|
||||
in_features=self.rank,
|
||||
out_features=out_features,
|
||||
bias=self.lora_bias,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
def set_scale(self, scale: float) -> None:
|
||||
assert isinstance(scale, float), "scalar value must be a float"
|
||||
self.scale = scale
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
base_out = super().forward(input)
|
||||
_lora_out_B = self.lora_B(self.lora_A(input))
|
||||
lora_update = _lora_out_B * self.scale
|
||||
return base_out + lora_update
|
@ -5,6 +5,9 @@ from typing import Any, Dict
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
|
||||
|
||||
|
||||
# A regex pattern that matches all of the keys in the Flux Dev/Canny LoRA format.
|
||||
@ -26,14 +29,14 @@ def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
|
||||
)
|
||||
|
||||
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
converted_state_dict = _convert_lora_bfl_control(state_dict=state_dict)
|
||||
# converted_state_dict = _convert_lora_bfl_control(state_dict=state_dict)
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for key, value in converted_state_dict.items():
|
||||
for key, value in state_dict.items():
|
||||
key_props = key.split(".")
|
||||
# Got it loading using lora_down and lora_up but it didn't seem to match this lora's structure
|
||||
# Leaving this in since it doesn't hurt anything and may be better
|
||||
layer_prop_size = -2 if any(prop in key for prop in ["lora_down", "lora_up"]) else -1
|
||||
layer_prop_size = -2 if any(prop in key for prop in ["lora_B", "lora_A"]) else -1
|
||||
layer_name = ".".join(key_props[:layer_prop_size])
|
||||
param_name = ".".join(key_props[layer_prop_size:])
|
||||
if layer_name not in grouped_state_dict:
|
||||
@ -44,21 +47,19 @@ def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor])
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
for layer_key, layer_state_dict in grouped_state_dict.items():
|
||||
# Convert to a full layer diff
|
||||
layers[layer_key] = any_lora_layer_from_state_dict(state_dict=layer_state_dict)
|
||||
|
||||
prefixed_key = f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"
|
||||
if all(k in layer_state_dict for k in ["lora_A.weight", "lora_B.bias", "lora_B.weight"]):
|
||||
layers[prefixed_key] = LoRALayer(
|
||||
layer_state_dict["lora_B.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_A.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_B.bias"]
|
||||
)
|
||||
elif "scale" in layer_state_dict:
|
||||
layers[prefixed_key] = SetParameterLayer("scale", layer_state_dict["scale"])
|
||||
else:
|
||||
raise AssertionError(f"{layer_key} not expected")
|
||||
# Create and return the LoRAModelRaw.
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
||||
|
||||
def _convert_lora_bfl_control(state_dict: dict[str, torch.Tensor])-> dict[str, torch.Tensor]:
|
||||
sd_out: dict[str, torch.Tensor] = {}
|
||||
for k in state_dict:
|
||||
if k.endswith(".scale"): # TODO: Fix these patches
|
||||
continue
|
||||
k_to = k.replace(".lora_B.bias", ".lora_B.diff_b")\
|
||||
.replace(".lora_A.weight", ".lora_A.diff")\
|
||||
.replace(".lora_B.weight", ".lora_B.diff")
|
||||
sd_out[k_to] = state_dict[k]
|
||||
|
||||
# sd_out["img_in.reshape_weight"] = torch.tensor([state_dict["img_in.lora_B.weight"].shape[0], state_dict["img_in.lora_A.weight"].shape[1]])
|
||||
return sd_out
|
||||
|
@ -7,6 +7,6 @@ from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
from invokeai.backend.lora.layers.set_weight_layer import SetWeightLayer
|
||||
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer, SetWeightLayer]
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer, SetParameterLayer]
|
||||
|
@ -6,30 +6,22 @@ from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
class SetWeightLayer(LoRALayerBase):
|
||||
class ReshapeWeightLayer(LoRALayerBase):
|
||||
# TODO: Just everything in this class
|
||||
def __init__(self, weight: Optional[torch.Tensor], bias: Optional[torch.Tensor], scale: Optional[torch.Tensor]):
|
||||
super().__init__(alpha=None, bias=bias)
|
||||
self.weight = torch.nn.Parameter(weight) if weight is not None else None
|
||||
self.bias = torch.nn.Parameter(bias) if bias is not None else None
|
||||
self.manual_scale = scale
|
||||
|
||||
def scale(self):
|
||||
return self.manual_scale.float() if self.manual_scale is not None else super().scale()
|
||||
|
||||
@classmethod
|
||||
def from_state_dict_values(
|
||||
cls,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
layer = cls(weight=values.get("set_weight", None), bias=values.get("set_bias", None), scale=values.get("set_scale", None))
|
||||
cls.warn_on_unhandled_keys(values=values, handled_keys={"set_weight", "set_bias", "set_scale"})
|
||||
return layer
|
||||
|
||||
def rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
return orig_weight
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
super().to(device=device, dtype=dtype)
|
29
invokeai/backend/lora/layers/set_parameter_layer.py
Normal file
29
invokeai/backend/lora/layers/set_parameter_layer.py
Normal file
@ -0,0 +1,29 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
class SetParameterLayer(LoRALayerBase):
|
||||
def __init__(self, param_name: str, weight: torch.Tensor):
|
||||
super().__init__(None, None)
|
||||
self.weight = weight
|
||||
self.param_name = param_name
|
||||
|
||||
def rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight - orig_weight
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||
return {self.param_name: self.get_weight(orig_module.get_parameter(self.param_name))}
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
super().to(device=device, dtype=dtype)
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
return super().calc_size() + calc_tensor_size(self.weight)
|
@ -9,7 +9,7 @@ from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
from invokeai.backend.lora.layers.set_weight_layer import SetWeightLayer
|
||||
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
|
||||
|
||||
|
||||
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
|
||||
@ -30,7 +30,5 @@ def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLo
|
||||
return IA3Layer.from_state_dict_values(state_dict)
|
||||
elif "w_norm" in state_dict:
|
||||
return NormLayer.from_state_dict_values(state_dict)
|
||||
elif any(key in state_dict for key in ["set_weight", "set_bias", "set_scale"]):
|
||||
return SetWeightLayer.from_state_dict_values(state_dict)
|
||||
else:
|
||||
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")
|
||||
|
@ -6,11 +6,13 @@ import torch
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.full_layer import FullLayer
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
|
||||
ConcatenatedLoRALinearSidecarLayer,
|
||||
)
|
||||
from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
|
||||
from invokeai.backend.lora.sidecar_layers.lora.lora_full_linear_sidecar_layer import LoRAFullLinearSidecarLayer
|
||||
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
@ -93,8 +95,9 @@ class LoRAPatcher:
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
device = module.weight.device
|
||||
dtype = module.weight.dtype
|
||||
first_param = next(module.parameters())
|
||||
device = first_param.device
|
||||
dtype = first_param.dtype
|
||||
|
||||
layer_scale = layer.scale()
|
||||
|
||||
@ -114,7 +117,13 @@ class LoRAPatcher:
|
||||
original_weights.save(param_key, module_param)
|
||||
|
||||
if module_param.shape != lora_param_weight.shape:
|
||||
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||
if module_param.nelement() == lora_param_weight.nelement():
|
||||
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||
else:
|
||||
expanded_weight = torch.zeros_like(lora_param_weight, device=module_param.device)
|
||||
slices = tuple(slice(0, dim) for dim in module_param.shape)
|
||||
expanded_weight[slices] = module_param
|
||||
setattr(module, param_name, expanded_weight)
|
||||
lora_param_weight *= patch_weight * layer_scale
|
||||
module_param += lora_param_weight.to(dtype=dtype)
|
||||
|
||||
@ -244,6 +253,8 @@ class LoRAPatcher:
|
||||
return LoRALinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight)
|
||||
elif isinstance(lora_layer, ConcatenatedLoRALayer):
|
||||
return ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer=lora_layer, weight=patch_weight)
|
||||
elif isinstance(lora_layer, FullLayer):
|
||||
return LoRAFullLinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight)
|
||||
else:
|
||||
raise ValueError(f"Unsupported Linear LoRA layer type: {type(lora_layer)}")
|
||||
else:
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,70 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.conversions.flux_control_lora_utils import (
|
||||
is_state_dict_likely_flux_control,
|
||||
lora_model_from_flux_control_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from tests.backend.lora.conversions.lora_state_dicts.flux_control_lora_format import (
|
||||
state_dict_keys as flux_control_lora_state_dict_keys,
|
||||
)
|
||||
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
|
||||
state_dict_keys as flux_diffusers_state_dict_keys,
|
||||
)
|
||||
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_keys", [flux_control_lora_state_dict_keys])
|
||||
def test_is_state_dict_likely_in_flux_control_format_true(sd_keys: dict[str, list[int]]):
|
||||
"""Test that is_state_dict_likely_flux_control() can identify a state dict in the FLUX Control LoRA format."""
|
||||
# Construct a state dict that is in the Diffusers FLUX LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||
|
||||
assert is_state_dict_likely_flux_control(state_dict)
|
||||
|
||||
@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys])
|
||||
def test_is_state_dict_likely_in_flux_control_format_false(sd_keys: dict[str, list[int]]):
|
||||
"""Test that is_state_dict_likely_flux_control() returns False for a state dict that is in the Diffusers
|
||||
FLUX LoRA format.
|
||||
"""
|
||||
# Construct a state dict that is not in the FLUX Control LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||
|
||||
assert not is_state_dict_likely_flux_control(state_dict)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_keys", [flux_control_lora_state_dict_keys])
|
||||
def test_lora_model_from_flux_control_state_dict(sd_keys: dict[str, list[int]]):
|
||||
"""Test that lora_model_from_flux_control_state_dict() can load a state dict in the FLUX Control LoRA format."""
|
||||
# Construct a state dict that is in the FLUX Control LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||
# Load the state dict into a LoRAModelRaw object.
|
||||
model = lora_model_from_flux_control_state_dict(state_dict)
|
||||
|
||||
# Check that the model has the correct number of LoRA layers.
|
||||
expected_lora_layers: set[str] = set()
|
||||
for k in sd_keys:
|
||||
k = k.replace("lora_A.weight", "")
|
||||
k = k.replace("lora_B.weight", "")
|
||||
k = k.replace("lora_B.bias", "")
|
||||
k = k.replace(".scale", "")
|
||||
expected_lora_layers.add(k)
|
||||
# Drop the K/V/proj_mlp weights because these are all concatenated into a single layer in the BFL format (we keep
|
||||
# the Q weights so that we count these layers once).
|
||||
assert len(model.layers) == len(expected_lora_layers)
|
||||
assert all(k.startswith(FLUX_LORA_TRANSFORMER_PREFIX) for k in model.layers.keys())
|
||||
|
||||
|
||||
def test_lora_model_from_flux_control_state_dict_extra_keys_error():
|
||||
"""Test that lora_model_from_flux_control_state_dict() raises an error if the input state_dict contains unexpected
|
||||
keys that we don't handle.
|
||||
"""
|
||||
# Construct a state dict that is in the FLUX Control LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(flux_control_lora_state_dict_keys)
|
||||
# Add an unexpected key.
|
||||
state_dict["transformer.single_transformer_blocks.0.unexpected_key.lora_A.weight"] = torch.empty(1)
|
||||
|
||||
# Check that an error is raised.
|
||||
with pytest.raises(AssertionError):
|
||||
lora_model_from_flux_control_state_dict(state_dict)
|
Loading…
Reference in New Issue
Block a user