mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-08 11:57:36 +08:00
Add more unit tests for custom module LoRA patching: multiple LoRAs and ConcatenatedLoRALayers.
This commit is contained in:
parent
e24e386a27
commit
5ee7405f97
@ -11,6 +11,7 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch
|
|||||||
wrap_custom_layer,
|
wrap_custom_layer,
|
||||||
)
|
)
|
||||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||||
|
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||||
from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_8_bit_lt import (
|
from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_8_bit_lt import (
|
||||||
build_linear_8bit_lt_layer,
|
build_linear_8bit_lt_layer,
|
||||||
@ -263,18 +264,22 @@ def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: La
|
|||||||
assert torch.allclose(orig_output, custom_output)
|
assert torch.allclose(orig_output, custom_output)
|
||||||
|
|
||||||
|
|
||||||
LayerAndPatchUnderTest = tuple[torch.nn.Module, BaseLayerPatch, torch.Tensor, bool]
|
LayerAndPatchUnderTest = tuple[torch.nn.Module, list[tuple[BaseLayerPatch, float]], torch.Tensor, bool]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
params=[
|
params=[
|
||||||
"linear_lora",
|
"linear_single_lora",
|
||||||
|
"linear_multiple_loras",
|
||||||
|
"linear_concatenated_lora",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchUnderTest:
|
def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchUnderTest:
|
||||||
"""A fixture that returns a tuple of (layer, input, supports_cpu_inference) for the layer under test."""
|
"""A fixture that returns a tuple of (layer, input, supports_cpu_inference) for the layer under test."""
|
||||||
layer_type = request.param
|
layer_type = request.param
|
||||||
if layer_type == "linear_lora":
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
if layer_type == "linear_single_lora":
|
||||||
# Create a linear layer.
|
# Create a linear layer.
|
||||||
in_features = 10
|
in_features = 10
|
||||||
out_features = 20
|
out_features = 20
|
||||||
@ -282,39 +287,85 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU
|
|||||||
|
|
||||||
# Create a LoRA layer.
|
# Create a LoRA layer.
|
||||||
rank = 4
|
rank = 4
|
||||||
down = torch.randn(rank, in_features)
|
lora_layer = LoRALayer(
|
||||||
up = torch.randn(out_features, rank)
|
up=torch.randn(out_features, rank),
|
||||||
bias = torch.randn(out_features)
|
mid=None,
|
||||||
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)
|
down=torch.randn(rank, in_features),
|
||||||
|
alpha=1.0,
|
||||||
|
bias=torch.randn(out_features),
|
||||||
|
)
|
||||||
|
input = torch.randn(1, in_features)
|
||||||
|
return (layer, [(lora_layer, 0.7)], input, True)
|
||||||
|
elif layer_type == "linear_multiple_loras":
|
||||||
|
# Create a linear layer.
|
||||||
|
rank = 4
|
||||||
|
in_features = 10
|
||||||
|
out_features = 20
|
||||||
|
layer = torch.nn.Linear(in_features, out_features)
|
||||||
|
|
||||||
|
lora_layer = LoRALayer(
|
||||||
|
up=torch.randn(out_features, rank),
|
||||||
|
mid=None,
|
||||||
|
down=torch.randn(rank, in_features),
|
||||||
|
alpha=1.0,
|
||||||
|
bias=torch.randn(out_features),
|
||||||
|
)
|
||||||
|
lora_layer_2 = LoRALayer(
|
||||||
|
up=torch.randn(out_features, rank),
|
||||||
|
mid=None,
|
||||||
|
down=torch.randn(rank, in_features),
|
||||||
|
alpha=1.0,
|
||||||
|
bias=torch.randn(out_features),
|
||||||
|
)
|
||||||
|
|
||||||
input = torch.randn(1, in_features)
|
input = torch.randn(1, in_features)
|
||||||
return (layer, lora_layer, input, True)
|
return (layer, [(lora_layer, 1.0), (lora_layer_2, 0.5)], input, True)
|
||||||
|
elif layer_type == "linear_concatenated_lora":
|
||||||
|
# Create a linear layer.
|
||||||
|
in_features = 5
|
||||||
|
sub_layer_out_features = [5, 10, 15]
|
||||||
|
layer = torch.nn.Linear(in_features, sum(sub_layer_out_features))
|
||||||
|
|
||||||
|
# Create a ConcatenatedLoRA layer.
|
||||||
|
rank = 4
|
||||||
|
sub_layers: list[LoRALayer] = []
|
||||||
|
for out_features in sub_layer_out_features:
|
||||||
|
down = torch.randn(rank, in_features)
|
||||||
|
up = torch.randn(out_features, rank)
|
||||||
|
bias = torch.randn(out_features)
|
||||||
|
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias))
|
||||||
|
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
|
||||||
|
|
||||||
|
input = torch.randn(1, in_features)
|
||||||
|
return (layer, [(concatenated_lora_layer, 0.7)], input, True)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported layer_type: {layer_type}")
|
raise ValueError(f"Unsupported layer_type: {layer_type}")
|
||||||
|
|
||||||
|
|
||||||
@parameterize_all_devices
|
@parameterize_all_devices
|
||||||
def test_sidecar_patches(device: str, layer_and_patch_under_test: LayerAndPatchUnderTest):
|
def test_sidecar_patches(device: str, layer_and_patch_under_test: LayerAndPatchUnderTest):
|
||||||
layer, patch, input, supports_cpu_inference = layer_and_patch_under_test
|
layer, patches, input, supports_cpu_inference = layer_and_patch_under_test
|
||||||
|
|
||||||
if device == "cpu" and not supports_cpu_inference:
|
if device == "cpu" and not supports_cpu_inference:
|
||||||
pytest.skip("Layer does not support CPU inference.")
|
pytest.skip("Layer does not support CPU inference.")
|
||||||
|
|
||||||
# Move the layer, patch, and input to the device.
|
# Move the layer and input to the device.
|
||||||
layer_to_device_via_state_dict(layer, device)
|
layer_to_device_via_state_dict(layer, device)
|
||||||
patch.to(torch.device(device))
|
|
||||||
input = input.to(torch.device(device))
|
input = input.to(torch.device(device))
|
||||||
|
|
||||||
# Patch the LoRA layer into the linear layer.
|
# Patch the LoRA layer into the linear layer.
|
||||||
weight = 0.7
|
|
||||||
layer_patched = copy.deepcopy(layer)
|
layer_patched = copy.deepcopy(layer)
|
||||||
parameters = patch.get_parameters(layer_patched, weight=weight)
|
for patch, weight in patches:
|
||||||
for param_name, param_weight in parameters.items():
|
patch.to(torch.device(device))
|
||||||
getattr(layer_patched, param_name).data += param_weight
|
parameters = patch.get_parameters(layer_patched, weight=weight)
|
||||||
|
for param_name, param_weight in parameters.items():
|
||||||
|
module_param = getattr(layer_patched, param_name)
|
||||||
|
module_param.data += param_weight
|
||||||
|
|
||||||
# Wrap the original layer in a custom layer and add the patch to it as a sidecar.
|
# Wrap the original layer in a custom layer and add the patch to it as a sidecar.
|
||||||
custom_layer = wrap_single_custom_layer(layer)
|
custom_layer = wrap_single_custom_layer(layer)
|
||||||
custom_layer.add_patch(patch, weight)
|
for patch, weight in patches:
|
||||||
|
custom_layer.add_patch(patch, weight)
|
||||||
|
|
||||||
# Run inference with the original layer and the patched layer and assert they are equal.
|
# Run inference with the original layer and the patched layer and assert they are equal.
|
||||||
output_patched = layer_patched(input)
|
output_patched = layer_patched(input)
|
||||||
|
Loading…
Reference in New Issue
Block a user