mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-07 03:17:05 +08:00
Enable LoRAPatcher.apply_smart_lora_patches(...) throughout the stack.
This commit is contained in:
parent
0148512038
commit
61253b91f1
@ -82,10 +82,11 @@ class CompelInvocation(BaseInvocation):
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
LayerPatcher.apply_model_patches(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=text_encoder,
|
||||
patches=_lora_loader(),
|
||||
prefix="lora_te_",
|
||||
dtype=TorchDevice.choose_torch_dtype(),
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
@ -179,10 +180,11 @@ class SDXLPromptInvocationBase:
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
LayerPatcher.apply_model_patches(
|
||||
text_encoder,
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=text_encoder,
|
||||
patches=_lora_loader(),
|
||||
prefix=lora_prefix,
|
||||
dtype=TorchDevice.choose_torch_dtype(),
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
|
@ -1003,10 +1003,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
LayerPatcher.apply_model_patches(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=unet,
|
||||
patches=_lora_loader(),
|
||||
prefix="lora_unet_",
|
||||
dtype=unet.dtype,
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
):
|
||||
|
@ -309,10 +309,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
if config.format in [ModelFormat.Checkpoint]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_model_patches(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
dtype=inference_dtype,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
|
@ -22,6 +22,7 @@ from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_L
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import LayerPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -111,10 +112,11 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_model_patches(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=clip_text_encoder,
|
||||
patches=self._clip_lora_iterator(context),
|
||||
prefix=FLUX_LORA_CLIP_PREFIX,
|
||||
dtype=TorchDevice.choose_torch_dtype(),
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_L
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import LayerPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
|
||||
SD3_T5_MAX_SEQ_LEN = 256
|
||||
@ -150,10 +151,11 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_model_patches(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=clip_text_encoder,
|
||||
patches=self._clip_lora_iterator(context, clip_model),
|
||||
prefix=FLUX_LORA_CLIP_PREFIX,
|
||||
dtype=TorchDevice.choose_torch_dtype(),
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
|
@ -207,7 +207,9 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
unet_info as unet,
|
||||
LayerPatcher.apply_model_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=unet, patches=_lora_loader(), prefix="lora_unet_", dtype=unet.dtype
|
||||
),
|
||||
):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
Loading…
Reference in New Issue
Block a user