2024-07-27 02:39:53 +03:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
from contextlib import contextmanager
|
2024-09-10 14:45:40 +00:00
|
|
|
from typing import TYPE_CHECKING
|
2024-07-27 02:39:53 +03:00
|
|
|
|
|
|
|
from diffusers import UNet2DConditionModel
|
|
|
|
|
2024-12-17 17:19:12 +00:00
|
|
|
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
2024-12-14 15:40:25 +00:00
|
|
|
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
2024-07-27 02:39:53 +03:00
|
|
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
from invokeai.app.invocations.model import ModelIdentifierField
|
|
|
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
2024-07-30 03:39:01 +03:00
|
|
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
2024-07-27 02:39:53 +03:00
|
|
|
|
|
|
|
|
|
|
|
class LoRAExt(ExtensionBase):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
node_context: InvocationContext,
|
|
|
|
model_id: ModelIdentifierField,
|
|
|
|
weight: float,
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self._node_context = node_context
|
|
|
|
self._model_id = model_id
|
|
|
|
self._weight = weight
|
|
|
|
|
|
|
|
@contextmanager
|
2024-07-30 03:39:01 +03:00
|
|
|
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
2024-07-27 02:39:53 +03:00
|
|
|
lora_model = self._node_context.models.load(self._model_id).model
|
2024-12-14 15:40:25 +00:00
|
|
|
assert isinstance(lora_model, ModelPatchRaw)
|
2024-12-17 18:33:36 +00:00
|
|
|
LayerPatcher.apply_smart_model_patch(
|
2024-07-27 02:39:53 +03:00
|
|
|
model=unet,
|
|
|
|
prefix="lora_unet_",
|
2024-09-10 14:45:40 +00:00
|
|
|
patch=lora_model,
|
|
|
|
patch_weight=self._weight,
|
2024-07-30 00:34:37 +03:00
|
|
|
original_weights=original_weights,
|
2024-12-17 18:33:36 +00:00
|
|
|
original_modules={},
|
|
|
|
dtype=unet.dtype,
|
|
|
|
force_direct_patching=True,
|
|
|
|
force_sidecar_patching=False,
|
2024-07-27 02:39:53 +03:00
|
|
|
)
|
|
|
|
del lora_model
|
|
|
|
|
2024-07-30 00:34:37 +03:00
|
|
|
yield
|