From 8b4b0ff0cfbf3bcdd7193ca651209d7b65b2ead4 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sun, 29 Dec 2024 19:00:24 +0000 Subject: [PATCH] Fix bug in CustomConv1d and CustomConv2d patch calculations. --- .../custom_modules/custom_conv1d.py | 10 +++++++--- .../custom_modules/custom_conv2d.py | 10 +++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py index b59b5a2aae..e65b325924 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py @@ -4,6 +4,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( CustomModuleMixin, ) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import ( + add_nullable_tensors, +) class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin): @@ -21,9 +24,10 @@ class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin): orig_params=orig_params, device=input.device, ) - return self._conv_forward( - input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None) - ) + + weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None)) + bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None)) + return self._conv_forward(input, weight, bias) def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: weight = cast_to_device(self.weight, input.device) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py index 1077b47ed5..91f08fb96b 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py @@ -4,6 +4,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( CustomModuleMixin, ) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import ( + add_nullable_tensors, +) class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin): @@ -21,9 +24,10 @@ class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin): orig_params=orig_params, device=input.device, ) - return self._conv_forward( - input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None) - ) + + weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None)) + bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None)) + return self._conv_forward(input, weight, bias) def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: weight = cast_to_device(self.weight, input.device)