mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-07 03:17:05 +08:00
Add unit test for a SetParameterLayer patch applied to a CustomFluxRMSNorm layer.
This commit is contained in:
parent
93e76b61d6
commit
918f541af8
@ -1,3 +1,5 @@
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.modules.layers import RMSNorm
|
||||
@ -52,7 +54,10 @@ except ImportError:
|
||||
AUTOCAST_MODULE_TYPE_MAPPING_INVERSE = {v: k for k, v in AUTOCAST_MODULE_TYPE_MAPPING.items()}
|
||||
|
||||
|
||||
def wrap_custom_layer(module_to_wrap: torch.nn.Module, custom_layer_type: type[torch.nn.Module]):
|
||||
T = TypeVar("T", bound=torch.nn.Module)
|
||||
|
||||
|
||||
def wrap_custom_layer(module_to_wrap: torch.nn.Module, custom_layer_type: type[T]) -> T:
|
||||
# HACK(ryand): We use custom initialization logic so that we can initialize a new custom layer instance from an
|
||||
# existing layer instance without calling __init__() on the original layer class. We achieve this by copying
|
||||
# the attributes from the original layer instance to the new instance.
|
||||
|
@ -0,0 +1,31 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.modules.layers import RMSNorm
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_flux_rms_norm import (
|
||||
CustomFluxRMSNorm,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
|
||||
wrap_custom_layer,
|
||||
)
|
||||
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
|
||||
|
||||
|
||||
def test_custom_flux_rms_norm_patch():
|
||||
"""Test a SetParameterLayer patch on a CustomFluxRMSNorm layer."""
|
||||
# Create a RMSNorm layer.
|
||||
dim = 8
|
||||
rms_norm = RMSNorm(dim)
|
||||
|
||||
# Create a SetParameterLayer.
|
||||
new_scale = torch.randn(dim)
|
||||
set_parameter_layer = SetParameterLayer("scale", new_scale)
|
||||
|
||||
# Wrap the RMSNorm layer in a CustomFluxRMSNorm layer.
|
||||
custom_flux_rms_norm = wrap_custom_layer(rms_norm, CustomFluxRMSNorm)
|
||||
custom_flux_rms_norm.add_patch(set_parameter_layer, 1.0)
|
||||
|
||||
# Run the CustomFluxRMSNorm layer.
|
||||
input = torch.randn(1, dim)
|
||||
expected_output = torch.nn.functional.rms_norm(input, new_scale.shape, new_scale, eps=1e-6)
|
||||
output_custom = custom_flux_rms_norm(input)
|
||||
assert torch.allclose(output_custom, expected_output, atol=1e-6)
|
Loading…
Reference in New Issue
Block a user