mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-04-04 22:43:40 +08:00
WIP - adding LoRA sidecar layers
This commit is contained in:
parent
2ff4dae5ce
commit
049ce1826c
@ -3,7 +3,10 @@ from typing import Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
@ -110,6 +113,113 @@ class LoRAPatcher:
|
||||
|
||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_lora_sidecar_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
):
|
||||
original_modules: dict[str, torch.nn.Module] = {}
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
LoraPatcher._apply_lora_sidecar_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
patch_weight=patch_weight,
|
||||
original_modules=original_modules,
|
||||
)
|
||||
|
||||
yield
|
||||
finally:
|
||||
# Restore original modules.
|
||||
# Note: This logic assumes no nested modules in original_modules.
|
||||
for module_key, orig_module in original_modules.items():
|
||||
module_parent_key, module_name = module_key.rsplit(".", 1)
|
||||
parent_module = model.get_submodule(module_parent_key)
|
||||
LoraPatcher._set_submodule(parent_module, module_name, orig_module)
|
||||
|
||||
@staticmethod
|
||||
def _apply_lora_sidecar_patch(
|
||||
model: torch.nn.Module,
|
||||
patch: LoRAModelRaw,
|
||||
patch_weight: float,
|
||||
prefix: str,
|
||||
original_modules: dict[str, torch.nn.Module],
|
||||
):
|
||||
if patch_weight == 0:
|
||||
return
|
||||
|
||||
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
|
||||
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
|
||||
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
|
||||
# without searching, but some legacy code still uses flattened keys.
|
||||
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
|
||||
|
||||
prefix_len = len(prefix)
|
||||
|
||||
for layer_key, layer in patch.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module_key, module = LoraPatcher._get_submodule(
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
# Initialize the LoRA sidecar layer.
|
||||
lora_sidecar_layer = LoraPatcher._initialize_lora_sidecar_layer(module, layer, patch_weight)
|
||||
|
||||
# TODO(ryand): Should we move the LoRA sidecar layer to the same device/dtype as the orig module?
|
||||
|
||||
if module_key in original_modules:
|
||||
# The module has already been patched with a LoRASidecarModule. Append to it.
|
||||
assert isinstance(module, LoRASidecarModule)
|
||||
module.add_lora_layer(lora_sidecar_layer)
|
||||
else:
|
||||
# The module has not yet been patched with a LoRASidecarModule. Create one.
|
||||
lora_sidecar_module = LoRASidecarModule(module, [lora_sidecar_layer])
|
||||
original_modules[module_key] = module
|
||||
module_parent_key, module_name = module_key.rsplit(".", 1)
|
||||
module_parent = model.get_submodule(module_parent_key)
|
||||
LoraPatcher._set_submodule(module_parent, module_name, lora_sidecar_module)
|
||||
|
||||
@staticmethod
|
||||
def _initialize_lora_sidecar_layer(orig_layer: torch.nn.Module, lora_layer: AnyLoRALayer, patch_weight: float):
|
||||
if isinstance(orig_layer, torch.nn.Linear):
|
||||
if isinstance(lora_layer, LoRALayer):
|
||||
return LoRALinearSidecarLayer(...)
|
||||
else:
|
||||
raise ValueError(f"Unsupported Linear LoRA layer type: {type(lora_layer)}")
|
||||
elif isinstance(orig_layer, torch.nn.Conv1d):
|
||||
if isinstance(lora_layer, LoRALayer):
|
||||
return LoRAConv1dSidecarLayer(...)
|
||||
else:
|
||||
raise ValueError(f"Unsupported Conv1D LoRA layer type: {type(lora_layer)}")
|
||||
elif isinstance(orig_layer, torch.nn.Conv2d):
|
||||
if isinstance(lora_layer, LoRALayer):
|
||||
return LoRAConv2dSidecarLayer(...)
|
||||
else:
|
||||
raise ValueError(f"Unsupported Conv2D LoRA layer type: {type(lora_layer)}")
|
||||
elif isinstance(orig_layer, torch.nn.Conv3d):
|
||||
if isinstance(lora_layer, LoRALayer):
|
||||
return LoRAConv3dSidecarLayer(...)
|
||||
else:
|
||||
raise ValueError(f"Unsupported Conv3D LoRA layer type: {type(lora_layer)}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer type: {type(orig_layer)}")
|
||||
|
||||
@staticmethod
|
||||
def _set_submodule(parent_module: torch.nn.Module, module_name: str, submodule: torch.nn.Module):
|
||||
try:
|
||||
submodule_index = int(module_name)
|
||||
# If the module name is an integer, then we use the __setitem__ method to set the submodule.
|
||||
parent_module[submodule_index] = submodule
|
||||
except ValueError:
|
||||
# If the module name is not an integer, then we use the setattr method to set the submodule.
|
||||
setattr(parent_module, module_name, submodule)
|
||||
|
||||
@staticmethod
|
||||
def _get_submodule(
|
||||
model: torch.nn.Module, layer_key: str, layer_key_is_flattened: bool
|
||||
|
0
invokeai/backend/lora/sidecar_layers/__init__.py
Normal file
0
invokeai/backend/lora/sidecar_layers/__init__.py
Normal file
@ -0,0 +1,151 @@
|
||||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
|
||||
|
||||
class LoRAConvSidecarLayer(torch.nn.Module):
|
||||
"""An implementation of a conv LoRA layer based on the paper 'LoRA: Low-Rank Adaptation of Large Language Models'.
|
||||
(https://arxiv.org/pdf/2106.09685.pdf)
|
||||
"""
|
||||
|
||||
@property
|
||||
def conv_module(self) -> type[torch.nn.Conv1d | torch.nn.Conv2d | torch.nn.Conv3d]:
|
||||
"""The conv module to be set by child classes. One of torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d."""
|
||||
raise NotImplementedError(
|
||||
"LoRAConvLayer cannot be used directly. Use LoRAConv1dLayer, LoRAConv2dLayer, or LoRAConv3dLayer instead."
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
include_mid: bool,
|
||||
rank: int,
|
||||
alpha: float,
|
||||
weight: float,
|
||||
kernel_size: typing.Union[int, tuple[int]] = 1,
|
||||
stride: typing.Union[int, tuple[int]] = 1,
|
||||
padding: typing.Union[str, int, tuple[int]] = 0,
|
||||
device: torch.device | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
):
|
||||
"""Initialize a LoRAConvLayer.
|
||||
Args:
|
||||
in_channels (int): The number of channels expected on inputs to this layer.
|
||||
out_channels (int): The number of channels on outputs from this layer.
|
||||
kernel_size: The kernel_size of the conv layer that this LoRA layer is mirroring. See torch.nn.Conv* docs.
|
||||
stride: The stride of the conv layer that this LoRA layer is mirroring. See torch.nn.Conv* docs.
|
||||
padding: The padding of the conv layer that this LoRA layer is mirroring. See torch.nn.Conv* docs.
|
||||
rank (int, optional): The internal rank of the layer. See the paper for details.
|
||||
alpha (float, optional): A scaling factor that enables tuning the rank without having to adjust the learning
|
||||
rate. The recommendation from the paper is to set alpha equal to the first rank that you try and then do
|
||||
not tune it further. See the paper for more details.
|
||||
device (torch.device, optional): Device where weights will be initialized.
|
||||
dtype (torch.dtype, optional): Weight dtype.
|
||||
Raises:
|
||||
ValueError: If the rank is greater than either in_channels or out_channels.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if rank > min(in_channels, out_channels):
|
||||
raise ValueError(f"LoRA rank {rank} must be less than or equal to {min(in_channels, out_channels)}")
|
||||
|
||||
self._down = self.conv_module(
|
||||
in_channels,
|
||||
rank,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self._up = self.conv_module(rank, out_channels, kernel_size=1, stride=1, bias=False, device=device, dtype=dtype)
|
||||
self._mid = None
|
||||
if include_mid:
|
||||
self._mid = self.conv_module(rank, rank, kernel_size=1, stride=1, bias=False, device=device, dtype=dtype)
|
||||
|
||||
# Register alpha as a buffer so that it is not trained, but still gets saved to the state_dict.
|
||||
self.register_buffer("alpha", torch.tensor(alpha, device=device, dtype=dtype))
|
||||
|
||||
self._weight = weight
|
||||
self._rank = rank
|
||||
|
||||
@classmethod
|
||||
def from_layers(cls, orig_layer: torch.nn.Module, lora_layer: LoRALayer, weight: float):
|
||||
# Initialize the LoRA layer.
|
||||
with torch.device("meta"):
|
||||
model = cls.from_orig_layer(
|
||||
orig_layer,
|
||||
include_mid=lora_layer.mid is not None,
|
||||
rank=lora_layer.rank,
|
||||
# TODO(ryand): Is this the right default in case of missing alpha?
|
||||
alpha=lora_layer.alpha if lora_layer.alpha is not None else lora_layer.rank,
|
||||
weight=weight,
|
||||
)
|
||||
|
||||
# Inject weight into the LoRA layer.
|
||||
model._up.weight.data = lora_layer.up
|
||||
model._down.weight.data = lora_layer.down
|
||||
if lora_layer.mid is not None:
|
||||
assert model._mid is not None
|
||||
model._mid.weight.data = lora_layer.mid
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def from_orig_layer(
|
||||
cls,
|
||||
layer: torch.nn.Module,
|
||||
include_mid: bool,
|
||||
rank: int,
|
||||
alpha: float,
|
||||
weight: float,
|
||||
device: torch.device | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
):
|
||||
if not isinstance(layer, cls.conv_module):
|
||||
raise TypeError(f"'{__class__.__name__}' cannot be initialized from a layer of type '{type(layer)}'.")
|
||||
|
||||
return cls(
|
||||
in_channels=layer.in_channels,
|
||||
out_channels=layer.out_channels,
|
||||
include_mid=include_mid,
|
||||
weight=weight,
|
||||
kernel_size=layer.kernel_size,
|
||||
stride=layer.stride,
|
||||
padding=layer.padding,
|
||||
rank=rank,
|
||||
alpha=alpha,
|
||||
device=layer.weight.device if device is None else device,
|
||||
dtype=layer.weight.dtype if dtype is None else dtype,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = self._down(x)
|
||||
if self._mid is not None:
|
||||
x = self._mid(x)
|
||||
x = self._up(x)
|
||||
|
||||
x *= self._weight * self.alpha / self._rank
|
||||
return x
|
||||
|
||||
|
||||
class LoRAConv1dSidecarLayer(LoRAConvSidecarLayer):
|
||||
@property
|
||||
def conv_module(self):
|
||||
return torch.nn.Conv1d
|
||||
|
||||
|
||||
class LoRAConv2dSidecarLayer(LoRAConvSidecarLayer):
|
||||
@property
|
||||
def conv_module(self):
|
||||
return torch.nn.Conv2d
|
||||
|
||||
|
||||
class LoRAConv3dSidecarLayer(LoRAConvSidecarLayer):
|
||||
@property
|
||||
def conv_module(self):
|
||||
return torch.nn.Conv3d
|
@ -0,0 +1,57 @@
|
||||
import torch
|
||||
|
||||
|
||||
class LoRALinearSidecarLayer(torch.nn.Module):
|
||||
"""An implementation of a linear LoRA layer based on the paper 'LoRA: Low-Rank Adaptation of Large Language Models'.
|
||||
(https://arxiv.org/pdf/2106.09685.pdf)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
include_mid: bool,
|
||||
rank: int,
|
||||
alpha: float,
|
||||
weight: float,
|
||||
device: torch.device | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
):
|
||||
"""Initialize a LoRALinearLayer.
|
||||
Args:
|
||||
in_features (int): Inputs to this layer will be expected to have shape (..., in_features).
|
||||
out_features (int): This layer will produce outputs with shape (..., out_features).
|
||||
rank (int, optional): The internal rank of the layer. See the paper for details.
|
||||
alpha (float, optional): A scaling factor that enables tuning the rank without having to adjust the learning
|
||||
rate. The recommendation from the paper is to set alpha equal to the first rank that you try and then do
|
||||
not tune it further. See the paper for more details.
|
||||
device (torch.device, optional): Device where weights will be initialized.
|
||||
dtype (torch.dtype, optional): Weight dtype.
|
||||
Raises:
|
||||
ValueError: If the rank is greater than either in_features or out_features.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if rank > min(in_features, out_features):
|
||||
raise ValueError(f"LoRA rank {rank} must be less than or equal to {min(in_features, out_features)}")
|
||||
|
||||
self._down = torch.nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
||||
self._up = torch.nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
|
||||
self._mid = None
|
||||
if include_mid:
|
||||
self._mid = torch.nn.Linear(rank, rank, bias=False, device=device, dtype=dtype)
|
||||
|
||||
# Register alpha as a buffer so that it is not trained, but still gets saved to the state_dict.
|
||||
self.register_buffer("alpha", torch.tensor(alpha, device=device, dtype=dtype))
|
||||
|
||||
self._weight = weight
|
||||
self._rank = rank
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self._down(x)
|
||||
if self._mid is not None:
|
||||
x = self._mid(x)
|
||||
x = self._up(x)
|
||||
|
||||
x *= self._weight * self.alpha / self._rank
|
||||
return x
|
17
invokeai/backend/lora/sidecar_layers/lora_sidecar_module.py
Normal file
17
invokeai/backend/lora/sidecar_layers/lora_sidecar_module.py
Normal file
@ -0,0 +1,17 @@
|
||||
import torch
|
||||
|
||||
|
||||
class LoRASidecarModule(torch.nn.Module):
|
||||
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[torch.nn.Module]):
|
||||
super().__init__()
|
||||
self._orig_module = orig_module
|
||||
self._lora_layers = lora_layers
|
||||
|
||||
def add_lora_layer(self, lora_layer: torch.nn.Module):
|
||||
self._lora_layers.append(lora_layer)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self._orig_module(x)
|
||||
for lora_layer in self._lora_layers:
|
||||
x += lora_layer(x)
|
||||
return x
|
Loading…
x
Reference in New Issue
Block a user