WIP - adding LoRA sidecar layers

This commit is contained in:
Ryan Dick 2024-09-10 21:45:18 +00:00 committed by Kent Keirsey
parent 2ff4dae5ce
commit 049ce1826c
8 changed files with 335 additions and 0 deletions

View File

@ -3,7 +3,10 @@ from typing import Dict, Iterable, Optional, Tuple
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
@ -110,6 +113,113 @@ class LoRAPatcher:
layer.to(device=TorchDevice.CPU_DEVICE)
@staticmethod
@torch.no_grad()
@contextmanager
def apply_lora_sidecar_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
):
original_modules: dict[str, torch.nn.Module] = {}
try:
for patch, patch_weight in patches:
LoraPatcher._apply_lora_sidecar_patch(
model=model,
prefix=prefix,
patch=patch,
patch_weight=patch_weight,
original_modules=original_modules,
)
yield
finally:
# Restore original modules.
# Note: This logic assumes no nested modules in original_modules.
for module_key, orig_module in original_modules.items():
module_parent_key, module_name = module_key.rsplit(".", 1)
parent_module = model.get_submodule(module_parent_key)
LoraPatcher._set_submodule(parent_module, module_name, orig_module)
@staticmethod
def _apply_lora_sidecar_patch(
model: torch.nn.Module,
patch: LoRAModelRaw,
patch_weight: float,
prefix: str,
original_modules: dict[str, torch.nn.Module],
):
if patch_weight == 0:
return
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
# without searching, but some legacy code still uses flattened keys.
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
prefix_len = len(prefix)
for layer_key, layer in patch.layers.items():
if not layer_key.startswith(prefix):
continue
module_key, module = LoraPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
# Initialize the LoRA sidecar layer.
lora_sidecar_layer = LoraPatcher._initialize_lora_sidecar_layer(module, layer, patch_weight)
# TODO(ryand): Should we move the LoRA sidecar layer to the same device/dtype as the orig module?
if module_key in original_modules:
# The module has already been patched with a LoRASidecarModule. Append to it.
assert isinstance(module, LoRASidecarModule)
module.add_lora_layer(lora_sidecar_layer)
else:
# The module has not yet been patched with a LoRASidecarModule. Create one.
lora_sidecar_module = LoRASidecarModule(module, [lora_sidecar_layer])
original_modules[module_key] = module
module_parent_key, module_name = module_key.rsplit(".", 1)
module_parent = model.get_submodule(module_parent_key)
LoraPatcher._set_submodule(module_parent, module_name, lora_sidecar_module)
@staticmethod
def _initialize_lora_sidecar_layer(orig_layer: torch.nn.Module, lora_layer: AnyLoRALayer, patch_weight: float):
if isinstance(orig_layer, torch.nn.Linear):
if isinstance(lora_layer, LoRALayer):
return LoRALinearSidecarLayer(...)
else:
raise ValueError(f"Unsupported Linear LoRA layer type: {type(lora_layer)}")
elif isinstance(orig_layer, torch.nn.Conv1d):
if isinstance(lora_layer, LoRALayer):
return LoRAConv1dSidecarLayer(...)
else:
raise ValueError(f"Unsupported Conv1D LoRA layer type: {type(lora_layer)}")
elif isinstance(orig_layer, torch.nn.Conv2d):
if isinstance(lora_layer, LoRALayer):
return LoRAConv2dSidecarLayer(...)
else:
raise ValueError(f"Unsupported Conv2D LoRA layer type: {type(lora_layer)}")
elif isinstance(orig_layer, torch.nn.Conv3d):
if isinstance(lora_layer, LoRALayer):
return LoRAConv3dSidecarLayer(...)
else:
raise ValueError(f"Unsupported Conv3D LoRA layer type: {type(lora_layer)}")
else:
raise ValueError(f"Unsupported layer type: {type(orig_layer)}")
@staticmethod
def _set_submodule(parent_module: torch.nn.Module, module_name: str, submodule: torch.nn.Module):
try:
submodule_index = int(module_name)
# If the module name is an integer, then we use the __setitem__ method to set the submodule.
parent_module[submodule_index] = submodule
except ValueError:
# If the module name is not an integer, then we use the setattr method to set the submodule.
setattr(parent_module, module_name, submodule)
@staticmethod
def _get_submodule(
model: torch.nn.Module, layer_key: str, layer_key_is_flattened: bool

View File

@ -0,0 +1,151 @@
import typing
import torch
from invokeai.backend.lora.layers.lora_layer import LoRALayer
class LoRAConvSidecarLayer(torch.nn.Module):
"""An implementation of a conv LoRA layer based on the paper 'LoRA: Low-Rank Adaptation of Large Language Models'.
(https://arxiv.org/pdf/2106.09685.pdf)
"""
@property
def conv_module(self) -> type[torch.nn.Conv1d | torch.nn.Conv2d | torch.nn.Conv3d]:
"""The conv module to be set by child classes. One of torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d."""
raise NotImplementedError(
"LoRAConvLayer cannot be used directly. Use LoRAConv1dLayer, LoRAConv2dLayer, or LoRAConv3dLayer instead."
)
def __init__(
self,
in_channels: int,
out_channels: int,
include_mid: bool,
rank: int,
alpha: float,
weight: float,
kernel_size: typing.Union[int, tuple[int]] = 1,
stride: typing.Union[int, tuple[int]] = 1,
padding: typing.Union[str, int, tuple[int]] = 0,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
"""Initialize a LoRAConvLayer.
Args:
in_channels (int): The number of channels expected on inputs to this layer.
out_channels (int): The number of channels on outputs from this layer.
kernel_size: The kernel_size of the conv layer that this LoRA layer is mirroring. See torch.nn.Conv* docs.
stride: The stride of the conv layer that this LoRA layer is mirroring. See torch.nn.Conv* docs.
padding: The padding of the conv layer that this LoRA layer is mirroring. See torch.nn.Conv* docs.
rank (int, optional): The internal rank of the layer. See the paper for details.
alpha (float, optional): A scaling factor that enables tuning the rank without having to adjust the learning
rate. The recommendation from the paper is to set alpha equal to the first rank that you try and then do
not tune it further. See the paper for more details.
device (torch.device, optional): Device where weights will be initialized.
dtype (torch.dtype, optional): Weight dtype.
Raises:
ValueError: If the rank is greater than either in_channels or out_channels.
"""
super().__init__()
if rank > min(in_channels, out_channels):
raise ValueError(f"LoRA rank {rank} must be less than or equal to {min(in_channels, out_channels)}")
self._down = self.conv_module(
in_channels,
rank,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=False,
device=device,
dtype=dtype,
)
self._up = self.conv_module(rank, out_channels, kernel_size=1, stride=1, bias=False, device=device, dtype=dtype)
self._mid = None
if include_mid:
self._mid = self.conv_module(rank, rank, kernel_size=1, stride=1, bias=False, device=device, dtype=dtype)
# Register alpha as a buffer so that it is not trained, but still gets saved to the state_dict.
self.register_buffer("alpha", torch.tensor(alpha, device=device, dtype=dtype))
self._weight = weight
self._rank = rank
@classmethod
def from_layers(cls, orig_layer: torch.nn.Module, lora_layer: LoRALayer, weight: float):
# Initialize the LoRA layer.
with torch.device("meta"):
model = cls.from_orig_layer(
orig_layer,
include_mid=lora_layer.mid is not None,
rank=lora_layer.rank,
# TODO(ryand): Is this the right default in case of missing alpha?
alpha=lora_layer.alpha if lora_layer.alpha is not None else lora_layer.rank,
weight=weight,
)
# Inject weight into the LoRA layer.
model._up.weight.data = lora_layer.up
model._down.weight.data = lora_layer.down
if lora_layer.mid is not None:
assert model._mid is not None
model._mid.weight.data = lora_layer.mid
return model
@classmethod
def from_orig_layer(
cls,
layer: torch.nn.Module,
include_mid: bool,
rank: int,
alpha: float,
weight: float,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
if not isinstance(layer, cls.conv_module):
raise TypeError(f"'{__class__.__name__}' cannot be initialized from a layer of type '{type(layer)}'.")
return cls(
in_channels=layer.in_channels,
out_channels=layer.out_channels,
include_mid=include_mid,
weight=weight,
kernel_size=layer.kernel_size,
stride=layer.stride,
padding=layer.padding,
rank=rank,
alpha=alpha,
device=layer.weight.device if device is None else device,
dtype=layer.weight.dtype if dtype is None else dtype,
)
def forward(self, x: torch.Tensor):
x = self._down(x)
if self._mid is not None:
x = self._mid(x)
x = self._up(x)
x *= self._weight * self.alpha / self._rank
return x
class LoRAConv1dSidecarLayer(LoRAConvSidecarLayer):
@property
def conv_module(self):
return torch.nn.Conv1d
class LoRAConv2dSidecarLayer(LoRAConvSidecarLayer):
@property
def conv_module(self):
return torch.nn.Conv2d
class LoRAConv3dSidecarLayer(LoRAConvSidecarLayer):
@property
def conv_module(self):
return torch.nn.Conv3d

View File

@ -0,0 +1,57 @@
import torch
class LoRALinearSidecarLayer(torch.nn.Module):
"""An implementation of a linear LoRA layer based on the paper 'LoRA: Low-Rank Adaptation of Large Language Models'.
(https://arxiv.org/pdf/2106.09685.pdf)
"""
def __init__(
self,
in_features: int,
out_features: int,
include_mid: bool,
rank: int,
alpha: float,
weight: float,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
"""Initialize a LoRALinearLayer.
Args:
in_features (int): Inputs to this layer will be expected to have shape (..., in_features).
out_features (int): This layer will produce outputs with shape (..., out_features).
rank (int, optional): The internal rank of the layer. See the paper for details.
alpha (float, optional): A scaling factor that enables tuning the rank without having to adjust the learning
rate. The recommendation from the paper is to set alpha equal to the first rank that you try and then do
not tune it further. See the paper for more details.
device (torch.device, optional): Device where weights will be initialized.
dtype (torch.dtype, optional): Weight dtype.
Raises:
ValueError: If the rank is greater than either in_features or out_features.
"""
super().__init__()
if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less than or equal to {min(in_features, out_features)}")
self._down = torch.nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
self._up = torch.nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
self._mid = None
if include_mid:
self._mid = torch.nn.Linear(rank, rank, bias=False, device=device, dtype=dtype)
# Register alpha as a buffer so that it is not trained, but still gets saved to the state_dict.
self.register_buffer("alpha", torch.tensor(alpha, device=device, dtype=dtype))
self._weight = weight
self._rank = rank
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self._down(x)
if self._mid is not None:
x = self._mid(x)
x = self._up(x)
x *= self._weight * self.alpha / self._rank
return x

View File

@ -0,0 +1,17 @@
import torch
class LoRASidecarModule(torch.nn.Module):
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[torch.nn.Module]):
super().__init__()
self._orig_module = orig_module
self._lora_layers = lora_layers
def add_lora_layer(self, lora_layer: torch.nn.Module):
self._lora_layers.append(lora_layer)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self._orig_module(x)
for lora_layer in self._lora_layers:
x += lora_layer(x)
return x