mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-07 03:17:05 +08:00
Add more unit tests for custom module LoRA patching: multiple LoRAs and ConcatenatedLoRALayers.
This commit is contained in:
parent
e24e386a27
commit
5ee7405f97
@ -11,6 +11,7 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch
|
||||
wrap_custom_layer,
|
||||
)
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_8_bit_lt import (
|
||||
build_linear_8bit_lt_layer,
|
||||
@ -263,18 +264,22 @@ def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: La
|
||||
assert torch.allclose(orig_output, custom_output)
|
||||
|
||||
|
||||
LayerAndPatchUnderTest = tuple[torch.nn.Module, BaseLayerPatch, torch.Tensor, bool]
|
||||
LayerAndPatchUnderTest = tuple[torch.nn.Module, list[tuple[BaseLayerPatch, float]], torch.Tensor, bool]
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
"linear_lora",
|
||||
"linear_single_lora",
|
||||
"linear_multiple_loras",
|
||||
"linear_concatenated_lora",
|
||||
]
|
||||
)
|
||||
def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchUnderTest:
|
||||
"""A fixture that returns a tuple of (layer, input, supports_cpu_inference) for the layer under test."""
|
||||
layer_type = request.param
|
||||
if layer_type == "linear_lora":
|
||||
torch.manual_seed(0)
|
||||
|
||||
if layer_type == "linear_single_lora":
|
||||
# Create a linear layer.
|
||||
in_features = 10
|
||||
out_features = 20
|
||||
@ -282,39 +287,85 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU
|
||||
|
||||
# Create a LoRA layer.
|
||||
rank = 4
|
||||
down = torch.randn(rank, in_features)
|
||||
up = torch.randn(out_features, rank)
|
||||
bias = torch.randn(out_features)
|
||||
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)
|
||||
lora_layer = LoRALayer(
|
||||
up=torch.randn(out_features, rank),
|
||||
mid=None,
|
||||
down=torch.randn(rank, in_features),
|
||||
alpha=1.0,
|
||||
bias=torch.randn(out_features),
|
||||
)
|
||||
input = torch.randn(1, in_features)
|
||||
return (layer, [(lora_layer, 0.7)], input, True)
|
||||
elif layer_type == "linear_multiple_loras":
|
||||
# Create a linear layer.
|
||||
rank = 4
|
||||
in_features = 10
|
||||
out_features = 20
|
||||
layer = torch.nn.Linear(in_features, out_features)
|
||||
|
||||
lora_layer = LoRALayer(
|
||||
up=torch.randn(out_features, rank),
|
||||
mid=None,
|
||||
down=torch.randn(rank, in_features),
|
||||
alpha=1.0,
|
||||
bias=torch.randn(out_features),
|
||||
)
|
||||
lora_layer_2 = LoRALayer(
|
||||
up=torch.randn(out_features, rank),
|
||||
mid=None,
|
||||
down=torch.randn(rank, in_features),
|
||||
alpha=1.0,
|
||||
bias=torch.randn(out_features),
|
||||
)
|
||||
|
||||
input = torch.randn(1, in_features)
|
||||
return (layer, lora_layer, input, True)
|
||||
return (layer, [(lora_layer, 1.0), (lora_layer_2, 0.5)], input, True)
|
||||
elif layer_type == "linear_concatenated_lora":
|
||||
# Create a linear layer.
|
||||
in_features = 5
|
||||
sub_layer_out_features = [5, 10, 15]
|
||||
layer = torch.nn.Linear(in_features, sum(sub_layer_out_features))
|
||||
|
||||
# Create a ConcatenatedLoRA layer.
|
||||
rank = 4
|
||||
sub_layers: list[LoRALayer] = []
|
||||
for out_features in sub_layer_out_features:
|
||||
down = torch.randn(rank, in_features)
|
||||
up = torch.randn(out_features, rank)
|
||||
bias = torch.randn(out_features)
|
||||
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias))
|
||||
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
|
||||
|
||||
input = torch.randn(1, in_features)
|
||||
return (layer, [(concatenated_lora_layer, 0.7)], input, True)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer_type: {layer_type}")
|
||||
|
||||
|
||||
@parameterize_all_devices
|
||||
def test_sidecar_patches(device: str, layer_and_patch_under_test: LayerAndPatchUnderTest):
|
||||
layer, patch, input, supports_cpu_inference = layer_and_patch_under_test
|
||||
layer, patches, input, supports_cpu_inference = layer_and_patch_under_test
|
||||
|
||||
if device == "cpu" and not supports_cpu_inference:
|
||||
pytest.skip("Layer does not support CPU inference.")
|
||||
|
||||
# Move the layer, patch, and input to the device.
|
||||
# Move the layer and input to the device.
|
||||
layer_to_device_via_state_dict(layer, device)
|
||||
patch.to(torch.device(device))
|
||||
input = input.to(torch.device(device))
|
||||
|
||||
# Patch the LoRA layer into the linear layer.
|
||||
weight = 0.7
|
||||
layer_patched = copy.deepcopy(layer)
|
||||
parameters = patch.get_parameters(layer_patched, weight=weight)
|
||||
for param_name, param_weight in parameters.items():
|
||||
getattr(layer_patched, param_name).data += param_weight
|
||||
for patch, weight in patches:
|
||||
patch.to(torch.device(device))
|
||||
parameters = patch.get_parameters(layer_patched, weight=weight)
|
||||
for param_name, param_weight in parameters.items():
|
||||
module_param = getattr(layer_patched, param_name)
|
||||
module_param.data += param_weight
|
||||
|
||||
# Wrap the original layer in a custom layer and add the patch to it as a sidecar.
|
||||
custom_layer = wrap_single_custom_layer(layer)
|
||||
custom_layer.add_patch(patch, weight)
|
||||
for patch, weight in patches:
|
||||
custom_layer.add_patch(patch, weight)
|
||||
|
||||
# Run inference with the original layer and the patched layer and assert they are equal.
|
||||
output_patched = layer_patched(input)
|
||||
|
Loading…
Reference in New Issue
Block a user