mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-08 11:57:36 +08:00
Add CustomFluxRMSNorm layer.
This commit is contained in:
parent
f692e217ea
commit
93e76b61d6
@ -0,0 +1,34 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.flux.modules.layers import RMSNorm
|
||||||
|
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||||
|
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||||
|
CustomModuleMixin,
|
||||||
|
)
|
||||||
|
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
|
||||||
|
|
||||||
|
|
||||||
|
class CustomFluxRMSNorm(RMSNorm, CustomModuleMixin):
|
||||||
|
def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Currently, CustomFluxRMSNorm layers only support patching with a single SetParameterLayer.
|
||||||
|
assert len(self._patches_and_weights) == 1
|
||||||
|
patch, _patch_weight = self._patches_and_weights[0]
|
||||||
|
assert isinstance(patch, SetParameterLayer)
|
||||||
|
assert patch.param_name == "scale"
|
||||||
|
|
||||||
|
# Apply the patch.
|
||||||
|
# NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should
|
||||||
|
# be handled.
|
||||||
|
return torch.nn.functional.rms_norm(x, patch.weight.shape, patch.weight, eps=1e-6)
|
||||||
|
|
||||||
|
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
scale = cast_to_device(self.scale, x.device)
|
||||||
|
return torch.nn.functional.rms_norm(x, scale.shape, scale, eps=1e-6)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if len(self._patches_and_weights) > 0:
|
||||||
|
return self._autocast_forward_with_patches(x)
|
||||||
|
elif self._device_autocasting_enabled:
|
||||||
|
return self._autocast_forward(x)
|
||||||
|
else:
|
||||||
|
return super().forward(x)
|
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.flux.modules.layers import RMSNorm
|
||||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv1d import (
|
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv1d import (
|
||||||
CustomConv1d,
|
CustomConv1d,
|
||||||
)
|
)
|
||||||
@ -9,6 +10,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custo
|
|||||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_embedding import (
|
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_embedding import (
|
||||||
CustomEmbedding,
|
CustomEmbedding,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_flux_rms_norm import (
|
||||||
|
CustomFluxRMSNorm,
|
||||||
|
)
|
||||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_group_norm import (
|
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_group_norm import (
|
||||||
CustomGroupNorm,
|
CustomGroupNorm,
|
||||||
)
|
)
|
||||||
@ -25,6 +29,7 @@ AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]]
|
|||||||
torch.nn.Conv2d: CustomConv2d,
|
torch.nn.Conv2d: CustomConv2d,
|
||||||
torch.nn.GroupNorm: CustomGroupNorm,
|
torch.nn.GroupNorm: CustomGroupNorm,
|
||||||
torch.nn.Embedding: CustomEmbedding,
|
torch.nn.Embedding: CustomEmbedding,
|
||||||
|
RMSNorm: CustomFluxRMSNorm,
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -4,6 +4,7 @@ import gguf
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.flux.modules.layers import RMSNorm
|
||||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
|
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
|
||||||
AUTOCAST_MODULE_TYPE_MAPPING,
|
AUTOCAST_MODULE_TYPE_MAPPING,
|
||||||
AUTOCAST_MODULE_TYPE_MAPPING_INVERSE,
|
AUTOCAST_MODULE_TYPE_MAPPING_INVERSE,
|
||||||
@ -68,6 +69,7 @@ LayerUnderTest = tuple[torch.nn.Module, torch.Tensor, bool]
|
|||||||
"conv2d",
|
"conv2d",
|
||||||
"group_norm",
|
"group_norm",
|
||||||
"embedding",
|
"embedding",
|
||||||
|
"flux_rms_norm",
|
||||||
"linear_with_ggml_quantized_tensor",
|
"linear_with_ggml_quantized_tensor",
|
||||||
"invoke_linear_8_bit_lt",
|
"invoke_linear_8_bit_lt",
|
||||||
"invoke_linear_nf4",
|
"invoke_linear_nf4",
|
||||||
@ -86,6 +88,8 @@ def layer_under_test(request: pytest.FixtureRequest) -> LayerUnderTest:
|
|||||||
return (torch.nn.GroupNorm(2, 8), torch.randn(1, 8, 5), True)
|
return (torch.nn.GroupNorm(2, 8), torch.randn(1, 8, 5), True)
|
||||||
elif layer_type == "embedding":
|
elif layer_type == "embedding":
|
||||||
return (torch.nn.Embedding(4, 8), torch.tensor([0, 1], dtype=torch.long), True)
|
return (torch.nn.Embedding(4, 8), torch.tensor([0, 1], dtype=torch.long), True)
|
||||||
|
elif layer_type == "flux_rms_norm":
|
||||||
|
return (RMSNorm(8), torch.randn(1, 8), True)
|
||||||
elif layer_type == "linear_with_ggml_quantized_tensor":
|
elif layer_type == "linear_with_ggml_quantized_tensor":
|
||||||
return (build_linear_layer_with_ggml_quantized_tensor(), torch.randn(1, 32), True)
|
return (build_linear_layer_with_ggml_quantized_tensor(), torch.randn(1, 32), True)
|
||||||
elif layer_type == "invoke_linear_8_bit_lt":
|
elif layer_type == "invoke_linear_8_bit_lt":
|
||||||
@ -351,7 +355,7 @@ def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest:
|
|||||||
|
|
||||||
|
|
||||||
@parameterize_all_devices
|
@parameterize_all_devices
|
||||||
def test_linear_sidecar_patches(device: str, layer_type: str, patch_under_test: PatchUnderTest):
|
def test_linear_sidecar_patches(device: str, patch_under_test: PatchUnderTest):
|
||||||
patches, input = patch_under_test
|
patches, input = patch_under_test
|
||||||
|
|
||||||
# Build the base layer under test.
|
# Build the base layer under test.
|
||||||
|
Loading…
Reference in New Issue
Block a user