Delete old sidecar wrapper implementation. This functionality has moved into the custom layers.

This commit is contained in:
Ryan Dick 2024-12-29 17:33:08 +00:00
parent 52fc5a64d4
commit 6fd9b0a274
9 changed files with 0 additions and 393 deletions

View File

@ -1,56 +0,0 @@
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
class BaseSidecarWrapper(torch.nn.Module):
"""A base class for sidecar wrappers.
A sidecar wrapper is a wrapper for an existing torch.nn.Module that applies a
list of patches as 'sidecar' patches. I.e. it applies the sidecar patches during forward inference without modifying
the original module.
Sidecar wrappers are typically used over regular patches when:
- The original module is quantized and so the weights can't be patched in the usual way.
- The original module is on the CPU and modifying the weights would require backing up the original weights and
doubling the CPU memory usage.
"""
def __init__(
self, orig_module: torch.nn.Module, patches_and_weights: list[tuple[BaseLayerPatch, float]] | None = None
):
super().__init__()
self._orig_module = orig_module
self._patches_and_weights = [] if patches_and_weights is None else patches_and_weights
@property
def orig_module(self) -> torch.nn.Module:
return self._orig_module
def add_patch(self, patch: BaseLayerPatch, patch_weight: float):
"""Add a patch to the sidecar wrapper."""
self._patches_and_weights.append((patch, patch_weight))
def _aggregate_patch_parameters(
self, patches_and_weights: list[tuple[BaseLayerPatch, float]]
) -> dict[str, torch.Tensor]:
"""Helper function that aggregates the parameters from all patches into a single dict."""
params: dict[str, torch.Tensor] = {}
for patch, patch_weight in patches_and_weights:
# TODO(ryand): self._orig_module could be quantized. Depending on what the patch is doing with the original
# parameters, this might fail or return incorrect results.
layer_params = patch.get_parameters(
dict(self._orig_module.named_parameters(recurse=False)), weight=patch_weight
)
for param_name, param_weight in layer_params.items():
if param_name not in params:
params[param_name] = param_weight
else:
params[param_name] += param_weight
return params
def forward(self, *args, **kwargs): # type: ignore
raise NotImplementedError()

View File

@ -1,11 +0,0 @@
import torch
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
class Conv1dSidecarWrapper(BaseSidecarWrapper):
def forward(self, input: torch.Tensor) -> torch.Tensor:
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
return self.orig_module(input) + torch.nn.functional.conv1d(
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
)

View File

@ -1,11 +0,0 @@
import torch
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
class Conv2dSidecarWrapper(BaseSidecarWrapper):
def forward(self, input: torch.Tensor) -> torch.Tensor:
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
return self.orig_module(input) + torch.nn.functional.conv1d(
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
)

View File

@ -1,24 +0,0 @@
import torch
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
class FluxRMSNormSidecarWrapper(BaseSidecarWrapper):
"""A sidecar wrapper for a FLUX RMSNorm layer.
This wrapper is a special case. It is added specifically to enable FLUX structural control LoRAs, which overwrite
the RMSNorm scale parameters.
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
# Given the narrow focus of this wrapper, we only support a very particular patch configuration:
assert len(self._patches_and_weights) == 1
patch, _patch_weight = self._patches_and_weights[0]
assert isinstance(patch, SetParameterLayer)
assert patch.param_name == "scale"
# Apply the patch.
# NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should
# be handled.
return torch.nn.functional.rms_norm(input, patch.weight.shape, patch.weight, eps=1e-6)

View File

@ -1,66 +0,0 @@
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
class LinearSidecarWrapper(BaseSidecarWrapper):
def _lora_forward(self, input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a Linear LoRALayer."""
x = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x = torch.nn.functional.linear(x, lora_layer.mid)
x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias)
x *= lora_weight * lora_layer.scale()
return x
def _concatenated_lora_forward(
self, input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float
) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a Linear ConcatenatedLoRALayer."""
x_chunks: list[torch.Tensor] = []
for lora_layer in concatenated_lora_layer.lora_layers:
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
x_chunk *= lora_weight * lora_layer.scale()
x_chunks.append(x_chunk)
# TODO(ryand): Generalize to support concat_axis != 0.
assert concatenated_lora_layer.concat_axis == 0
x = torch.cat(x_chunks, dim=-1)
return x
def forward(self, input: torch.Tensor) -> torch.Tensor:
# First, apply the original linear layer.
# NOTE: We slice the input to match the original weight shape in order to work with FluxControlLoRAs, which
# change the linear layer's in_features.
orig_input = input
input = orig_input[..., : self.orig_module.in_features]
output = self.orig_module(input)
# Then, apply layers for which we have optimized implementations.
unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = []
for patch, patch_weight in self._patches_and_weights:
if isinstance(patch, FluxControlLoRALayer):
# Note that we use the original input here, not the sliced input.
output += self._lora_forward(orig_input, patch, patch_weight)
elif isinstance(patch, LoRALayer):
output += self._lora_forward(input, patch, patch_weight)
elif isinstance(patch, ConcatenatedLoRALayer):
output += self._concatenated_lora_forward(input, patch, patch_weight)
else:
unprocessed_patches_and_weights.append((patch, patch_weight))
# Finally, apply any remaining patches.
if len(unprocessed_patches_and_weights) > 0:
aggregated_param_residuals = self._aggregate_patch_parameters(unprocessed_patches_and_weights)
output += torch.nn.functional.linear(
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
)
return output

View File

@ -1,20 +0,0 @@
import torch
from invokeai.backend.flux.modules.layers import RMSNorm
from invokeai.backend.patches.sidecar_wrappers.conv1d_sidecar_wrapper import Conv1dSidecarWrapper
from invokeai.backend.patches.sidecar_wrappers.conv2d_sidecar_wrapper import Conv2dSidecarWrapper
from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper
from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper
def wrap_module_with_sidecar_wrapper(orig_module: torch.nn.Module) -> torch.nn.Module:
if isinstance(orig_module, torch.nn.Linear):
return LinearSidecarWrapper(orig_module)
elif isinstance(orig_module, torch.nn.Conv1d):
return Conv1dSidecarWrapper(orig_module)
elif isinstance(orig_module, torch.nn.Conv2d):
return Conv2dSidecarWrapper(orig_module)
elif isinstance(orig_module, RMSNorm):
return FluxRMSNormSidecarWrapper(orig_module)
else:
raise ValueError(f"No sidecar wrapper found for module type: {type(orig_module)}")

View File

@ -1,23 +0,0 @@
import torch
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper
def test_flux_rms_norm_sidecar_wrapper():
# Create a RMSNorm layer.
dim = 10
rms_norm = torch.nn.RMSNorm(dim)
# Create a SetParameterLayer.
new_scale = torch.randn(dim)
set_parameter_layer = SetParameterLayer("scale", new_scale)
# Create a FluxRMSNormSidecarWrapper.
rms_norm_wrapped = FluxRMSNormSidecarWrapper(rms_norm, [(set_parameter_layer, 1.0)])
# Run the FluxRMSNormSidecarWrapper.
input = torch.randn(1, dim)
expected_output = torch.nn.functional.rms_norm(input, new_scale.shape, new_scale, eps=1e-6)
output_wrapped = rms_norm_wrapped(input)
assert torch.allclose(output_wrapped, expected_output, atol=1e-6)

View File

@ -1,182 +0,0 @@
import copy
import torch
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.layers.full_layer import FullLayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.pad_with_zeros import pad_with_zeros
from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper
@torch.no_grad()
def test_linear_sidecar_wrapper_lora():
# Create a linear layer.
in_features = 10
out_features = 20
linear = torch.nn.Linear(in_features, out_features)
# Create a LoRA layer.
rank = 4
down = torch.randn(rank, in_features)
up = torch.randn(out_features, rank)
bias = torch.randn(out_features)
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)
# Patch the LoRA layer into the linear layer.
linear_patched = copy.deepcopy(linear)
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
# Create a LinearSidecarWrapper.
lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, 1.0)])
# Run the LoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal.
input = torch.randn(1, in_features)
output_patched = linear_patched(input)
output_wrapped = lora_wrapped(input)
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
@torch.no_grad()
def test_linear_sidecar_wrapper_multiple_loras():
# Create a linear layer.
in_features = 10
out_features = 20
linear = torch.nn.Linear(in_features, out_features)
# Create two LoRA layers.
rank = 4
lora_layer = LoRALayer(
up=torch.randn(out_features, rank),
mid=None,
down=torch.randn(rank, in_features),
alpha=1.0,
bias=torch.randn(out_features),
)
lora_layer_2 = LoRALayer(
up=torch.randn(out_features, rank),
mid=None,
down=torch.randn(rank, in_features),
alpha=1.0,
bias=torch.randn(out_features),
)
# We use different weights for the two LoRA layers to ensure this is working.
lora_weight = 1.0
lora_weight_2 = 0.5
# Patch the LoRA layers into the linear layer.
linear_patched = copy.deepcopy(linear)
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * (lora_layer.scale() * lora_weight)
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * (lora_layer.scale() * lora_weight)
linear_patched.weight.data += lora_layer_2.get_weight(linear_patched.weight) * (
lora_layer_2.scale() * lora_weight_2
)
linear_patched.bias.data += lora_layer_2.get_bias(linear_patched.bias) * (lora_layer_2.scale() * lora_weight_2)
# Create a LinearSidecarWrapper.
lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, lora_weight), (lora_layer_2, lora_weight_2)])
# Run the LoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal.
input = torch.randn(1, in_features)
output_patched = linear_patched(input)
output_wrapped = lora_wrapped(input)
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
@torch.no_grad()
def test_linear_sidecar_wrapper_concatenated_lora():
# Create a linear layer.
in_features = 5
sub_layer_out_features = [5, 10, 15]
linear = torch.nn.Linear(in_features, sum(sub_layer_out_features))
# Create a ConcatenatedLoRA layer.
rank = 4
sub_layers: list[LoRALayer] = []
for out_features in sub_layer_out_features:
down = torch.randn(rank, in_features)
up = torch.randn(out_features, rank)
bias = torch.randn(out_features)
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias))
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
# Patch the ConcatenatedLoRA layer into the linear layer.
linear_patched = copy.deepcopy(linear)
linear_patched.weight.data += (
concatenated_lora_layer.get_weight(linear_patched.weight) * concatenated_lora_layer.scale()
)
linear_patched.bias.data += concatenated_lora_layer.get_bias(linear_patched.bias) * concatenated_lora_layer.scale()
# Create a LinearSidecarWrapper.
lora_wrapped = LinearSidecarWrapper(linear, [(concatenated_lora_layer, 1.0)])
# Run the ConcatenatedLoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal.
input = torch.randn(1, in_features)
output_patched = linear_patched(input)
output_wrapped = lora_wrapped(input)
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
def test_linear_sidecar_wrapper_full_layer():
# Create a linear layer.
in_features = 10
out_features = 20
linear = torch.nn.Linear(in_features, out_features)
# Create a FullLayer.
full_layer = FullLayer(weight=torch.randn(out_features, in_features), bias=torch.randn(out_features))
# Patch the FullLayer into the linear layer.
linear_patched = copy.deepcopy(linear)
linear_patched.weight.data += full_layer.get_weight(linear_patched.weight)
linear_patched.bias.data += full_layer.get_bias(linear_patched.bias)
# Create a LinearSidecarWrapper.
full_wrapped = LinearSidecarWrapper(linear, [(full_layer, 1.0)])
# Run the FullLayer-patched linear layer and the LinearSidecarWrapper and assert they are equal.
input = torch.randn(1, in_features)
output_patched = linear_patched(input)
output_wrapped = full_wrapped(input)
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
def test_linear_sidecar_wrapper_flux_control_lora_layer():
# Create a linear layer.
orig_in_features = 10
out_features = 40
linear = torch.nn.Linear(orig_in_features, out_features)
# Create a FluxControlLoRALayer.
patched_in_features = 20
rank = 4
lora_layer = FluxControlLoRALayer(
up=torch.randn(out_features, rank),
mid=None,
down=torch.randn(rank, patched_in_features),
alpha=1.0,
bias=torch.randn(out_features),
)
# Patch the FluxControlLoRALayer into the linear layer.
linear_patched = copy.deepcopy(linear)
# Expand the existing weight.
expanded_weight = pad_with_zeros(linear_patched.weight, torch.Size([out_features, patched_in_features]))
linear_patched.weight = torch.nn.Parameter(expanded_weight, requires_grad=linear_patched.weight.requires_grad)
# Expand the existing bias.
expanded_bias = pad_with_zeros(linear_patched.bias, torch.Size([out_features]))
linear_patched.bias = torch.nn.Parameter(expanded_bias, requires_grad=linear_patched.bias.requires_grad)
# Add the residuals.
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
# Create a LinearSidecarWrapper.
lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, 1.0)])
# Run the FluxControlLoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal.
input = torch.randn(1, patched_in_features)
output_patched = linear_patched(input)
output_wrapped = lora_wrapped(input)
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)