From 918f541af8660d15be1bc4fae3e5b42743fef60e Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sat, 28 Dec 2024 20:44:48 +0000 Subject: [PATCH] Add unit test for a SetParameterLayer patch applied to a CustomFluxRMSNorm layer. --- .../torch_module_autocast.py | 7 ++++- .../test_custom_flux_rms_norm.py | 31 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) create mode 100644 tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_flux_rms_norm.py diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py index 2d85e32370..73d5ec1ee5 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py @@ -1,3 +1,5 @@ +from typing import TypeVar + import torch from invokeai.backend.flux.modules.layers import RMSNorm @@ -52,7 +54,10 @@ except ImportError: AUTOCAST_MODULE_TYPE_MAPPING_INVERSE = {v: k for k, v in AUTOCAST_MODULE_TYPE_MAPPING.items()} -def wrap_custom_layer(module_to_wrap: torch.nn.Module, custom_layer_type: type[torch.nn.Module]): +T = TypeVar("T", bound=torch.nn.Module) + + +def wrap_custom_layer(module_to_wrap: torch.nn.Module, custom_layer_type: type[T]) -> T: # HACK(ryand): We use custom initialization logic so that we can initialize a new custom layer instance from an # existing layer instance without calling __init__() on the original layer class. We achieve this by copying # the attributes from the original layer instance to the new instance. diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_flux_rms_norm.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_flux_rms_norm.py new file mode 100644 index 0000000000..05e15302d5 --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_flux_rms_norm.py @@ -0,0 +1,31 @@ +import torch + +from invokeai.backend.flux.modules.layers import RMSNorm +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_flux_rms_norm import ( + CustomFluxRMSNorm, +) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + wrap_custom_layer, +) +from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer + + +def test_custom_flux_rms_norm_patch(): + """Test a SetParameterLayer patch on a CustomFluxRMSNorm layer.""" + # Create a RMSNorm layer. + dim = 8 + rms_norm = RMSNorm(dim) + + # Create a SetParameterLayer. + new_scale = torch.randn(dim) + set_parameter_layer = SetParameterLayer("scale", new_scale) + + # Wrap the RMSNorm layer in a CustomFluxRMSNorm layer. + custom_flux_rms_norm = wrap_custom_layer(rms_norm, CustomFluxRMSNorm) + custom_flux_rms_norm.add_patch(set_parameter_layer, 1.0) + + # Run the CustomFluxRMSNorm layer. + input = torch.randn(1, dim) + expected_output = torch.nn.functional.rms_norm(input, new_scale.shape, new_scale, eps=1e-6) + output_custom = custom_flux_rms_norm(input) + assert torch.allclose(output_custom, expected_output, atol=1e-6)