Remove unused layer_key property from LoRALayerBase.

This commit is contained in:
Ryan Dick 2024-09-10 15:17:12 +00:00 committed by Kent Keirsey
parent fef26a5f2f
commit 705173b575
13 changed files with 20 additions and 34 deletions

View File

@ -55,8 +55,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
if src_key in grouped_state_dict:
src_layer_dict = grouped_state_dict.pop(src_key)
layers[dst_key] = LoRALayer(
dst_key,
{
values={
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
"alpha": torch.tensor(alpha),
@ -81,7 +80,6 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
for src_layer_dict in src_layer_dicts:
sub_layers.append(
LoRALayer(
layer_key="",
values={
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
@ -90,7 +88,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
)
)
assert len(src_layer_dict) == 0
layers[dst_qkv_key] = ConcatenatedLoRALayer(layer_key=dst_qkv_key, lora_layers=sub_layers, concat_axis=0)
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers, concat_axis=0)
# time_text_embed.timestep_embedder -> time_in.
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_1", "time_in.in_layer")

View File

@ -41,8 +41,7 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
# Create LoRA layers.
layers: dict[str, AnyLoRALayer] = {}
for layer_key, layer_state_dict in grouped_state_dict.items():
layer = any_lora_layer_from_state_dict(layer_key, layer_state_dict)
layers[layer_key] = layer
layers[layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
# Create and return the LoRAModelRaw.
return LoRAModelRaw(layers=layers)

View File

@ -12,8 +12,7 @@ def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAMo
layers: dict[str, AnyLoRALayer] = {}
for layer_key, values in grouped_state_dict.items():
layer = any_lora_layer_from_state_dict(layer_key, values)
layers[layer_key] = layer
layers[layer_key] = any_lora_layer_from_state_dict(values)
return LoRAModelRaw(layers=layers)

View File

@ -13,9 +13,9 @@ class ConcatenatedLoRALayer(LoRALayerBase):
stored as separate tensors. This class enables diffusers LoRA layers to be used in BFL FLUX models.
"""
def __init__(self, layer_key: str, lora_layers: List[LoRALayerBase], concat_axis: int = 0):
def __init__(self, lora_layers: List[LoRALayerBase], concat_axis: int = 0):
# Note: We pass values={} to the base class, because the values are handled by the individual LoRA layers.
super().__init__(layer_key, values={})
super().__init__(values={})
self._lora_layers = lora_layers
self._concat_axis = concat_axis

View File

@ -12,10 +12,9 @@ class FullLayer(LoRALayerBase):
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
super().__init__(values)
self.weight = values["diff"]
self.bias = values.get("diff_b", None)

View File

@ -11,10 +11,9 @@ class IA3Layer(LoRALayerBase):
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
super().__init__(values)
self.weight = values["weight"]
self.on_input = values["on_input"]

View File

@ -13,8 +13,8 @@ class LoHALayer(LoRALayerBase):
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
super().__init__(layer_key, values)
def __init__(self, values: Dict[str, torch.Tensor]):
super().__init__(values)
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]

View File

@ -16,10 +16,9 @@ class LoKRLayer(LoRALayerBase):
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
super().__init__(values)
self.w1 = values.get("lokr_w1", None)
if self.w1 is None:

View File

@ -13,10 +13,9 @@ class LoRALayer(LoRALayerBase):
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
super().__init__(values)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]

View File

@ -9,7 +9,6 @@ class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
# bias: Optional[torch.Tensor]
# layer_key: str
# @property
# def scale(self):
@ -17,7 +16,6 @@ class LoRALayerBase:
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
if "alpha" in values:
@ -36,7 +34,6 @@ class LoRALayerBase:
self.bias = None
self.rank = None # set in layer implementation
self.layer_key = layer_key
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()

View File

@ -12,10 +12,9 @@ class NormLayer(LoRALayerBase):
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
super().__init__(values)
self.weight = values["w_norm"]
self.bias = values.get("b_norm", None)

View File

@ -11,23 +11,23 @@ from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
def any_lora_layer_from_state_dict(layer_key: str, state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
if "lora_up.weight" in state_dict:
# LoRA a.k.a LoCon
return LoRALayer(layer_key, state_dict)
return LoRALayer(state_dict)
elif "hada_w1_a" in state_dict:
return LoHALayer(layer_key, state_dict)
return LoHALayer(state_dict)
elif "lokr_w1" in state_dict or "lokr_w1_a" in state_dict:
return LoKRLayer(layer_key, state_dict)
return LoKRLayer(state_dict)
elif "diff" in state_dict:
# Full a.k.a Diff
return FullLayer(layer_key, state_dict)
return FullLayer(state_dict)
elif "on_input" in state_dict:
return IA3Layer(layer_key, state_dict)
return IA3Layer(state_dict)
elif "w_norm" in state_dict:
return NormLayer(layer_key, state_dict)
return NormLayer(state_dict)
else:
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")

View File

@ -28,7 +28,6 @@ def test_apply_lora(device: str):
lora_layers = {
"linear_layer_1": LoRALayer(
layer_key="linear_layer_1",
values={
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
@ -72,7 +71,6 @@ def test_apply_lora_change_device():
lora_layers = {
"linear_layer_1": LoRALayer(
layer_key="linear_layer_1",
values={
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),