Add CustomFluxRMSNorm layer.

This commit is contained in:
Ryan Dick 2024-12-28 20:33:38 +00:00
parent f692e217ea
commit 93e76b61d6
3 changed files with 44 additions and 1 deletions

View File

@ -0,0 +1,34 @@
import torch
from invokeai.backend.flux.modules.layers import RMSNorm
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
class CustomFluxRMSNorm(RMSNorm, CustomModuleMixin):
def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor:
# Currently, CustomFluxRMSNorm layers only support patching with a single SetParameterLayer.
assert len(self._patches_and_weights) == 1
patch, _patch_weight = self._patches_and_weights[0]
assert isinstance(patch, SetParameterLayer)
assert patch.param_name == "scale"
# Apply the patch.
# NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should
# be handled.
return torch.nn.functional.rms_norm(x, patch.weight.shape, patch.weight, eps=1e-6)
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
scale = cast_to_device(self.scale, x.device)
return torch.nn.functional.rms_norm(x, scale.shape, scale, eps=1e-6)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(self._patches_and_weights) > 0:
return self._autocast_forward_with_patches(x)
elif self._device_autocasting_enabled:
return self._autocast_forward(x)
else:
return super().forward(x)

View File

@ -1,5 +1,6 @@
import torch
from invokeai.backend.flux.modules.layers import RMSNorm
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv1d import (
CustomConv1d,
)
@ -9,6 +10,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custo
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_embedding import (
CustomEmbedding,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_flux_rms_norm import (
CustomFluxRMSNorm,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_group_norm import (
CustomGroupNorm,
)
@ -25,6 +29,7 @@ AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]]
torch.nn.Conv2d: CustomConv2d,
torch.nn.GroupNorm: CustomGroupNorm,
torch.nn.Embedding: CustomEmbedding,
RMSNorm: CustomFluxRMSNorm,
}
try:

View File

@ -4,6 +4,7 @@ import gguf
import pytest
import torch
from invokeai.backend.flux.modules.layers import RMSNorm
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
AUTOCAST_MODULE_TYPE_MAPPING,
AUTOCAST_MODULE_TYPE_MAPPING_INVERSE,
@ -68,6 +69,7 @@ LayerUnderTest = tuple[torch.nn.Module, torch.Tensor, bool]
"conv2d",
"group_norm",
"embedding",
"flux_rms_norm",
"linear_with_ggml_quantized_tensor",
"invoke_linear_8_bit_lt",
"invoke_linear_nf4",
@ -86,6 +88,8 @@ def layer_under_test(request: pytest.FixtureRequest) -> LayerUnderTest:
return (torch.nn.GroupNorm(2, 8), torch.randn(1, 8, 5), True)
elif layer_type == "embedding":
return (torch.nn.Embedding(4, 8), torch.tensor([0, 1], dtype=torch.long), True)
elif layer_type == "flux_rms_norm":
return (RMSNorm(8), torch.randn(1, 8), True)
elif layer_type == "linear_with_ggml_quantized_tensor":
return (build_linear_layer_with_ggml_quantized_tensor(), torch.randn(1, 32), True)
elif layer_type == "invoke_linear_8_bit_lt":
@ -351,7 +355,7 @@ def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest:
@parameterize_all_devices
def test_linear_sidecar_patches(device: str, layer_type: str, patch_under_test: PatchUnderTest):
def test_linear_sidecar_patches(device: str, patch_under_test: PatchUnderTest):
patches, input = patch_under_test
# Build the base layer under test.