mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-04-03 15:31:55 +08:00
55 lines
2.2 KiB
Python
55 lines
2.2 KiB
Python
import torch
|
|
|
|
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
|
|
|
|
|
class BaseSidecarWrapper(torch.nn.Module):
|
|
"""A base class for sidecar wrappers.
|
|
|
|
A sidecar wrapper is a wrapper for an existing torch.nn.Module that applies a
|
|
list of patches as 'sidecar' patches. I.e. it applies the sidecar patches during forward inference without modifying
|
|
the original module.
|
|
|
|
Sidecar wrappers are typically used over regular patches when:
|
|
- The original module is quantized and so the weights can't be patched in the usual way.
|
|
- The original module is on the CPU and modifying the weights would require backing up the original weights and
|
|
doubling the CPU memory usage.
|
|
"""
|
|
|
|
def __init__(
|
|
self, orig_module: torch.nn.Module, patches_and_weights: list[tuple[BaseLayerPatch, float]] | None = None
|
|
):
|
|
super().__init__()
|
|
self._orig_module = orig_module
|
|
self._patches_and_weights = [] if patches_and_weights is None else patches_and_weights
|
|
|
|
@property
|
|
def orig_module(self) -> torch.nn.Module:
|
|
return self._orig_module
|
|
|
|
def add_patch(self, patch: BaseLayerPatch, patch_weight: float):
|
|
"""Add a patch to the sidecar wrapper."""
|
|
self._patches_and_weights.append((patch, patch_weight))
|
|
|
|
def _aggregate_patch_parameters(
|
|
self, patches_and_weights: list[tuple[BaseLayerPatch, float]]
|
|
) -> dict[str, torch.Tensor]:
|
|
"""Helper function that aggregates the parameters from all patches into a single dict."""
|
|
params: dict[str, torch.Tensor] = {}
|
|
|
|
for patch, patch_weight in patches_and_weights:
|
|
# TODO(ryand): self._orig_module could be quantized. Depending on what the patch is doing with the original
|
|
# module, this might fail or return incorrect results.
|
|
layer_params = patch.get_parameters(self._orig_module, weight=patch_weight)
|
|
|
|
for param_name, param_weight in layer_params.items():
|
|
if param_name not in params:
|
|
params[param_name] = param_weight
|
|
else:
|
|
params[param_name] += param_weight
|
|
|
|
return params
|
|
|
|
def forward(self, *args, **kwargs): # type: ignore
|
|
raise NotImplementedError()
|