From 20acfc9a00c82fd7496785b4136ca64210755f32 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sat, 28 Dec 2024 20:49:17 +0000 Subject: [PATCH] Raise in CustomEmbedding and CustomGroupNorm if a patch is applied. --- .../torch_module_autocast/custom_modules/custom_embedding.py | 3 +++ .../torch_module_autocast/custom_modules/custom_group_norm.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_embedding.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_embedding.py index e6f0c5df21..e622b678fa 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_embedding.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_embedding.py @@ -20,6 +20,9 @@ class CustomEmbedding(torch.nn.Embedding, CustomModuleMixin): ) def forward(self, input: torch.Tensor) -> torch.Tensor: + if len(self._patches_and_weights) > 0: + raise RuntimeError("Embedding layers do not support patches") + if self._device_autocasting_enabled: return self._autocast_forward(input) else: diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_group_norm.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_group_norm.py index 66a46ac7ea..d02e2d533f 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_group_norm.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_group_norm.py @@ -13,6 +13,9 @@ class CustomGroupNorm(torch.nn.GroupNorm, CustomModuleMixin): return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) def forward(self, input: torch.Tensor) -> torch.Tensor: + if len(self._patches_and_weights) > 0: + raise RuntimeError("GroupNorm layers do not support patches") + if self._device_autocasting_enabled: return self._autocast_forward(input) else: