From 3ed6e65a6e96afaec3ffa2c44e763a301e51b9f6 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 10 Dec 2024 17:27:33 +0000 Subject: [PATCH] Enable LoRAPatcher.apply_smart_lora_patches(...) throughout the stack. --- invokeai/app/invocations/compel.py | 6 ++++-- invokeai/app/invocations/denoise_latents.py | 3 ++- invokeai/app/invocations/flux_denoise.py | 3 ++- invokeai/app/invocations/flux_text_encoder.py | 4 +++- invokeai/app/invocations/sd3_text_encoder.py | 4 +++- .../invocations/tiled_multi_diffusion_denoise_latents.py | 4 +++- 6 files changed, 17 insertions(+), 7 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index fe8943bfcd..9142aa1de2 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -82,10 +82,11 @@ class CompelInvocation(BaseInvocation): # apply all patches while the model is on the target device text_encoder_info.model_on_device() as (cached_weights, text_encoder), tokenizer_info as tokenizer, - LoRAPatcher.apply_lora_patches( + LoRAPatcher.apply_smart_lora_patches( model=text_encoder, patches=_lora_loader(), prefix="lora_te_", + dtype=TorchDevice.choose_torch_dtype(), cached_weights=cached_weights, ), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. @@ -179,10 +180,11 @@ class SDXLPromptInvocationBase: # apply all patches while the model is on the target device text_encoder_info.model_on_device() as (cached_weights, text_encoder), tokenizer_info as tokenizer, - LoRAPatcher.apply_lora_patches( + LoRAPatcher.apply_smart_lora_patches( text_encoder, patches=_lora_loader(), prefix=lora_prefix, + dtype=TorchDevice.choose_torch_dtype(), cached_weights=cached_weights, ), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 4cbbcf07af..d97adae227 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -1003,10 +1003,11 @@ class DenoiseLatentsInvocation(BaseInvocation): ModelPatcher.apply_freeu(unet, self.unet.freeu_config), SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME # Apply the LoRA after unet has been moved to its target device for faster patching. - LoRAPatcher.apply_lora_patches( + LoRAPatcher.apply_smart_lora_patches( model=unet, patches=_lora_loader(), prefix="lora_unet_", + dtype=unet.dtype, cached_weights=cached_weights, ), ): diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index eb3e50f103..33581298d9 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -310,10 +310,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): if config.format in [ModelFormat.Checkpoint]: # The model is non-quantized, so we can apply the LoRA weights directly into the model. exit_stack.enter_context( - LoRAPatcher.apply_lora_patches( + LoRAPatcher.apply_smart_lora_patches( model=transformer, patches=self._lora_iterator(context), prefix=FLUX_LORA_TRANSFORMER_PREFIX, + dtype=inference_dtype, cached_weights=cached_weights, ) ) diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index e887ba26f2..c0b65acfaf 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -22,6 +22,7 @@ from invokeai.backend.lora.lora_model_raw import LoRAModelRaw from invokeai.backend.lora.lora_patcher import LoRAPatcher from invokeai.backend.model_manager.config import ModelFormat from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo +from invokeai.backend.util.devices import TorchDevice @invocation( @@ -111,10 +112,11 @@ class FluxTextEncoderInvocation(BaseInvocation): if clip_text_encoder_config.format in [ModelFormat.Diffusers]: # The model is non-quantized, so we can apply the LoRA weights directly into the model. exit_stack.enter_context( - LoRAPatcher.apply_lora_patches( + LoRAPatcher.apply_smart_lora_patches( model=clip_text_encoder, patches=self._clip_lora_iterator(context), prefix=FLUX_LORA_CLIP_PREFIX, + dtype=TorchDevice.choose_torch_dtype(), cached_weights=cached_weights, ) ) diff --git a/invokeai/app/invocations/sd3_text_encoder.py b/invokeai/app/invocations/sd3_text_encoder.py index 5969eda095..11f4eaf63e 100644 --- a/invokeai/app/invocations/sd3_text_encoder.py +++ b/invokeai/app/invocations/sd3_text_encoder.py @@ -21,6 +21,7 @@ from invokeai.backend.lora.lora_model_raw import LoRAModelRaw from invokeai.backend.lora.lora_patcher import LoRAPatcher from invokeai.backend.model_manager.config import ModelFormat from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo +from invokeai.backend.util.devices import TorchDevice # The SD3 T5 Max Sequence Length set based on the default in diffusers. SD3_T5_MAX_SEQ_LEN = 256 @@ -150,10 +151,11 @@ class Sd3TextEncoderInvocation(BaseInvocation): if clip_text_encoder_config.format in [ModelFormat.Diffusers]: # The model is non-quantized, so we can apply the LoRA weights directly into the model. exit_stack.enter_context( - LoRAPatcher.apply_lora_patches( + LoRAPatcher.apply_smart_lora_patches( model=clip_text_encoder, patches=self._clip_lora_iterator(context, clip_model), prefix=FLUX_LORA_CLIP_PREFIX, + dtype=TorchDevice.choose_torch_dtype(), cached_weights=cached_weights, ) ) diff --git a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py index 556600b412..c32ef29972 100644 --- a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py +++ b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py @@ -207,7 +207,9 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation): with ( ExitStack() as exit_stack, unet_info as unet, - LoRAPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"), + LoRAPatcher.apply_smart_lora_patches( + model=unet, patches=_lora_loader(), prefix="lora_unet_", dtype=unet.dtype + ), ): assert isinstance(unet, UNet2DConditionModel) latents = latents.to(device=unet.device, dtype=unet.dtype)