Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

48 lines
1.5 KiB
Python
Raw Normal View History

2024-07-27 02:39:53 +03:00
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING
2024-07-27 02:39:53 +03:00
from diffusers import UNet2DConditionModel
from invokeai.backend.patches.layer_patcher import LayerPatcher
2024-12-14 15:40:25 +00:00
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
2024-07-27 02:39:53 +03:00
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
if TYPE_CHECKING:
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext
2024-07-30 03:39:01 +03:00
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
2024-07-27 02:39:53 +03:00
class LoRAExt(ExtensionBase):
def __init__(
self,
node_context: InvocationContext,
model_id: ModelIdentifierField,
weight: float,
):
super().__init__()
self._node_context = node_context
self._model_id = model_id
self._weight = weight
@contextmanager
2024-07-30 03:39:01 +03:00
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
2024-07-27 02:39:53 +03:00
lora_model = self._node_context.models.load(self._model_id).model
2024-12-14 15:40:25 +00:00
assert isinstance(lora_model, ModelPatchRaw)
LayerPatcher.apply_smart_model_patch(
2024-07-27 02:39:53 +03:00
model=unet,
prefix="lora_unet_",
patch=lora_model,
patch_weight=self._weight,
original_weights=original_weights,
original_modules={},
dtype=unet.dtype,
force_direct_patching=True,
force_sidecar_patching=False,
2024-07-27 02:39:53 +03:00
)
del lora_model
yield