2024-12-13 17:52:32 +00:00
|
|
|
import torch
|
|
|
|
|
2024-12-13 21:23:43 +00:00
|
|
|
from invokeai.backend.flux.modules.layers import RMSNorm
|
2024-12-13 17:52:32 +00:00
|
|
|
from invokeai.backend.patches.sidecar_wrappers.conv1d_sidecar_wrapper import Conv1dSidecarWrapper
|
|
|
|
from invokeai.backend.patches.sidecar_wrappers.conv2d_sidecar_wrapper import Conv2dSidecarWrapper
|
2024-12-13 21:23:43 +00:00
|
|
|
from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper
|
2024-12-13 17:52:32 +00:00
|
|
|
from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper
|
|
|
|
|
|
|
|
|
2024-12-13 20:02:05 +00:00
|
|
|
def wrap_module_with_sidecar_wrapper(orig_module: torch.nn.Module) -> torch.nn.Module:
|
2024-12-13 17:52:32 +00:00
|
|
|
if isinstance(orig_module, torch.nn.Linear):
|
2024-12-13 20:02:05 +00:00
|
|
|
return LinearSidecarWrapper(orig_module)
|
2024-12-13 17:52:32 +00:00
|
|
|
elif isinstance(orig_module, torch.nn.Conv1d):
|
2024-12-13 20:02:05 +00:00
|
|
|
return Conv1dSidecarWrapper(orig_module)
|
2024-12-13 17:52:32 +00:00
|
|
|
elif isinstance(orig_module, torch.nn.Conv2d):
|
2024-12-13 20:02:05 +00:00
|
|
|
return Conv2dSidecarWrapper(orig_module)
|
2024-12-13 21:23:43 +00:00
|
|
|
elif isinstance(orig_module, RMSNorm):
|
|
|
|
return FluxRMSNormSidecarWrapper(orig_module)
|
2024-12-13 17:52:32 +00:00
|
|
|
else:
|
|
|
|
raise ValueError(f"No sidecar wrapper found for module type: {type(orig_module)}")
|