Add more unit tests for custom module LoRA patching: multiple LoRAs and ConcatenatedLoRALayers.

This commit is contained in:
Ryan Dick 2024-12-27 19:47:21 +00:00
parent e24e386a27
commit 5ee7405f97

View File

@ -11,6 +11,7 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch
wrap_custom_layer,
)
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_8_bit_lt import (
build_linear_8bit_lt_layer,
@ -263,18 +264,22 @@ def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: La
assert torch.allclose(orig_output, custom_output)
LayerAndPatchUnderTest = tuple[torch.nn.Module, BaseLayerPatch, torch.Tensor, bool]
LayerAndPatchUnderTest = tuple[torch.nn.Module, list[tuple[BaseLayerPatch, float]], torch.Tensor, bool]
@pytest.fixture(
params=[
"linear_lora",
"linear_single_lora",
"linear_multiple_loras",
"linear_concatenated_lora",
]
)
def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchUnderTest:
"""A fixture that returns a tuple of (layer, input, supports_cpu_inference) for the layer under test."""
layer_type = request.param
if layer_type == "linear_lora":
torch.manual_seed(0)
if layer_type == "linear_single_lora":
# Create a linear layer.
in_features = 10
out_features = 20
@ -282,39 +287,85 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU
# Create a LoRA layer.
rank = 4
down = torch.randn(rank, in_features)
up = torch.randn(out_features, rank)
bias = torch.randn(out_features)
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)
lora_layer = LoRALayer(
up=torch.randn(out_features, rank),
mid=None,
down=torch.randn(rank, in_features),
alpha=1.0,
bias=torch.randn(out_features),
)
input = torch.randn(1, in_features)
return (layer, [(lora_layer, 0.7)], input, True)
elif layer_type == "linear_multiple_loras":
# Create a linear layer.
rank = 4
in_features = 10
out_features = 20
layer = torch.nn.Linear(in_features, out_features)
lora_layer = LoRALayer(
up=torch.randn(out_features, rank),
mid=None,
down=torch.randn(rank, in_features),
alpha=1.0,
bias=torch.randn(out_features),
)
lora_layer_2 = LoRALayer(
up=torch.randn(out_features, rank),
mid=None,
down=torch.randn(rank, in_features),
alpha=1.0,
bias=torch.randn(out_features),
)
input = torch.randn(1, in_features)
return (layer, lora_layer, input, True)
return (layer, [(lora_layer, 1.0), (lora_layer_2, 0.5)], input, True)
elif layer_type == "linear_concatenated_lora":
# Create a linear layer.
in_features = 5
sub_layer_out_features = [5, 10, 15]
layer = torch.nn.Linear(in_features, sum(sub_layer_out_features))
# Create a ConcatenatedLoRA layer.
rank = 4
sub_layers: list[LoRALayer] = []
for out_features in sub_layer_out_features:
down = torch.randn(rank, in_features)
up = torch.randn(out_features, rank)
bias = torch.randn(out_features)
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias))
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
input = torch.randn(1, in_features)
return (layer, [(concatenated_lora_layer, 0.7)], input, True)
else:
raise ValueError(f"Unsupported layer_type: {layer_type}")
@parameterize_all_devices
def test_sidecar_patches(device: str, layer_and_patch_under_test: LayerAndPatchUnderTest):
layer, patch, input, supports_cpu_inference = layer_and_patch_under_test
layer, patches, input, supports_cpu_inference = layer_and_patch_under_test
if device == "cpu" and not supports_cpu_inference:
pytest.skip("Layer does not support CPU inference.")
# Move the layer, patch, and input to the device.
# Move the layer and input to the device.
layer_to_device_via_state_dict(layer, device)
patch.to(torch.device(device))
input = input.to(torch.device(device))
# Patch the LoRA layer into the linear layer.
weight = 0.7
layer_patched = copy.deepcopy(layer)
parameters = patch.get_parameters(layer_patched, weight=weight)
for param_name, param_weight in parameters.items():
getattr(layer_patched, param_name).data += param_weight
for patch, weight in patches:
patch.to(torch.device(device))
parameters = patch.get_parameters(layer_patched, weight=weight)
for param_name, param_weight in parameters.items():
module_param = getattr(layer_patched, param_name)
module_param.data += param_weight
# Wrap the original layer in a custom layer and add the patch to it as a sidecar.
custom_layer = wrap_single_custom_layer(layer)
custom_layer.add_patch(patch, weight)
for patch, weight in patches:
custom_layer.add_patch(patch, weight)
# Run inference with the original layer and the patched layer and assert they are equal.
output_patched = layer_patched(input)