mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-07 03:17:05 +08:00
Fix bug in CustomConv1d and CustomConv2d patch calculations.
This commit is contained in:
parent
6fd9b0a274
commit
8b4b0ff0cf
@ -4,6 +4,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
|
||||
add_nullable_tensors,
|
||||
)
|
||||
|
||||
|
||||
class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin):
|
||||
@ -21,9 +24,10 @@ class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin):
|
||||
orig_params=orig_params,
|
||||
device=input.device,
|
||||
)
|
||||
return self._conv_forward(
|
||||
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
|
||||
)
|
||||
|
||||
weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None))
|
||||
bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None))
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
|
@ -4,6 +4,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
|
||||
add_nullable_tensors,
|
||||
)
|
||||
|
||||
|
||||
class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
|
||||
@ -21,9 +24,10 @@ class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
|
||||
orig_params=orig_params,
|
||||
device=input.device,
|
||||
)
|
||||
return self._conv_forward(
|
||||
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
|
||||
)
|
||||
|
||||
weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None))
|
||||
bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None))
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
|
Loading…
Reference in New Issue
Block a user