Lots of updates centered around using the lora patcher rather than changing the modules in the transformer model

This commit is contained in:
Brandon Rising 2024-12-11 14:14:50 -05:00 committed by Kent Keirsey
parent 5a035dd19f
commit f53da60b84
11 changed files with 1237 additions and 189 deletions

View File

@ -9,7 +9,6 @@ from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.modules.lora import replace_linear_with_lora
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
@ -305,8 +304,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
assert isinstance(transformer, Flux)
config = transformer_info.config
assert config is not None
if self.transformer.structural_lora:
replace_linear_with_lora(transformer, 128)
# Apply LoRA models to the transformer.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.

View File

@ -20,7 +20,6 @@ from invokeai.backend.flux.modules.layers import (
SingleStreamBlock,
timestep_embedding,
)
from invokeai.backend.flux.modules.lora import replace_linear_with_lora
@dataclass

View File

@ -1,148 +0,0 @@
import torch
import bitsandbytes as bnb
from torch import nn
def replace_linear_with_lora(
module: nn.Module,
max_rank: int,
scale: float = 1.0,
) -> None:
for name, child in module.named_children():
if isinstance(child, (LinearLora, BNBNF4LinearLora)):
# TODO: We really need to undo this after generation runs
return
if isinstance(child, nn.Linear):
dtype = child.weight.dtype
loraModule = LinearLora
if hasattr(child, "compute_dtype"):
dtype = getattr(child, "compute_dtype")
loraModule = BNBNF4LinearLora
new_lora = loraModule(
# Double the in features to accommodate the control image. This conditional is to avoid increasing the final_layer in features
in_features=child.out_features*2 if child.in_features == child.out_features else child.in_features,
out_features=child.out_features,
bias=child.bias is not None,
rank=max_rank,
scale=scale,
dtype=dtype,
device=child.weight.device,
)
new_lora.weight = child.weight
new_lora.bias = child.bias if child.bias is not None else None
setattr(module, name, new_lora)
else:
replace_linear_with_lora(
module=child,
max_rank=max_rank,
scale=scale,
)
class BNBNF4LinearLora(bnb.nn.LinearNF4):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
rank: int,
dtype: torch.dtype,
device: torch.device,
lora_bias: bool = True,
scale: float = 1.0,
*args,
**kwargs,
) -> None:
super().__init__(
input_features=in_features,
output_features=out_features,
bias=bias is not None,
device=device,
compute_dtype=dtype,
*args,
**kwargs,
)
assert isinstance(scale, float), "scale must be a float"
self.scale = scale
self.rank = rank
self.lora_bias = lora_bias
self.dtype = dtype
self.device = device
if rank > (new_rank := min(self.out_features, self.in_features)):
self.rank = new_rank
self.lora_A = bnb.nn.LinearNF4(
input_features=in_features,
output_features=self.rank,
bias=False,
device=device,
compute_dtype=dtype
)
self.lora_B = bnb.nn.LinearNF4(
input_features=self.rank,
output_features=out_features,
bias=self.lora_bias,
device=device,
compute_dtype=dtype
)
def set_scale(self, scale: float) -> None:
assert isinstance(scale, float), "scalar value must be a float"
self.scale = scale
def forward(self, x: torch.Tensor) -> torch.Tensor:
base_out = super().forward(x)
_lora_out_B = self.lora_B(self.lora_A(x))
lora_update = _lora_out_B * self.scale
return base_out + lora_update
class LinearLora(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
rank: int,
dtype: torch.dtype,
device: torch.device,
lora_bias: bool = True,
scale: float = 1.0,
*args,
**kwargs,
) -> None:
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias is not None,
device=device,
dtype=dtype,
*args,
**kwargs,
)
assert isinstance(scale, float), "scale must be a float"
self.scale = scale
self.rank = rank
self.lora_bias = lora_bias
self.dtype = dtype
self.device = device
if rank > (new_rank := min(self.out_features, self.in_features)):
self.rank = new_rank
self.lora_A = nn.Linear(
in_features=in_features,
out_features=self.rank,
bias=False,
dtype=dtype,
device=device,
)
self.lora_B = nn.Linear(
in_features=self.rank,
out_features=out_features,
bias=self.lora_bias,
dtype=dtype,
device=device,
)
def set_scale(self, scale: float) -> None:
assert isinstance(scale, float), "scalar value must be a float"
self.scale = scale
def forward(self, input: torch.Tensor) -> torch.Tensor:
base_out = super().forward(input)
_lora_out_B = self.lora_B(self.lora_A(input))
lora_update = _lora_out_B * self.scale
return base_out + lora_update

View File

@ -5,6 +5,9 @@ from typing import Any, Dict
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
# A regex pattern that matches all of the keys in the Flux Dev/Canny LoRA format.
@ -26,14 +29,14 @@ def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
)
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
converted_state_dict = _convert_lora_bfl_control(state_dict=state_dict)
# converted_state_dict = _convert_lora_bfl_control(state_dict=state_dict)
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for key, value in converted_state_dict.items():
for key, value in state_dict.items():
key_props = key.split(".")
# Got it loading using lora_down and lora_up but it didn't seem to match this lora's structure
# Leaving this in since it doesn't hurt anything and may be better
layer_prop_size = -2 if any(prop in key for prop in ["lora_down", "lora_up"]) else -1
layer_prop_size = -2 if any(prop in key for prop in ["lora_B", "lora_A"]) else -1
layer_name = ".".join(key_props[:layer_prop_size])
param_name = ".".join(key_props[layer_prop_size:])
if layer_name not in grouped_state_dict:
@ -44,21 +47,19 @@ def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor])
layers: dict[str, AnyLoRALayer] = {}
for layer_key, layer_state_dict in grouped_state_dict.items():
# Convert to a full layer diff
layers[layer_key] = any_lora_layer_from_state_dict(state_dict=layer_state_dict)
prefixed_key = f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"
if all(k in layer_state_dict for k in ["lora_A.weight", "lora_B.bias", "lora_B.weight"]):
layers[prefixed_key] = LoRALayer(
layer_state_dict["lora_B.weight"],
None,
layer_state_dict["lora_A.weight"],
None,
layer_state_dict["lora_B.bias"]
)
elif "scale" in layer_state_dict:
layers[prefixed_key] = SetParameterLayer("scale", layer_state_dict["scale"])
else:
raise AssertionError(f"{layer_key} not expected")
# Create and return the LoRAModelRaw.
return LoRAModelRaw(layers=layers)
def _convert_lora_bfl_control(state_dict: dict[str, torch.Tensor])-> dict[str, torch.Tensor]:
sd_out: dict[str, torch.Tensor] = {}
for k in state_dict:
if k.endswith(".scale"): # TODO: Fix these patches
continue
k_to = k.replace(".lora_B.bias", ".lora_B.diff_b")\
.replace(".lora_A.weight", ".lora_A.diff")\
.replace(".lora_B.weight", ".lora_B.diff")
sd_out[k_to] = state_dict[k]
# sd_out["img_in.reshape_weight"] = torch.tensor([state_dict["img_in.lora_B.weight"].shape[0], state_dict["img_in.lora_A.weight"].shape[1]])
return sd_out

View File

@ -7,6 +7,6 @@ from invokeai.backend.lora.layers.loha_layer import LoHALayer
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
from invokeai.backend.lora.layers.set_weight_layer import SetWeightLayer
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer, SetWeightLayer]
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer, SetParameterLayer]

View File

@ -6,30 +6,22 @@ from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
class SetWeightLayer(LoRALayerBase):
class ReshapeWeightLayer(LoRALayerBase):
# TODO: Just everything in this class
def __init__(self, weight: Optional[torch.Tensor], bias: Optional[torch.Tensor], scale: Optional[torch.Tensor]):
super().__init__(alpha=None, bias=bias)
self.weight = torch.nn.Parameter(weight) if weight is not None else None
self.bias = torch.nn.Parameter(bias) if bias is not None else None
self.manual_scale = scale
def scale(self):
return self.manual_scale.float() if self.manual_scale is not None else super().scale()
@classmethod
def from_state_dict_values(
cls,
values: Dict[str, torch.Tensor],
):
layer = cls(weight=values.get("set_weight", None), bias=values.get("set_bias", None), scale=values.get("set_scale", None))
cls.warn_on_unhandled_keys(values=values, handled_keys={"set_weight", "set_bias", "set_scale"})
return layer
def rank(self) -> int | None:
return None
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
return orig_weight
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)

View File

@ -0,0 +1,29 @@
from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
class SetParameterLayer(LoRALayerBase):
def __init__(self, param_name: str, weight: torch.Tensor):
super().__init__(None, None)
self.weight = weight
self.param_name = param_name
def rank(self) -> int | None:
return None
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight - orig_weight
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
return {self.param_name: self.get_weight(orig_module.get_parameter(self.param_name))}
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return super().calc_size() + calc_tensor_size(self.weight)

View File

@ -9,7 +9,7 @@ from invokeai.backend.lora.layers.loha_layer import LoHALayer
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
from invokeai.backend.lora.layers.set_weight_layer import SetWeightLayer
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
@ -30,7 +30,5 @@ def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLo
return IA3Layer.from_state_dict_values(state_dict)
elif "w_norm" in state_dict:
return NormLayer.from_state_dict_values(state_dict)
elif any(key in state_dict for key in ["set_weight", "set_bias", "set_scale"]):
return SetWeightLayer.from_state_dict_values(state_dict)
else:
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")

View File

@ -6,11 +6,13 @@ import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.full_layer import FullLayer
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
ConcatenatedLoRALinearSidecarLayer,
)
from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
from invokeai.backend.lora.sidecar_layers.lora.lora_full_linear_sidecar_layer import LoRAFullLinearSidecarLayer
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
@ -93,8 +95,9 @@ class LoRAPatcher:
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
first_param = next(module.parameters())
device = first_param.device
dtype = first_param.dtype
layer_scale = layer.scale()
@ -114,7 +117,13 @@ class LoRAPatcher:
original_weights.save(param_key, module_param)
if module_param.shape != lora_param_weight.shape:
lora_param_weight = lora_param_weight.reshape(module_param.shape)
if module_param.nelement() == lora_param_weight.nelement():
lora_param_weight = lora_param_weight.reshape(module_param.shape)
else:
expanded_weight = torch.zeros_like(lora_param_weight, device=module_param.device)
slices = tuple(slice(0, dim) for dim in module_param.shape)
expanded_weight[slices] = module_param
setattr(module, param_name, expanded_weight)
lora_param_weight *= patch_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype)
@ -244,6 +253,8 @@ class LoRAPatcher:
return LoRALinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight)
elif isinstance(lora_layer, ConcatenatedLoRALayer):
return ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer=lora_layer, weight=patch_weight)
elif isinstance(lora_layer, FullLayer):
return LoRAFullLinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight)
else:
raise ValueError(f"Unsupported Linear LoRA layer type: {type(lora_layer)}")
else:

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,70 @@
import pytest
import torch
from invokeai.backend.lora.conversions.flux_control_lora_utils import (
is_state_dict_likely_flux_control,
lora_model_from_flux_control_state_dict,
)
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from tests.backend.lora.conversions.lora_state_dicts.flux_control_lora_format import (
state_dict_keys as flux_control_lora_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
@pytest.mark.parametrize("sd_keys", [flux_control_lora_state_dict_keys])
def test_is_state_dict_likely_in_flux_control_format_true(sd_keys: dict[str, list[int]]):
"""Test that is_state_dict_likely_flux_control() can identify a state dict in the FLUX Control LoRA format."""
# Construct a state dict that is in the Diffusers FLUX LoRA format.
state_dict = keys_to_mock_state_dict(sd_keys)
assert is_state_dict_likely_flux_control(state_dict)
@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys])
def test_is_state_dict_likely_in_flux_control_format_false(sd_keys: dict[str, list[int]]):
"""Test that is_state_dict_likely_flux_control() returns False for a state dict that is in the Diffusers
FLUX LoRA format.
"""
# Construct a state dict that is not in the FLUX Control LoRA format.
state_dict = keys_to_mock_state_dict(sd_keys)
assert not is_state_dict_likely_flux_control(state_dict)
@pytest.mark.parametrize("sd_keys", [flux_control_lora_state_dict_keys])
def test_lora_model_from_flux_control_state_dict(sd_keys: dict[str, list[int]]):
"""Test that lora_model_from_flux_control_state_dict() can load a state dict in the FLUX Control LoRA format."""
# Construct a state dict that is in the FLUX Control LoRA format.
state_dict = keys_to_mock_state_dict(sd_keys)
# Load the state dict into a LoRAModelRaw object.
model = lora_model_from_flux_control_state_dict(state_dict)
# Check that the model has the correct number of LoRA layers.
expected_lora_layers: set[str] = set()
for k in sd_keys:
k = k.replace("lora_A.weight", "")
k = k.replace("lora_B.weight", "")
k = k.replace("lora_B.bias", "")
k = k.replace(".scale", "")
expected_lora_layers.add(k)
# Drop the K/V/proj_mlp weights because these are all concatenated into a single layer in the BFL format (we keep
# the Q weights so that we count these layers once).
assert len(model.layers) == len(expected_lora_layers)
assert all(k.startswith(FLUX_LORA_TRANSFORMER_PREFIX) for k in model.layers.keys())
def test_lora_model_from_flux_control_state_dict_extra_keys_error():
"""Test that lora_model_from_flux_control_state_dict() raises an error if the input state_dict contains unexpected
keys that we don't handle.
"""
# Construct a state dict that is in the FLUX Control LoRA format.
state_dict = keys_to_mock_state_dict(flux_control_lora_state_dict_keys)
# Add an unexpected key.
state_dict["transformer.single_transformer_blocks.0.unexpected_key.lora_A.weight"] = torch.empty(1)
# Check that an error is raised.
with pytest.raises(AssertionError):
lora_model_from_flux_control_state_dict(state_dict)