mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-07 03:17:05 +08:00
Add CustomFluxRMSNorm layer.
This commit is contained in:
parent
f692e217ea
commit
93e76b61d6
@ -0,0 +1,34 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.modules.layers import RMSNorm
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
|
||||
|
||||
|
||||
class CustomFluxRMSNorm(RMSNorm, CustomModuleMixin):
|
||||
def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Currently, CustomFluxRMSNorm layers only support patching with a single SetParameterLayer.
|
||||
assert len(self._patches_and_weights) == 1
|
||||
patch, _patch_weight = self._patches_and_weights[0]
|
||||
assert isinstance(patch, SetParameterLayer)
|
||||
assert patch.param_name == "scale"
|
||||
|
||||
# Apply the patch.
|
||||
# NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should
|
||||
# be handled.
|
||||
return torch.nn.functional.rms_norm(x, patch.weight.shape, patch.weight, eps=1e-6)
|
||||
|
||||
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
scale = cast_to_device(self.scale, x.device)
|
||||
return torch.nn.functional.rms_norm(x, scale.shape, scale, eps=1e-6)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if len(self._patches_and_weights) > 0:
|
||||
return self._autocast_forward_with_patches(x)
|
||||
elif self._device_autocasting_enabled:
|
||||
return self._autocast_forward(x)
|
||||
else:
|
||||
return super().forward(x)
|
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.modules.layers import RMSNorm
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv1d import (
|
||||
CustomConv1d,
|
||||
)
|
||||
@ -9,6 +10,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custo
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_embedding import (
|
||||
CustomEmbedding,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_flux_rms_norm import (
|
||||
CustomFluxRMSNorm,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_group_norm import (
|
||||
CustomGroupNorm,
|
||||
)
|
||||
@ -25,6 +29,7 @@ AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]]
|
||||
torch.nn.Conv2d: CustomConv2d,
|
||||
torch.nn.GroupNorm: CustomGroupNorm,
|
||||
torch.nn.Embedding: CustomEmbedding,
|
||||
RMSNorm: CustomFluxRMSNorm,
|
||||
}
|
||||
|
||||
try:
|
||||
|
@ -4,6 +4,7 @@ import gguf
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.modules.layers import RMSNorm
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
|
||||
AUTOCAST_MODULE_TYPE_MAPPING,
|
||||
AUTOCAST_MODULE_TYPE_MAPPING_INVERSE,
|
||||
@ -68,6 +69,7 @@ LayerUnderTest = tuple[torch.nn.Module, torch.Tensor, bool]
|
||||
"conv2d",
|
||||
"group_norm",
|
||||
"embedding",
|
||||
"flux_rms_norm",
|
||||
"linear_with_ggml_quantized_tensor",
|
||||
"invoke_linear_8_bit_lt",
|
||||
"invoke_linear_nf4",
|
||||
@ -86,6 +88,8 @@ def layer_under_test(request: pytest.FixtureRequest) -> LayerUnderTest:
|
||||
return (torch.nn.GroupNorm(2, 8), torch.randn(1, 8, 5), True)
|
||||
elif layer_type == "embedding":
|
||||
return (torch.nn.Embedding(4, 8), torch.tensor([0, 1], dtype=torch.long), True)
|
||||
elif layer_type == "flux_rms_norm":
|
||||
return (RMSNorm(8), torch.randn(1, 8), True)
|
||||
elif layer_type == "linear_with_ggml_quantized_tensor":
|
||||
return (build_linear_layer_with_ggml_quantized_tensor(), torch.randn(1, 32), True)
|
||||
elif layer_type == "invoke_linear_8_bit_lt":
|
||||
@ -351,7 +355,7 @@ def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest:
|
||||
|
||||
|
||||
@parameterize_all_devices
|
||||
def test_linear_sidecar_patches(device: str, layer_type: str, patch_under_test: PatchUnderTest):
|
||||
def test_linear_sidecar_patches(device: str, patch_under_test: PatchUnderTest):
|
||||
patches, input = patch_under_test
|
||||
|
||||
# Build the base layer under test.
|
||||
|
Loading…
Reference in New Issue
Block a user