Raise in CustomEmbedding and CustomGroupNorm if a patch is applied.

This commit is contained in:
Ryan Dick 2024-12-28 20:49:17 +00:00
parent 918f541af8
commit 20acfc9a00
2 changed files with 6 additions and 0 deletions

View File

@ -20,6 +20,9 @@ class CustomEmbedding(torch.nn.Embedding, CustomModuleMixin):
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if len(self._patches_and_weights) > 0:
raise RuntimeError("Embedding layers do not support patches")
if self._device_autocasting_enabled:
return self._autocast_forward(input)
else:

View File

@ -13,6 +13,9 @@ class CustomGroupNorm(torch.nn.GroupNorm, CustomModuleMixin):
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if len(self._patches_and_weights) > 0:
raise RuntimeError("GroupNorm layers do not support patches")
if self._device_autocasting_enabled:
return self._autocast_forward(input)
else: