Add support for patches to CustomModuleMixin and add a single unit test (more to come).

This commit is contained in:
Ryan Dick 2024-12-27 18:57:13 +00:00
parent b06d61e3c0
commit e24e386a27
4 changed files with 136 additions and 1 deletions

View File

@ -4,16 +4,27 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
add_nullable_tensors,
)
class CustomLinear(torch.nn.Linear, CustomModuleMixin):
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
weight = add_nullable_tensors(self.weight, aggregated_param_residuals["weight"])
bias = add_nullable_tensors(self.bias, aggregated_param_residuals.get("bias", None))
return torch.nn.functional.linear(input, weight, bias)
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self._device_autocasting_enabled:
if len(self._patches_and_weights) > 0:
return self._autocast_forward_with_patches(input)
elif self._device_autocasting_enabled:
return self._autocast_forward(input)
else:
return super().forward(input)

View File

@ -1,8 +1,14 @@
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
class CustomModuleMixin:
"""A mixin class for custom modules that enables device autocasting of module parameters."""
def __init__(self):
self._device_autocasting_enabled = False
self._patches_and_weights: list[tuple[BaseLayerPatch, float]] = []
def set_device_autocasting_enabled(self, enabled: bool):
"""Pass True to enable autocasting of module parameters to the same device as the input tensor. Pass False to
@ -10,3 +16,30 @@ class CustomModuleMixin:
not needed.
"""
self._device_autocasting_enabled = enabled
def add_patch(self, patch: BaseLayerPatch, patch_weight: float):
"""Add a patch to the sidecar wrapper."""
self._patches_and_weights.append((patch, patch_weight))
def clear_patches(self):
"""Clear all patches from the sidecar wrapper."""
self._patches_and_weights = []
def _aggregate_patch_parameters(
self, patches_and_weights: list[tuple[BaseLayerPatch, float]]
) -> dict[str, torch.Tensor]:
"""Helper function that aggregates the parameters from all patches into a single dict."""
params: dict[str, torch.Tensor] = {}
for patch, patch_weight in patches_and_weights:
# TODO(ryand): self._orig_module could be quantized. Depending on what the patch is doing with the original
# module, this might fail or return incorrect results.
layer_params = patch.get_parameters(self, weight=patch_weight)
for param_name, param_weight in layer_params.items():
if param_name not in params:
params[param_name] = param_weight
else:
params[param_name] += param_weight
return params

View File

@ -0,0 +1,30 @@
from typing import overload
import torch
@overload
def add_nullable_tensors(a: None, b: None) -> None: ...
@overload
def add_nullable_tensors(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: ...
@overload
def add_nullable_tensors(a: torch.Tensor, b: None) -> torch.Tensor: ...
@overload
def add_nullable_tensors(a: None, b: torch.Tensor) -> torch.Tensor: ...
def add_nullable_tensors(a: torch.Tensor | None, b: torch.Tensor | None) -> torch.Tensor | None:
if a is None and b is None:
return None
elif a is None:
return b
elif b is None:
return a
else:
return a + b

View File

@ -10,6 +10,8 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch
unwrap_custom_layer,
wrap_custom_layer,
)
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_8_bit_lt import (
build_linear_8bit_lt_layer,
)
@ -259,3 +261,62 @@ def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: La
assert custom_output.device.type == device
assert torch.allclose(orig_output, custom_output)
LayerAndPatchUnderTest = tuple[torch.nn.Module, BaseLayerPatch, torch.Tensor, bool]
@pytest.fixture(
params=[
"linear_lora",
]
)
def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchUnderTest:
"""A fixture that returns a tuple of (layer, input, supports_cpu_inference) for the layer under test."""
layer_type = request.param
if layer_type == "linear_lora":
# Create a linear layer.
in_features = 10
out_features = 20
layer = torch.nn.Linear(in_features, out_features)
# Create a LoRA layer.
rank = 4
down = torch.randn(rank, in_features)
up = torch.randn(out_features, rank)
bias = torch.randn(out_features)
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)
input = torch.randn(1, in_features)
return (layer, lora_layer, input, True)
else:
raise ValueError(f"Unsupported layer_type: {layer_type}")
@parameterize_all_devices
def test_sidecar_patches(device: str, layer_and_patch_under_test: LayerAndPatchUnderTest):
layer, patch, input, supports_cpu_inference = layer_and_patch_under_test
if device == "cpu" and not supports_cpu_inference:
pytest.skip("Layer does not support CPU inference.")
# Move the layer, patch, and input to the device.
layer_to_device_via_state_dict(layer, device)
patch.to(torch.device(device))
input = input.to(torch.device(device))
# Patch the LoRA layer into the linear layer.
weight = 0.7
layer_patched = copy.deepcopy(layer)
parameters = patch.get_parameters(layer_patched, weight=weight)
for param_name, param_weight in parameters.items():
getattr(layer_patched, param_name).data += param_weight
# Wrap the original layer in a custom layer and add the patch to it as a sidecar.
custom_layer = wrap_single_custom_layer(layer)
custom_layer.add_patch(patch, weight)
# Run inference with the original layer and the patched layer and assert they are equal.
output_patched = layer_patched(input)
output_custom = custom_layer(input)
assert torch.allclose(output_patched, output_custom)