Add unit test for a SetParameterLayer patch applied to a CustomFluxRMSNorm layer.

This commit is contained in:
Ryan Dick 2024-12-28 20:44:48 +00:00
parent 93e76b61d6
commit 918f541af8
2 changed files with 37 additions and 1 deletions

View File

@ -1,3 +1,5 @@
from typing import TypeVar
import torch
from invokeai.backend.flux.modules.layers import RMSNorm
@ -52,7 +54,10 @@ except ImportError:
AUTOCAST_MODULE_TYPE_MAPPING_INVERSE = {v: k for k, v in AUTOCAST_MODULE_TYPE_MAPPING.items()}
def wrap_custom_layer(module_to_wrap: torch.nn.Module, custom_layer_type: type[torch.nn.Module]):
T = TypeVar("T", bound=torch.nn.Module)
def wrap_custom_layer(module_to_wrap: torch.nn.Module, custom_layer_type: type[T]) -> T:
# HACK(ryand): We use custom initialization logic so that we can initialize a new custom layer instance from an
# existing layer instance without calling __init__() on the original layer class. We achieve this by copying
# the attributes from the original layer instance to the new instance.

View File

@ -0,0 +1,31 @@
import torch
from invokeai.backend.flux.modules.layers import RMSNorm
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_flux_rms_norm import (
CustomFluxRMSNorm,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
wrap_custom_layer,
)
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
def test_custom_flux_rms_norm_patch():
"""Test a SetParameterLayer patch on a CustomFluxRMSNorm layer."""
# Create a RMSNorm layer.
dim = 8
rms_norm = RMSNorm(dim)
# Create a SetParameterLayer.
new_scale = torch.randn(dim)
set_parameter_layer = SetParameterLayer("scale", new_scale)
# Wrap the RMSNorm layer in a CustomFluxRMSNorm layer.
custom_flux_rms_norm = wrap_custom_layer(rms_norm, CustomFluxRMSNorm)
custom_flux_rms_norm.add_patch(set_parameter_layer, 1.0)
# Run the CustomFluxRMSNorm layer.
input = torch.randn(1, dim)
expected_output = torch.nn.functional.rms_norm(input, new_scale.shape, new_scale, eps=1e-6)
output_custom = custom_flux_rms_norm(input)
assert torch.allclose(output_custom, expected_output, atol=1e-6)