Fix bug in CustomConv1d and CustomConv2d patch calculations.

This commit is contained in:
Ryan Dick 2024-12-29 19:00:24 +00:00
parent 6fd9b0a274
commit 8b4b0ff0cf
2 changed files with 14 additions and 6 deletions

View File

@ -4,6 +4,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
add_nullable_tensors,
)
class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin):
@ -21,9 +24,10 @@ class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin):
orig_params=orig_params,
device=input.device,
)
return self._conv_forward(
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
)
weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None))
bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None))
return self._conv_forward(input, weight, bias)
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)

View File

@ -4,6 +4,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
add_nullable_tensors,
)
class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
@ -21,9 +24,10 @@ class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
orig_params=orig_params,
device=input.device,
)
return self._conv_forward(
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
)
weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None))
bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None))
return self._conv_forward(input, weight, bias)
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)