mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-04-03 15:31:55 +08:00
183 lines
7.3 KiB
Python
183 lines
7.3 KiB
Python
import copy
|
|
|
|
import torch
|
|
|
|
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
|
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
|
|
from invokeai.backend.patches.layers.full_layer import FullLayer
|
|
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
|
from invokeai.backend.patches.pad_with_zeros import pad_with_zeros
|
|
from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper
|
|
|
|
|
|
@torch.no_grad()
|
|
def test_linear_sidecar_wrapper_lora():
|
|
# Create a linear layer.
|
|
in_features = 10
|
|
out_features = 20
|
|
linear = torch.nn.Linear(in_features, out_features)
|
|
|
|
# Create a LoRA layer.
|
|
rank = 4
|
|
down = torch.randn(rank, in_features)
|
|
up = torch.randn(out_features, rank)
|
|
bias = torch.randn(out_features)
|
|
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)
|
|
|
|
# Patch the LoRA layer into the linear layer.
|
|
linear_patched = copy.deepcopy(linear)
|
|
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
|
|
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
|
|
|
|
# Create a LinearSidecarWrapper.
|
|
lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, 1.0)])
|
|
|
|
# Run the LoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal.
|
|
input = torch.randn(1, in_features)
|
|
output_patched = linear_patched(input)
|
|
output_wrapped = lora_wrapped(input)
|
|
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
|
|
|
|
|
|
@torch.no_grad()
|
|
def test_linear_sidecar_wrapper_multiple_loras():
|
|
# Create a linear layer.
|
|
in_features = 10
|
|
out_features = 20
|
|
linear = torch.nn.Linear(in_features, out_features)
|
|
|
|
# Create two LoRA layers.
|
|
rank = 4
|
|
lora_layer = LoRALayer(
|
|
up=torch.randn(out_features, rank),
|
|
mid=None,
|
|
down=torch.randn(rank, in_features),
|
|
alpha=1.0,
|
|
bias=torch.randn(out_features),
|
|
)
|
|
lora_layer_2 = LoRALayer(
|
|
up=torch.randn(out_features, rank),
|
|
mid=None,
|
|
down=torch.randn(rank, in_features),
|
|
alpha=1.0,
|
|
bias=torch.randn(out_features),
|
|
)
|
|
# We use different weights for the two LoRA layers to ensure this is working.
|
|
lora_weight = 1.0
|
|
lora_weight_2 = 0.5
|
|
|
|
# Patch the LoRA layers into the linear layer.
|
|
linear_patched = copy.deepcopy(linear)
|
|
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * (lora_layer.scale() * lora_weight)
|
|
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * (lora_layer.scale() * lora_weight)
|
|
linear_patched.weight.data += lora_layer_2.get_weight(linear_patched.weight) * (
|
|
lora_layer_2.scale() * lora_weight_2
|
|
)
|
|
linear_patched.bias.data += lora_layer_2.get_bias(linear_patched.bias) * (lora_layer_2.scale() * lora_weight_2)
|
|
|
|
# Create a LinearSidecarWrapper.
|
|
lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, lora_weight), (lora_layer_2, lora_weight_2)])
|
|
|
|
# Run the LoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal.
|
|
input = torch.randn(1, in_features)
|
|
output_patched = linear_patched(input)
|
|
output_wrapped = lora_wrapped(input)
|
|
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
|
|
|
|
|
|
@torch.no_grad()
|
|
def test_linear_sidecar_wrapper_concatenated_lora():
|
|
# Create a linear layer.
|
|
in_features = 5
|
|
sub_layer_out_features = [5, 10, 15]
|
|
linear = torch.nn.Linear(in_features, sum(sub_layer_out_features))
|
|
|
|
# Create a ConcatenatedLoRA layer.
|
|
rank = 4
|
|
sub_layers: list[LoRALayer] = []
|
|
for out_features in sub_layer_out_features:
|
|
down = torch.randn(rank, in_features)
|
|
up = torch.randn(out_features, rank)
|
|
bias = torch.randn(out_features)
|
|
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias))
|
|
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
|
|
|
|
# Patch the ConcatenatedLoRA layer into the linear layer.
|
|
linear_patched = copy.deepcopy(linear)
|
|
linear_patched.weight.data += (
|
|
concatenated_lora_layer.get_weight(linear_patched.weight) * concatenated_lora_layer.scale()
|
|
)
|
|
linear_patched.bias.data += concatenated_lora_layer.get_bias(linear_patched.bias) * concatenated_lora_layer.scale()
|
|
|
|
# Create a LinearSidecarWrapper.
|
|
lora_wrapped = LinearSidecarWrapper(linear, [(concatenated_lora_layer, 1.0)])
|
|
|
|
# Run the ConcatenatedLoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal.
|
|
input = torch.randn(1, in_features)
|
|
output_patched = linear_patched(input)
|
|
output_wrapped = lora_wrapped(input)
|
|
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
|
|
|
|
|
|
def test_linear_sidecar_wrapper_full_layer():
|
|
# Create a linear layer.
|
|
in_features = 10
|
|
out_features = 20
|
|
linear = torch.nn.Linear(in_features, out_features)
|
|
|
|
# Create a FullLayer.
|
|
full_layer = FullLayer(weight=torch.randn(out_features, in_features), bias=torch.randn(out_features))
|
|
|
|
# Patch the FullLayer into the linear layer.
|
|
linear_patched = copy.deepcopy(linear)
|
|
linear_patched.weight.data += full_layer.get_weight(linear_patched.weight)
|
|
linear_patched.bias.data += full_layer.get_bias(linear_patched.bias)
|
|
|
|
# Create a LinearSidecarWrapper.
|
|
full_wrapped = LinearSidecarWrapper(linear, [(full_layer, 1.0)])
|
|
|
|
# Run the FullLayer-patched linear layer and the LinearSidecarWrapper and assert they are equal.
|
|
input = torch.randn(1, in_features)
|
|
output_patched = linear_patched(input)
|
|
output_wrapped = full_wrapped(input)
|
|
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
|
|
|
|
|
|
def test_linear_sidecar_wrapper_flux_control_lora_layer():
|
|
# Create a linear layer.
|
|
orig_in_features = 10
|
|
out_features = 40
|
|
linear = torch.nn.Linear(orig_in_features, out_features)
|
|
|
|
# Create a FluxControlLoRALayer.
|
|
patched_in_features = 20
|
|
rank = 4
|
|
lora_layer = FluxControlLoRALayer(
|
|
up=torch.randn(out_features, rank),
|
|
mid=None,
|
|
down=torch.randn(rank, patched_in_features),
|
|
alpha=1.0,
|
|
bias=torch.randn(out_features),
|
|
)
|
|
|
|
# Patch the FluxControlLoRALayer into the linear layer.
|
|
linear_patched = copy.deepcopy(linear)
|
|
# Expand the existing weight.
|
|
expanded_weight = pad_with_zeros(linear_patched.weight, torch.Size([out_features, patched_in_features]))
|
|
linear_patched.weight = torch.nn.Parameter(expanded_weight, requires_grad=linear_patched.weight.requires_grad)
|
|
# Expand the existing bias.
|
|
expanded_bias = pad_with_zeros(linear_patched.bias, torch.Size([out_features]))
|
|
linear_patched.bias = torch.nn.Parameter(expanded_bias, requires_grad=linear_patched.bias.requires_grad)
|
|
# Add the residuals.
|
|
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
|
|
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
|
|
|
|
# Create a LinearSidecarWrapper.
|
|
lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, 1.0)])
|
|
|
|
# Run the FluxControlLoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal.
|
|
input = torch.randn(1, patched_in_features)
|
|
output_patched = linear_patched(input)
|
|
output_wrapped = lora_wrapped(input)
|
|
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
|