mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-07 03:17:05 +08:00
Add support for patches to CustomModuleMixin and add a single unit test (more to come).
This commit is contained in:
parent
b06d61e3c0
commit
e24e386a27
@ -4,16 +4,27 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
|
||||
add_nullable_tensors,
|
||||
)
|
||||
|
||||
|
||||
class CustomLinear(torch.nn.Linear, CustomModuleMixin):
|
||||
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
|
||||
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
|
||||
weight = add_nullable_tensors(self.weight, aggregated_param_residuals["weight"])
|
||||
bias = add_nullable_tensors(self.bias, aggregated_param_residuals.get("bias", None))
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self._device_autocasting_enabled:
|
||||
if len(self._patches_and_weights) > 0:
|
||||
return self._autocast_forward_with_patches(input)
|
||||
elif self._device_autocasting_enabled:
|
||||
return self._autocast_forward(input)
|
||||
else:
|
||||
return super().forward(input)
|
||||
|
@ -1,8 +1,14 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
|
||||
|
||||
class CustomModuleMixin:
|
||||
"""A mixin class for custom modules that enables device autocasting of module parameters."""
|
||||
|
||||
def __init__(self):
|
||||
self._device_autocasting_enabled = False
|
||||
self._patches_and_weights: list[tuple[BaseLayerPatch, float]] = []
|
||||
|
||||
def set_device_autocasting_enabled(self, enabled: bool):
|
||||
"""Pass True to enable autocasting of module parameters to the same device as the input tensor. Pass False to
|
||||
@ -10,3 +16,30 @@ class CustomModuleMixin:
|
||||
not needed.
|
||||
"""
|
||||
self._device_autocasting_enabled = enabled
|
||||
|
||||
def add_patch(self, patch: BaseLayerPatch, patch_weight: float):
|
||||
"""Add a patch to the sidecar wrapper."""
|
||||
self._patches_and_weights.append((patch, patch_weight))
|
||||
|
||||
def clear_patches(self):
|
||||
"""Clear all patches from the sidecar wrapper."""
|
||||
self._patches_and_weights = []
|
||||
|
||||
def _aggregate_patch_parameters(
|
||||
self, patches_and_weights: list[tuple[BaseLayerPatch, float]]
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Helper function that aggregates the parameters from all patches into a single dict."""
|
||||
params: dict[str, torch.Tensor] = {}
|
||||
|
||||
for patch, patch_weight in patches_and_weights:
|
||||
# TODO(ryand): self._orig_module could be quantized. Depending on what the patch is doing with the original
|
||||
# module, this might fail or return incorrect results.
|
||||
layer_params = patch.get_parameters(self, weight=patch_weight)
|
||||
|
||||
for param_name, param_weight in layer_params.items():
|
||||
if param_name not in params:
|
||||
params[param_name] = param_weight
|
||||
else:
|
||||
params[param_name] += param_weight
|
||||
|
||||
return params
|
||||
|
@ -0,0 +1,30 @@
|
||||
from typing import overload
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@overload
|
||||
def add_nullable_tensors(a: None, b: None) -> None: ...
|
||||
|
||||
|
||||
@overload
|
||||
def add_nullable_tensors(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
@overload
|
||||
def add_nullable_tensors(a: torch.Tensor, b: None) -> torch.Tensor: ...
|
||||
|
||||
|
||||
@overload
|
||||
def add_nullable_tensors(a: None, b: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
def add_nullable_tensors(a: torch.Tensor | None, b: torch.Tensor | None) -> torch.Tensor | None:
|
||||
if a is None and b is None:
|
||||
return None
|
||||
elif a is None:
|
||||
return b
|
||||
elif b is None:
|
||||
return a
|
||||
else:
|
||||
return a + b
|
@ -10,6 +10,8 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch
|
||||
unwrap_custom_layer,
|
||||
wrap_custom_layer,
|
||||
)
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_8_bit_lt import (
|
||||
build_linear_8bit_lt_layer,
|
||||
)
|
||||
@ -259,3 +261,62 @@ def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: La
|
||||
assert custom_output.device.type == device
|
||||
|
||||
assert torch.allclose(orig_output, custom_output)
|
||||
|
||||
|
||||
LayerAndPatchUnderTest = tuple[torch.nn.Module, BaseLayerPatch, torch.Tensor, bool]
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
"linear_lora",
|
||||
]
|
||||
)
|
||||
def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchUnderTest:
|
||||
"""A fixture that returns a tuple of (layer, input, supports_cpu_inference) for the layer under test."""
|
||||
layer_type = request.param
|
||||
if layer_type == "linear_lora":
|
||||
# Create a linear layer.
|
||||
in_features = 10
|
||||
out_features = 20
|
||||
layer = torch.nn.Linear(in_features, out_features)
|
||||
|
||||
# Create a LoRA layer.
|
||||
rank = 4
|
||||
down = torch.randn(rank, in_features)
|
||||
up = torch.randn(out_features, rank)
|
||||
bias = torch.randn(out_features)
|
||||
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)
|
||||
|
||||
input = torch.randn(1, in_features)
|
||||
return (layer, lora_layer, input, True)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer_type: {layer_type}")
|
||||
|
||||
|
||||
@parameterize_all_devices
|
||||
def test_sidecar_patches(device: str, layer_and_patch_under_test: LayerAndPatchUnderTest):
|
||||
layer, patch, input, supports_cpu_inference = layer_and_patch_under_test
|
||||
|
||||
if device == "cpu" and not supports_cpu_inference:
|
||||
pytest.skip("Layer does not support CPU inference.")
|
||||
|
||||
# Move the layer, patch, and input to the device.
|
||||
layer_to_device_via_state_dict(layer, device)
|
||||
patch.to(torch.device(device))
|
||||
input = input.to(torch.device(device))
|
||||
|
||||
# Patch the LoRA layer into the linear layer.
|
||||
weight = 0.7
|
||||
layer_patched = copy.deepcopy(layer)
|
||||
parameters = patch.get_parameters(layer_patched, weight=weight)
|
||||
for param_name, param_weight in parameters.items():
|
||||
getattr(layer_patched, param_name).data += param_weight
|
||||
|
||||
# Wrap the original layer in a custom layer and add the patch to it as a sidecar.
|
||||
custom_layer = wrap_single_custom_layer(layer)
|
||||
custom_layer.add_patch(patch, weight)
|
||||
|
||||
# Run inference with the original layer and the patched layer and assert they are equal.
|
||||
output_patched = layer_patched(input)
|
||||
output_custom = custom_layer(input)
|
||||
assert torch.allclose(output_patched, output_custom)
|
||||
|
Loading…
Reference in New Issue
Block a user