48 lines
1.5 KiB
Python

from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING
from diffusers import UNet2DConditionModel
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
if TYPE_CHECKING:
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
class LoRAExt(ExtensionBase):
def __init__(
self,
node_context: InvocationContext,
model_id: ModelIdentifierField,
weight: float,
):
super().__init__()
self._node_context = node_context
self._model_id = model_id
self._weight = weight
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
lora_model = self._node_context.models.load(self._model_id).model
assert isinstance(lora_model, ModelPatchRaw)
LayerPatcher.apply_smart_model_patch(
model=unet,
prefix="lora_unet_",
patch=lora_model,
patch_weight=self._weight,
original_weights=original_weights,
original_modules={},
dtype=unet.dtype,
force_direct_patching=True,
force_sidecar_patching=False,
)
del lora_model
yield