mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-09 04:18:46 +08:00
Remove unused layer_key property from LoRALayerBase.
This commit is contained in:
parent
fef26a5f2f
commit
705173b575
@ -55,8 +55,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
if src_key in grouped_state_dict:
|
||||
src_layer_dict = grouped_state_dict.pop(src_key)
|
||||
layers[dst_key] = LoRALayer(
|
||||
dst_key,
|
||||
{
|
||||
values={
|
||||
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
|
||||
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
|
||||
"alpha": torch.tensor(alpha),
|
||||
@ -81,7 +80,6 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
for src_layer_dict in src_layer_dicts:
|
||||
sub_layers.append(
|
||||
LoRALayer(
|
||||
layer_key="",
|
||||
values={
|
||||
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
|
||||
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
|
||||
@ -90,7 +88,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
)
|
||||
)
|
||||
assert len(src_layer_dict) == 0
|
||||
layers[dst_qkv_key] = ConcatenatedLoRALayer(layer_key=dst_qkv_key, lora_layers=sub_layers, concat_axis=0)
|
||||
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers, concat_axis=0)
|
||||
|
||||
# time_text_embed.timestep_embedder -> time_in.
|
||||
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_1", "time_in.in_layer")
|
||||
|
@ -41,8 +41,7 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
|
||||
# Create LoRA layers.
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
for layer_key, layer_state_dict in grouped_state_dict.items():
|
||||
layer = any_lora_layer_from_state_dict(layer_key, layer_state_dict)
|
||||
layers[layer_key] = layer
|
||||
layers[layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
|
||||
|
||||
# Create and return the LoRAModelRaw.
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
@ -12,8 +12,7 @@ def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAMo
|
||||
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
for layer_key, values in grouped_state_dict.items():
|
||||
layer = any_lora_layer_from_state_dict(layer_key, values)
|
||||
layers[layer_key] = layer
|
||||
layers[layer_key] = any_lora_layer_from_state_dict(values)
|
||||
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
||||
|
@ -13,9 +13,9 @@ class ConcatenatedLoRALayer(LoRALayerBase):
|
||||
stored as separate tensors. This class enables diffusers LoRA layers to be used in BFL FLUX models.
|
||||
"""
|
||||
|
||||
def __init__(self, layer_key: str, lora_layers: List[LoRALayerBase], concat_axis: int = 0):
|
||||
def __init__(self, lora_layers: List[LoRALayerBase], concat_axis: int = 0):
|
||||
# Note: We pass values={} to the base class, because the values are handled by the individual LoRA layers.
|
||||
super().__init__(layer_key, values={})
|
||||
super().__init__(values={})
|
||||
|
||||
self._lora_layers = lora_layers
|
||||
self._concat_axis = concat_axis
|
||||
|
@ -12,10 +12,9 @@ class FullLayer(LoRALayerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
super().__init__(values)
|
||||
|
||||
self.weight = values["diff"]
|
||||
self.bias = values.get("diff_b", None)
|
||||
|
@ -11,10 +11,9 @@ class IA3Layer(LoRALayerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
super().__init__(values)
|
||||
|
||||
self.weight = values["weight"]
|
||||
self.on_input = values["on_input"]
|
||||
|
@ -13,8 +13,8 @@ class LoHALayer(LoRALayerBase):
|
||||
# t1: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
|
||||
super().__init__(layer_key, values)
|
||||
def __init__(self, values: Dict[str, torch.Tensor]):
|
||||
super().__init__(values)
|
||||
|
||||
self.w1_a = values["hada_w1_a"]
|
||||
self.w1_b = values["hada_w1_b"]
|
||||
|
@ -16,10 +16,9 @@ class LoKRLayer(LoRALayerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
super().__init__(values)
|
||||
|
||||
self.w1 = values.get("lokr_w1", None)
|
||||
if self.w1 is None:
|
||||
|
@ -13,10 +13,9 @@ class LoRALayer(LoRALayerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
super().__init__(values)
|
||||
|
||||
self.up = values["lora_up.weight"]
|
||||
self.down = values["lora_down.weight"]
|
||||
|
@ -9,7 +9,6 @@ class LoRALayerBase:
|
||||
# rank: Optional[int]
|
||||
# alpha: Optional[float]
|
||||
# bias: Optional[torch.Tensor]
|
||||
# layer_key: str
|
||||
|
||||
# @property
|
||||
# def scale(self):
|
||||
@ -17,7 +16,6 @@ class LoRALayerBase:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
if "alpha" in values:
|
||||
@ -36,7 +34,6 @@ class LoRALayerBase:
|
||||
self.bias = None
|
||||
|
||||
self.rank = None # set in layer implementation
|
||||
self.layer_key = layer_key
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
@ -12,10 +12,9 @@ class NormLayer(LoRALayerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
super().__init__(values)
|
||||
|
||||
self.weight = values["w_norm"]
|
||||
self.bias = values.get("b_norm", None)
|
||||
|
@ -11,23 +11,23 @@ from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
|
||||
|
||||
def any_lora_layer_from_state_dict(layer_key: str, state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
|
||||
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
|
||||
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
||||
|
||||
if "lora_up.weight" in state_dict:
|
||||
# LoRA a.k.a LoCon
|
||||
return LoRALayer(layer_key, state_dict)
|
||||
return LoRALayer(state_dict)
|
||||
elif "hada_w1_a" in state_dict:
|
||||
return LoHALayer(layer_key, state_dict)
|
||||
return LoHALayer(state_dict)
|
||||
elif "lokr_w1" in state_dict or "lokr_w1_a" in state_dict:
|
||||
return LoKRLayer(layer_key, state_dict)
|
||||
return LoKRLayer(state_dict)
|
||||
elif "diff" in state_dict:
|
||||
# Full a.k.a Diff
|
||||
return FullLayer(layer_key, state_dict)
|
||||
return FullLayer(state_dict)
|
||||
elif "on_input" in state_dict:
|
||||
return IA3Layer(layer_key, state_dict)
|
||||
return IA3Layer(state_dict)
|
||||
elif "w_norm" in state_dict:
|
||||
return NormLayer(layer_key, state_dict)
|
||||
return NormLayer(state_dict)
|
||||
else:
|
||||
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")
|
||||
|
@ -28,7 +28,6 @@ def test_apply_lora(device: str):
|
||||
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer(
|
||||
layer_key="linear_layer_1",
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
|
||||
@ -72,7 +71,6 @@ def test_apply_lora_change_device():
|
||||
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer(
|
||||
layer_key="linear_layer_1",
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
|
||||
|
Loading…
Reference in New Issue
Block a user