mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-04-03 07:21:32 +08:00
21 lines
1.0 KiB
Python
21 lines
1.0 KiB
Python
import torch
|
|
|
|
from invokeai.backend.flux.modules.layers import RMSNorm
|
|
from invokeai.backend.patches.sidecar_wrappers.conv1d_sidecar_wrapper import Conv1dSidecarWrapper
|
|
from invokeai.backend.patches.sidecar_wrappers.conv2d_sidecar_wrapper import Conv2dSidecarWrapper
|
|
from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper
|
|
from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper
|
|
|
|
|
|
def wrap_module_with_sidecar_wrapper(orig_module: torch.nn.Module) -> torch.nn.Module:
|
|
if isinstance(orig_module, torch.nn.Linear):
|
|
return LinearSidecarWrapper(orig_module)
|
|
elif isinstance(orig_module, torch.nn.Conv1d):
|
|
return Conv1dSidecarWrapper(orig_module)
|
|
elif isinstance(orig_module, torch.nn.Conv2d):
|
|
return Conv2dSidecarWrapper(orig_module)
|
|
elif isinstance(orig_module, RMSNorm):
|
|
return FluxRMSNormSidecarWrapper(orig_module)
|
|
else:
|
|
raise ValueError(f"No sidecar wrapper found for module type: {type(orig_module)}")
|