mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-04-03 15:31:55 +08:00
12 lines
496 B
Python
12 lines
496 B
Python
import torch
|
|
|
|
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
|
|
|
|
|
|
class Conv1dSidecarWrapper(BaseSidecarWrapper):
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
|
|
return self.orig_module(input) + torch.nn.functional.conv1d(
|
|
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
|
|
)
|