diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 723cd93a10..b535254cfd 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -20,8 +20,8 @@ from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.ti_utils import generate_ti_list from invokeai.backend.model_patcher import ModelPatcher +from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.model_patch_raw import ModelPatchRaw -from invokeai.backend.patches.model_patcher import LayerPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ConditioningFieldData, @@ -82,10 +82,11 @@ class CompelInvocation(BaseInvocation): # apply all patches while the model is on the target device text_encoder_info.model_on_device() as (cached_weights, text_encoder), tokenizer_info as tokenizer, - LayerPatcher.apply_model_patches( + LayerPatcher.apply_smart_model_patches( model=text_encoder, patches=_lora_loader(), prefix="lora_te_", + dtype=text_encoder.dtype, cached_weights=cached_weights, ), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. @@ -179,10 +180,11 @@ class SDXLPromptInvocationBase: # apply all patches while the model is on the target device text_encoder_info.model_on_device() as (cached_weights, text_encoder), tokenizer_info as tokenizer, - LayerPatcher.apply_model_patches( - text_encoder, + LayerPatcher.apply_smart_model_patches( + model=text_encoder, patches=_lora_loader(), prefix=lora_prefix, + dtype=text_encoder.dtype, cached_weights=cached_weights, ), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 62ac6934c3..5aeeff57ad 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -39,8 +39,8 @@ from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.model_manager import BaseModelType, ModelVariantType from invokeai.backend.model_patcher import ModelPatcher +from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.model_patch_raw import ModelPatchRaw -from invokeai.backend.patches.model_patcher import LayerPatcher from invokeai.backend.stable_diffusion import PipelineIntermediateState from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs from invokeai.backend.stable_diffusion.diffusers_pipeline import ( @@ -1003,10 +1003,11 @@ class DenoiseLatentsInvocation(BaseInvocation): ModelPatcher.apply_freeu(unet, self.unet.freeu_config), SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME # Apply the LoRA after unet has been moved to its target device for faster patching. - LayerPatcher.apply_model_patches( + LayerPatcher.apply_smart_model_patches( model=unet, patches=_lora_loader(), prefix="lora_unet_", + dtype=unet.dtype, cached_weights=cached_weights, ), ): diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 08bbd9f31c..d8bc8135bc 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -48,9 +48,9 @@ from invokeai.backend.flux.sampling_utils import ( ) from invokeai.backend.flux.text_conditioning import FluxTextConditioning from invokeai.backend.model_manager.config import ModelFormat +from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX from invokeai.backend.patches.model_patch_raw import ModelPatchRaw -from invokeai.backend.patches.model_patcher import LayerPatcher from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.util.devices import TorchDevice @@ -304,36 +304,33 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): config = transformer_info.config assert config is not None - # Apply LoRA models to the transformer. - # Note: We apply the LoRA after the transformer has been moved to its target device for faster patching. + # Determine if the model is quantized. + # If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in + # slower inference than direct patching, but is agnostic to the quantization format. if config.format in [ModelFormat.Checkpoint]: - # The model is non-quantized, so we can apply the LoRA weights directly into the model. - exit_stack.enter_context( - LayerPatcher.apply_model_patches( - model=transformer, - patches=self._lora_iterator(context), - prefix=FLUX_LORA_TRANSFORMER_PREFIX, - cached_weights=cached_weights, - ) - ) + model_is_quantized = False elif config.format in [ ModelFormat.BnbQuantizedLlmInt8b, ModelFormat.BnbQuantizednf4b, ModelFormat.GGUFQuantized, ]: - # The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference, - # than directly patching the weights, but is agnostic to the quantization format. - exit_stack.enter_context( - LayerPatcher.apply_model_sidecar_patches( - model=transformer, - patches=self._lora_iterator(context), - prefix=FLUX_LORA_TRANSFORMER_PREFIX, - dtype=inference_dtype, - ) - ) + model_is_quantized = True else: raise ValueError(f"Unsupported model format: {config.format}") + # Apply LoRA models to the transformer. + # Note: We apply the LoRA after the transformer has been moved to its target device for faster patching. + exit_stack.enter_context( + LayerPatcher.apply_smart_model_patches( + model=transformer, + patches=self._lora_iterator(context), + prefix=FLUX_LORA_TRANSFORMER_PREFIX, + dtype=inference_dtype, + cached_weights=cached_weights, + force_sidecar_patching=model_is_quantized, + ) + ) + # Prepare IP-Adapter extensions. pos_ip_adapter_extensions, neg_ip_adapter_extensions = self._prep_ip_adapter_extensions( pos_image_prompt_clip_embeds=pos_image_prompt_clip_embeds, diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index c1113603f0..3f1f38c4a1 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -18,9 +18,9 @@ from invokeai.app.invocations.primitives import FluxConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.modules.conditioner import HFEncoder from invokeai.backend.model_manager.config import ModelFormat +from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX from invokeai.backend.patches.model_patch_raw import ModelPatchRaw -from invokeai.backend.patches.model_patcher import LayerPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo @@ -111,10 +111,11 @@ class FluxTextEncoderInvocation(BaseInvocation): if clip_text_encoder_config.format in [ModelFormat.Diffusers]: # The model is non-quantized, so we can apply the LoRA weights directly into the model. exit_stack.enter_context( - LayerPatcher.apply_model_patches( + LayerPatcher.apply_smart_model_patches( model=clip_text_encoder, patches=self._clip_lora_iterator(context), prefix=FLUX_LORA_CLIP_PREFIX, + dtype=clip_text_encoder.dtype, cached_weights=cached_weights, ) ) diff --git a/invokeai/app/invocations/sd3_text_encoder.py b/invokeai/app/invocations/sd3_text_encoder.py index f92977bd42..6569fa0a76 100644 --- a/invokeai/app/invocations/sd3_text_encoder.py +++ b/invokeai/app/invocations/sd3_text_encoder.py @@ -17,9 +17,9 @@ from invokeai.app.invocations.model import CLIPField, T5EncoderField from invokeai.app.invocations.primitives import SD3ConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_manager.config import ModelFormat +from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX from invokeai.backend.patches.model_patch_raw import ModelPatchRaw -from invokeai.backend.patches.model_patcher import LayerPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo # The SD3 T5 Max Sequence Length set based on the default in diffusers. @@ -150,10 +150,11 @@ class Sd3TextEncoderInvocation(BaseInvocation): if clip_text_encoder_config.format in [ModelFormat.Diffusers]: # The model is non-quantized, so we can apply the LoRA weights directly into the model. exit_stack.enter_context( - LayerPatcher.apply_model_patches( + LayerPatcher.apply_smart_model_patches( model=clip_text_encoder, patches=self._clip_lora_iterator(context, clip_model), prefix=FLUX_LORA_CLIP_PREFIX, + dtype=clip_text_encoder.dtype, cached_weights=cached_weights, ) ) diff --git a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py index 761e73d2bf..7c1442177f 100644 --- a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py +++ b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py @@ -22,8 +22,8 @@ from invokeai.app.invocations.fields import ( from invokeai.app.invocations.model import UNetField from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.model_patch_raw import ModelPatchRaw -from invokeai.backend.patches.model_patcher import LayerPatcher from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import ( MultiDiffusionPipeline, @@ -207,7 +207,9 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation): with ( ExitStack() as exit_stack, unet_info as unet, - LayerPatcher.apply_model_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"), + LayerPatcher.apply_smart_model_patches( + model=unet, patches=_lora_loader(), prefix="lora_unet_", dtype=unet.dtype + ), ): assert isinstance(unet, UNet2DConditionModel) latents = latents.to(device=unet.device, dtype=unet.dtype) diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py index ab1a62db46..a5e1e3d539 100644 --- a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py @@ -1,9 +1,7 @@ import torch -from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( - AUTOCAST_MODULE_TYPE_MAPPING, - apply_custom_layers_to_model, - remove_custom_layers_from_model, +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( + CustomModuleMixin, ) from invokeai.backend.util.calc_tensor_size import calc_tensor_size from invokeai.backend.util.logging import InvokeAILogger @@ -45,10 +43,10 @@ class CachedModelWithPartialLoad: def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]: """Find all modules that support autocasting.""" - return {n: m for n, m in self._model.named_modules() if type(m) in AUTOCAST_MODULE_TYPE_MAPPING} + return {n: m for n, m in self._model.named_modules() if isinstance(m, CustomModuleMixin)} # type: ignore def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]: - keys_in_modules_that_do_not_support_autocast = set() + keys_in_modules_that_do_not_support_autocast: set[str] = set() for key in self._cpu_state_dict.keys(): for module_name in self._modules_that_support_autocast.keys(): if key.startswith(module_name): @@ -70,6 +68,11 @@ class CachedModelWithPartialLoad: if name in module._non_persistent_buffers_set: module._buffers[name] = buffer.to(device, copy=True) + def _set_autocast_enabled_in_all_modules(self, enabled: bool): + """Set autocast_enabled flag in all modules that support device autocasting.""" + for module in self._modules_that_support_autocast.values(): + module.set_device_autocasting_enabled(enabled) + @property def model(self) -> torch.nn.Module: return self._model @@ -114,7 +117,7 @@ class CachedModelWithPartialLoad: cur_state_dict = self._model.state_dict() - # First, process the keys *must* be loaded into VRAM. + # First, process the keys that *must* be loaded into VRAM. for key in self._keys_in_modules_that_do_not_support_autocast: param = cur_state_dict[key] if param.device.type == self._compute_device.type: @@ -157,10 +160,10 @@ class CachedModelWithPartialLoad: self._cur_vram_bytes += vram_bytes_loaded if fully_loaded: - remove_custom_layers_from_model(self._model) + self._set_autocast_enabled_in_all_modules(False) # TODO(ryand): Warn if the self.cur_vram_bytes() and self.total_bytes() are out of sync. else: - apply_custom_layers_to_model(self._model) + self._set_autocast_enabled_in_all_modules(True) # Move all non-persistent buffers to the compute device. These are a weird edge case and do not participate in # the vram_bytes_loaded tracking. @@ -197,5 +200,5 @@ class CachedModelWithPartialLoad: # We may have gone from a fully-loaded model to a partially-loaded model, so we need to reapply the custom # layers. - apply_custom_layers_to_model(self._model) + self._set_autocast_enabled_in_all_modules(True) return vram_bytes_freed diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache.py b/invokeai/backend/model_manager/load/model_cache/model_cache.py index f61e2963a7..dbc3670c95 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache.py @@ -13,6 +13,9 @@ from invokeai.backend.model_manager import AnyModel, SubModelType from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + apply_custom_layers_to_model, +) from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger @@ -143,6 +146,10 @@ class ModelCache: size = calc_model_size_by_data(self._logger, model) self.make_room(size) + # Inject custom modules into the model. + if isinstance(model, torch.nn.Module): + apply_custom_layers_to_model(model) + running_on_cpu = self._execution_device == torch.device("cpu") state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None cache_record = CacheRecord(key=key, model=model, device=self._storage_device, state_dict=state_dict, size=size) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py deleted file mode 100644 index 8a1bacf683..0000000000 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py +++ /dev/null @@ -1,50 +0,0 @@ -import torch - -from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device - -# This file contains custom torch.nn.Module classes that support streaming of weights to the target device. -# Each class sub-classes the original module type that is is replacing, so the following properties are preserved: -# - isinstance(m, torch.nn.OrginalModule) should still work. -# - Patching the weights (e.g. for LoRA) should still work if non-quantized. - - -class CustomLinear(torch.nn.Linear): - def forward(self, input: torch.Tensor) -> torch.Tensor: - weight = cast_to_device(self.weight, input.device) - bias = cast_to_device(self.bias, input.device) - return torch.nn.functional.linear(input, weight, bias) - - -class CustomConv1d(torch.nn.Conv1d): - def forward(self, input: torch.Tensor) -> torch.Tensor: - weight = cast_to_device(self.weight, input.device) - bias = cast_to_device(self.bias, input.device) - return self._conv_forward(input, weight, bias) - - -class CustomConv2d(torch.nn.Conv2d): - def forward(self, input: torch.Tensor) -> torch.Tensor: - weight = cast_to_device(self.weight, input.device) - bias = cast_to_device(self.bias, input.device) - return self._conv_forward(input, weight, bias) - - -class CustomGroupNorm(torch.nn.GroupNorm): - def forward(self, input: torch.Tensor) -> torch.Tensor: - weight = cast_to_device(self.weight, input.device) - bias = cast_to_device(self.bias, input.device) - return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) - - -class CustomEmbedding(torch.nn.Embedding): - def forward(self, input: torch.Tensor) -> torch.Tensor: - weight = cast_to_device(self.weight, input.device) - return torch.nn.functional.embedding( - input, - weight, - self.padding_idx, - self.max_norm, - self.norm_type, - self.scale_grad_by_freq, - self.sparse, - ) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/README.md b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/README.md new file mode 100644 index 0000000000..cadb1b6dd5 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/README.md @@ -0,0 +1,8 @@ + +This directory contains custom implementations of common torch.nn.Module classes that add support for: +- Streaming weights to the execution device +- Applying sidecar patches at execution time (e.g. sidecar LoRA layers) + +Each custom class sub-classes the original module type that is is replacing, so the following properties are preserved: +- `isinstance(m, torch.nn.OrginalModule)` should still work. +- Patching the weights directly (e.g. for LoRA) should still work. (Of course, this is not possible for quantized layers, hence the sidecar support.) diff --git a/invokeai/backend/patches/sidecar_wrappers/__init__.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/__init__.py similarity index 100% rename from invokeai/backend/patches/sidecar_wrappers/__init__.py rename to invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/__init__.py diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py new file mode 100644 index 0000000000..e65b325924 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py @@ -0,0 +1,43 @@ +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( + CustomModuleMixin, +) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import ( + add_nullable_tensors, +) + + +class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin): + def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor: + weight = cast_to_device(self.weight, input.device) + bias = cast_to_device(self.bias, input.device) + + # Prepare the original parameters for the patch aggregation. + orig_params = {"weight": weight, "bias": bias} + # Filter out None values. + orig_params = {k: v for k, v in orig_params.items() if v is not None} + + aggregated_param_residuals = self._aggregate_patch_parameters( + patches_and_weights=self._patches_and_weights, + orig_params=orig_params, + device=input.device, + ) + + weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None)) + bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None)) + return self._conv_forward(input, weight, bias) + + def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: + weight = cast_to_device(self.weight, input.device) + bias = cast_to_device(self.bias, input.device) + return self._conv_forward(input, weight, bias) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if len(self._patches_and_weights) > 0: + return self._autocast_forward_with_patches(input) + elif self._device_autocasting_enabled: + return self._autocast_forward(input) + else: + return super().forward(input) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py new file mode 100644 index 0000000000..91f08fb96b --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py @@ -0,0 +1,43 @@ +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( + CustomModuleMixin, +) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import ( + add_nullable_tensors, +) + + +class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin): + def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor: + weight = cast_to_device(self.weight, input.device) + bias = cast_to_device(self.bias, input.device) + + # Prepare the original parameters for the patch aggregation. + orig_params = {"weight": weight, "bias": bias} + # Filter out None values. + orig_params = {k: v for k, v in orig_params.items() if v is not None} + + aggregated_param_residuals = self._aggregate_patch_parameters( + patches_and_weights=self._patches_and_weights, + orig_params=orig_params, + device=input.device, + ) + + weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None)) + bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None)) + return self._conv_forward(input, weight, bias) + + def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: + weight = cast_to_device(self.weight, input.device) + bias = cast_to_device(self.bias, input.device) + return self._conv_forward(input, weight, bias) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if len(self._patches_and_weights) > 0: + return self._autocast_forward_with_patches(input) + elif self._device_autocasting_enabled: + return self._autocast_forward(input) + else: + return super().forward(input) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_embedding.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_embedding.py new file mode 100644 index 0000000000..e622b678fa --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_embedding.py @@ -0,0 +1,29 @@ +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( + CustomModuleMixin, +) + + +class CustomEmbedding(torch.nn.Embedding, CustomModuleMixin): + def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: + weight = cast_to_device(self.weight, input.device) + return torch.nn.functional.embedding( + input, + weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if len(self._patches_and_weights) > 0: + raise RuntimeError("Embedding layers do not support patches") + + if self._device_autocasting_enabled: + return self._autocast_forward(input) + else: + return super().forward(input) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_flux_rms_norm.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_flux_rms_norm.py new file mode 100644 index 0000000000..dccbe4af6c --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_flux_rms_norm.py @@ -0,0 +1,36 @@ +import torch + +from invokeai.backend.flux.modules.layers import RMSNorm +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( + CustomModuleMixin, +) +from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer + + +class CustomFluxRMSNorm(RMSNorm, CustomModuleMixin): + def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor: + # Currently, CustomFluxRMSNorm layers only support patching with a single SetParameterLayer. + assert len(self._patches_and_weights) == 1 + patch, _patch_weight = self._patches_and_weights[0] + assert isinstance(patch, SetParameterLayer) + assert patch.param_name == "scale" + + scale = cast_to_device(patch.weight, x.device) + + # Apply the patch. + # NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should + # be handled. + return torch.nn.functional.rms_norm(x, scale.shape, scale, eps=1e-6) + + def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor: + scale = cast_to_device(self.scale, x.device) + return torch.nn.functional.rms_norm(x, scale.shape, scale, eps=1e-6) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if len(self._patches_and_weights) > 0: + return self._autocast_forward_with_patches(x) + elif self._device_autocasting_enabled: + return self._autocast_forward(x) + else: + return super().forward(x) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_group_norm.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_group_norm.py new file mode 100644 index 0000000000..d02e2d533f --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_group_norm.py @@ -0,0 +1,22 @@ +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( + CustomModuleMixin, +) + + +class CustomGroupNorm(torch.nn.GroupNorm, CustomModuleMixin): + def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: + weight = cast_to_device(self.weight, input.device) + bias = cast_to_device(self.bias, input.device) + return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if len(self._patches_and_weights) > 0: + raise RuntimeError("GroupNorm layers do not support patches") + + if self._device_autocasting_enabled: + return self._autocast_forward(input) + else: + return super().forward(input) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_8_bit_lt.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_8_bit_lt.py similarity index 61% rename from invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_8_bit_lt.py rename to invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_8_bit_lt.py index 3941a2af6b..2b9d8e9e98 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_8_bit_lt.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_8_bit_lt.py @@ -2,11 +2,20 @@ import bitsandbytes as bnb import torch from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import ( + autocast_linear_forward_sidecar_patches, +) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( + CustomModuleMixin, +) from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt -class CustomInvokeLinear8bitLt(InvokeLinear8bitLt): - def forward(self, x: torch.Tensor) -> torch.Tensor: +class CustomInvokeLinear8bitLt(InvokeLinear8bitLt, CustomModuleMixin): + def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor: + return autocast_linear_forward_sidecar_patches(self, x, self._patches_and_weights) + + def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor: matmul_state = bnb.MatmulLtState() matmul_state.threshold = self.state.threshold matmul_state.has_fp16_weights = self.state.has_fp16_weights @@ -25,3 +34,11 @@ class CustomInvokeLinear8bitLt(InvokeLinear8bitLt): # it's dtype field must be accessible, even though it's not used. We pass in self.weight even though it could be # on the wrong device. return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if len(self._patches_and_weights) > 0: + return self._autocast_forward_with_patches(x) + elif self._device_autocasting_enabled: + return self._autocast_forward(x) + else: + return super().forward(x) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_nf4.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_nf4.py similarity index 71% rename from invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_nf4.py rename to invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_nf4.py index c697b3c7b4..89284d5509 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_nf4.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_nf4.py @@ -4,11 +4,20 @@ import bitsandbytes as bnb import torch from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import ( + autocast_linear_forward_sidecar_patches, +) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( + CustomModuleMixin, +) from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4 -class CustomInvokeLinearNF4(InvokeLinearNF4): - def forward(self, x: torch.Tensor) -> torch.Tensor: +class CustomInvokeLinearNF4(InvokeLinearNF4, CustomModuleMixin): + def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor: + return autocast_linear_forward_sidecar_patches(self, x, self._patches_and_weights) + + def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor: bnb.nn.modules.fix_4bit_weight_quant_state_from_module(self) # weights are cast automatically as Int8Params, but the bias has to be cast manually @@ -43,3 +52,11 @@ class CustomInvokeLinearNF4(InvokeLinearNF4): bias = cast_to_device(self.bias, x.device) return bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state).to(inp_dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if len(self._patches_and_weights) > 0: + return self._autocast_forward_with_patches(x) + elif self._device_autocasting_enabled: + return self._autocast_forward(x) + else: + return super().forward(x) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py new file mode 100644 index 0000000000..7d5784563e --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py @@ -0,0 +1,106 @@ +import copy + +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( + CustomModuleMixin, +) +from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch +from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer +from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer +from invokeai.backend.patches.layers.lora_layer import LoRALayer + + +def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor: + """An optimized implementation of the residual calculation for a sidecar linear LoRALayer.""" + x = torch.nn.functional.linear(input, lora_layer.down) + if lora_layer.mid is not None: + x = torch.nn.functional.linear(x, lora_layer.mid) + x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias) + x *= lora_weight * lora_layer.scale() + return x + + +def concatenated_lora_forward( + input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float +) -> torch.Tensor: + """An optimized implementation of the residual calculation for a sidecar ConcatenatedLoRALayer.""" + x_chunks: list[torch.Tensor] = [] + for lora_layer in concatenated_lora_layer.lora_layers: + x_chunk = torch.nn.functional.linear(input, lora_layer.down) + if lora_layer.mid is not None: + x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid) + x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias) + x_chunk *= lora_weight * lora_layer.scale() + x_chunks.append(x_chunk) + + # TODO(ryand): Generalize to support concat_axis != 0. + assert concatenated_lora_layer.concat_axis == 0 + x = torch.cat(x_chunks, dim=-1) + return x + + +def autocast_linear_forward_sidecar_patches( + orig_module: torch.nn.Linear, input: torch.Tensor, patches_and_weights: list[tuple[BaseLayerPatch, float]] +) -> torch.Tensor: + """A function that runs a linear layer (quantized or non-quantized) with sidecar patches for a linear layer. + Compatible with both quantized and non-quantized Linear layers. + """ + # First, apply the original linear layer. + # NOTE: We slice the input to match the original weight shape in order to work with FluxControlLoRAs, which + # change the linear layer's in_features. + orig_input = input + input = orig_input[..., : orig_module.in_features] + output = orig_module._autocast_forward(input) + + # Then, apply layers for which we have optimized implementations. + unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = [] + for patch, patch_weight in patches_and_weights: + # Shallow copy the patch so that we can cast it to the target device without modifying the original patch. + patch = copy.copy(patch) + patch.to(input.device) + + if isinstance(patch, FluxControlLoRALayer): + # Note that we use the original input here, not the sliced input. + output += linear_lora_forward(orig_input, patch, patch_weight) + elif isinstance(patch, LoRALayer): + output += linear_lora_forward(input, patch, patch_weight) + elif isinstance(patch, ConcatenatedLoRALayer): + output += concatenated_lora_forward(input, patch, patch_weight) + else: + unprocessed_patches_and_weights.append((patch, patch_weight)) + + # Finally, apply any remaining patches. + if len(unprocessed_patches_and_weights) > 0: + # Prepare the original parameters for the patch aggregation. + orig_params = {"weight": orig_module.weight, "bias": orig_module.bias} + # Filter out None values. + orig_params = {k: v for k, v in orig_params.items() if v is not None} + + aggregated_param_residuals = orig_module._aggregate_patch_parameters( + unprocessed_patches_and_weights, orig_params=orig_params, device=input.device + ) + output += torch.nn.functional.linear( + input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None) + ) + + return output + + +class CustomLinear(torch.nn.Linear, CustomModuleMixin): + def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor: + return autocast_linear_forward_sidecar_patches(self, input, self._patches_and_weights) + + def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: + weight = cast_to_device(self.weight, input.device) + bias = cast_to_device(self.bias, input.device) + return torch.nn.functional.linear(input, weight, bias) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if len(self._patches_and_weights) > 0: + return self._autocast_forward_with_patches(input) + elif self._device_autocasting_enabled: + return self._autocast_forward(input) + else: + return super().forward(input) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py new file mode 100644 index 0000000000..a7312517a4 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py @@ -0,0 +1,63 @@ +import copy + +import torch + +from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch + + +class CustomModuleMixin: + """A mixin class for custom modules that enables device autocasting of module parameters.""" + + def __init__(self): + self._device_autocasting_enabled = False + self._patches_and_weights: list[tuple[BaseLayerPatch, float]] = [] + + def set_device_autocasting_enabled(self, enabled: bool): + """Pass True to enable autocasting of module parameters to the same device as the input tensor. Pass False to + disable autocasting, which results in slightly faster execution speed when we know that device autocasting is + not needed. + """ + self._device_autocasting_enabled = enabled + + def is_device_autocasting_enabled(self) -> bool: + """Check if device autocasting is enabled for the module.""" + return self._device_autocasting_enabled + + def add_patch(self, patch: BaseLayerPatch, patch_weight: float): + """Add a patch to the module.""" + self._patches_and_weights.append((patch, patch_weight)) + + def clear_patches(self): + """Clear all patches from the module.""" + self._patches_and_weights = [] + + def get_num_patches(self) -> int: + """Get the number of patches in the module.""" + return len(self._patches_and_weights) + + def _aggregate_patch_parameters( + self, + patches_and_weights: list[tuple[BaseLayerPatch, float]], + orig_params: dict[str, torch.Tensor], + device: torch.device | None = None, + ): + """Helper function that aggregates the parameters from all patches into a single dict.""" + params: dict[str, torch.Tensor] = {} + + for patch, patch_weight in patches_and_weights: + if device is not None: + # Shallow copy the patch so that we can cast it to the target device without modifying the original patch. + patch = copy.copy(patch) + patch.to(device) + + # TODO(ryand): `self` could be a quantized module. Depending on what the patch is doing with the original + # parameters, this might fail or return incorrect results. + layer_params = patch.get_parameters(orig_params, weight=patch_weight) + + for param_name, param_weight in layer_params.items(): + if param_name not in params: + params[param_name] = param_weight + else: + params[param_name] += param_weight + + return params diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/utils.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/utils.py new file mode 100644 index 0000000000..60294d9e0c --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/utils.py @@ -0,0 +1,30 @@ +from typing import overload + +import torch + + +@overload +def add_nullable_tensors(a: None, b: None) -> None: ... + + +@overload +def add_nullable_tensors(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: ... + + +@overload +def add_nullable_tensors(a: torch.Tensor, b: None) -> torch.Tensor: ... + + +@overload +def add_nullable_tensors(a: None, b: torch.Tensor) -> torch.Tensor: ... + + +def add_nullable_tensors(a: torch.Tensor | None, b: torch.Tensor | None) -> torch.Tensor | None: + if a is None and b is None: + return None + elif a is None: + return b + elif b is None: + return a + else: + return a + b diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py index 825eebf64e..0e271eaec5 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py @@ -1,12 +1,29 @@ +from typing import TypeVar + import torch -from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import ( +from invokeai.backend.flux.modules.layers import RMSNorm +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv1d import ( CustomConv1d, +) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv2d import ( CustomConv2d, +) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_embedding import ( CustomEmbedding, +) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_flux_rms_norm import ( + CustomFluxRMSNorm, +) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_group_norm import ( CustomGroupNorm, +) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import ( CustomLinear, ) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( + CustomModuleMixin, +) AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]] = { torch.nn.Linear: CustomLinear, @@ -14,14 +31,15 @@ AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]] torch.nn.Conv2d: CustomConv2d, torch.nn.GroupNorm: CustomGroupNorm, torch.nn.Embedding: CustomEmbedding, + RMSNorm: CustomFluxRMSNorm, } try: # These dependencies are not expected to be present on MacOS. - from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_8_bit_lt import ( + from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_8_bit_lt import ( CustomInvokeLinear8bitLt, ) - from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_nf4 import ( + from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_nf4 import ( CustomInvokeLinearNF4, ) from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt @@ -33,24 +51,55 @@ except ImportError: pass -def apply_custom_layers_to_model(model: torch.nn.Module): - def apply_custom_layers(module: torch.nn.Module): - override_type = AUTOCAST_MODULE_TYPE_MAPPING.get(type(module), None) +AUTOCAST_MODULE_TYPE_MAPPING_INVERSE = {v: k for k, v in AUTOCAST_MODULE_TYPE_MAPPING.items()} + + +T = TypeVar("T", bound=torch.nn.Module) + + +def wrap_custom_layer(module_to_wrap: torch.nn.Module, custom_layer_type: type[T]) -> T: + # HACK(ryand): We use custom initialization logic so that we can initialize a new custom layer instance from an + # existing layer instance without calling __init__() on the original layer class. We achieve this by copying + # the attributes from the original layer instance to the new instance. + custom_layer = custom_layer_type.__new__(custom_layer_type) + # Note that we share the __dict__. + # TODO(ryand): In the future, we may want to do a shallow copy of the __dict__. + custom_layer.__dict__ = module_to_wrap.__dict__ + + # Initialize the CustomModuleMixin fields. + CustomModuleMixin.__init__(custom_layer) # type: ignore + return custom_layer + + +def unwrap_custom_layer(custom_layer: torch.nn.Module, original_layer_type: type[torch.nn.Module]): + # HACK(ryand): We use custom initialization logic so that we can initialize a new custom layer instance from an + # existing layer instance without calling __init__() on the original layer class. We achieve this by copying + # the attributes from the original layer instance to the new instance. + original_layer = original_layer_type.__new__(original_layer_type) + # Note that we share the __dict__. + # TODO(ryand): In the future, we may want to do a shallow copy of the __dict__ and strip out the CustomModuleMixin + # fields. + original_layer.__dict__ = custom_layer.__dict__ + return original_layer + + +def apply_custom_layers_to_model(module: torch.nn.Module, device_autocasting_enabled: bool = False): + for name, submodule in module.named_children(): + override_type = AUTOCAST_MODULE_TYPE_MAPPING.get(type(submodule), None) if override_type is not None: - module.__class__ = override_type - - # model.apply(...) calls apply_custom_layers(...) on each module in the model. - model.apply(apply_custom_layers) + custom_layer = wrap_custom_layer(submodule, override_type) + # TODO(ryand): In the future, we should manage this flag on a per-module basis. + custom_layer.set_device_autocasting_enabled(device_autocasting_enabled) + setattr(module, name, custom_layer) + else: + # Recursively apply to submodules + apply_custom_layers_to_model(submodule, device_autocasting_enabled) -def remove_custom_layers_from_model(model: torch.nn.Module): - # Invert AUTOCAST_MODULE_TYPE_MAPPING. - original_module_type_mapping = {v: k for k, v in AUTOCAST_MODULE_TYPE_MAPPING.items()} - - def remove_custom_layers(module: torch.nn.Module): - override_type = original_module_type_mapping.get(type(module), None) +def remove_custom_layers_from_model(module: torch.nn.Module): + for name, submodule in module.named_children(): + override_type = AUTOCAST_MODULE_TYPE_MAPPING_INVERSE.get(type(submodule), None) if override_type is not None: - module.__class__ = override_type - - # model.apply(...) calls remove_custom_layers(...) on each module in the model. - model.apply(remove_custom_layers) + setattr(module, name, unwrap_custom_layer(submodule, override_type)) + else: + remove_custom_layers_from_model(submodule) diff --git a/invokeai/backend/patches/model_patcher.py b/invokeai/backend/patches/layer_patcher.py similarity index 58% rename from invokeai/backend/patches/model_patcher.py rename to invokeai/backend/patches/layer_patcher.py index 14b92a26a8..463d753b9d 100644 --- a/invokeai/backend/patches/model_patcher.py +++ b/invokeai/backend/patches/layer_patcher.py @@ -7,8 +7,6 @@ from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer from invokeai.backend.patches.model_patch_raw import ModelPatchRaw from invokeai.backend.patches.pad_with_zeros import pad_with_zeros -from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper -from invokeai.backend.patches.sidecar_wrappers.utils import wrap_module_with_sidecar_wrapper from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage @@ -17,58 +15,64 @@ class LayerPatcher: @staticmethod @torch.no_grad() @contextmanager - def apply_model_patches( + def apply_smart_model_patches( model: torch.nn.Module, patches: Iterable[Tuple[ModelPatchRaw, float]], prefix: str, + dtype: torch.dtype, cached_weights: Optional[Dict[str, torch.Tensor]] = None, + force_direct_patching: bool = False, + force_sidecar_patching: bool = False, ): - """Apply one or more LoRA patches to a model within a context manager. - - Args: - model (torch.nn.Module): The model to patch. - patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and - associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory - all at once. - prefix (str): The keys in the patches will be filtered to only include weights with this prefix. - cached_weights (Optional[Dict[str, torch.Tensor]], optional): Read-only copy of the model's state dict in - CPU RAM, for efficient unpatching purposes. + """Apply 'smart' model patching that chooses whether to use direct patching or a sidecar wrapper for each + module. """ + + # original_weights are stored for unpatching layers that are directly patched. original_weights = OriginalWeightsStorage(cached_weights) + # original_modules are stored for unpatching layers that are wrapped. + original_modules: dict[str, torch.nn.Module] = {} try: for patch, patch_weight in patches: - LayerPatcher.apply_model_patch( + LayerPatcher.apply_smart_model_patch( model=model, prefix=prefix, patch=patch, patch_weight=patch_weight, original_weights=original_weights, + original_modules=original_modules, + dtype=dtype, + force_direct_patching=force_direct_patching, + force_sidecar_patching=force_sidecar_patching, ) - del patch yield finally: + # Restore directly patched layers. for param_key, weight in original_weights.get_changed_weights(): cur_param = model.get_parameter(param_key) cur_param.data = weight.to(dtype=cur_param.dtype, device=cur_param.device, copy=True) + # Clear patches from all patched modules. + # Note: This logic assumes no nested modules in original_modules. + for orig_module in original_modules.values(): + orig_module.clear_patches() + @staticmethod @torch.no_grad() - def apply_model_patch( + def apply_smart_model_patch( model: torch.nn.Module, prefix: str, patch: ModelPatchRaw, patch_weight: float, original_weights: OriginalWeightsStorage, + original_modules: dict[str, torch.nn.Module], + dtype: torch.dtype, + force_direct_patching: bool, + force_sidecar_patching: bool, ): - """Apply a single LoRA patch to a model. - - Args: - model (torch.nn.Module): The model to patch. - prefix (str): A string prefix that precedes keys used in the LoRAs weight layers. - patch (LoRAModelRaw): The LoRA model to patch in. - patch_weight (float): The weight of the LoRA patch. - original_weights (OriginalWeightsStorage): Storage for the original weights of the model, for unpatching. + """Apply a single LoRA patch to a model using the 'smart' patching strategy that chooses whether to use direct + patching or a sidecar wrapper for each module. """ if patch_weight == 0: return @@ -89,13 +93,50 @@ class LayerPatcher: model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened ) - LayerPatcher._apply_model_layer_patch( - module_to_patch=module, - module_to_patch_key=module_key, - patch=layer, - patch_weight=patch_weight, - original_weights=original_weights, - ) + # Decide whether to use direct patching or a sidecar patch. + # Direct patching is preferred, because it results in better runtime speed. + # Reasons to use sidecar patching: + # - The module is quantized, so the caller passed force_sidecar_patching=True. + # - The module already has sidecar patches. + # - The module is on the CPU (and we don't want to store a second full copy of the original weights on the + # CPU, since this would double the RAM usage) + # NOTE: For now, we don't check if the layer is quantized here. We assume that this is checked in the caller + # and that the caller will set force_sidecar_patching=True if the layer is quantized. + # TODO(ryand): Handle the case where we are running without a GPU. Should we set a config flag that allows + # forcing full patching even on the CPU? + use_sidecar_patching = False + if force_direct_patching and force_sidecar_patching: + raise ValueError("Cannot force both direct and sidecar patching.") + elif force_direct_patching: + use_sidecar_patching = False + elif force_sidecar_patching: + use_sidecar_patching = True + elif module.get_num_patches() > 0: + use_sidecar_patching = True + elif LayerPatcher._is_any_part_of_layer_on_cpu(module): + use_sidecar_patching = True + + if use_sidecar_patching: + LayerPatcher._apply_model_layer_wrapper_patch( + module_to_patch=module, + module_to_patch_key=module_key, + patch=layer, + patch_weight=patch_weight, + original_modules=original_modules, + dtype=dtype, + ) + else: + LayerPatcher._apply_model_layer_patch( + module_to_patch=module, + module_to_patch_key=module_key, + patch=layer, + patch_weight=patch_weight, + original_weights=original_weights, + ) + + @staticmethod + def _is_any_part_of_layer_on_cpu(layer: torch.nn.Module) -> bool: + return any(p.device.type == "cpu" for p in layer.parameters()) @staticmethod @torch.no_grad() @@ -120,7 +161,9 @@ class LayerPatcher: # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. - for param_name, param_weight in patch.get_parameters(module_to_patch, weight=patch_weight).items(): + for param_name, param_weight in patch.get_parameters( + dict(module_to_patch.named_parameters(recurse=False)), weight=patch_weight + ).items(): param_key = module_to_patch_key + "." + param_name module_param = module_to_patch.get_parameter(param_name) @@ -143,93 +186,9 @@ class LayerPatcher: patch.to(device=TorchDevice.CPU_DEVICE) - @staticmethod - @torch.no_grad() - @contextmanager - def apply_model_sidecar_patches( - model: torch.nn.Module, - patches: Iterable[Tuple[ModelPatchRaw, float]], - prefix: str, - dtype: torch.dtype, - ): - """Apply one or more LoRA sidecar patches to a model within a context manager. Sidecar patches incur some - overhead compared to normal LoRA patching, but they allow for LoRA layers to applied to base layers in any - quantization format. - - Args: - model (torch.nn.Module): The model to patch. - patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and - associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory - all at once. - prefix (str): The keys in the patches will be filtered to only include weights with this prefix. - dtype (torch.dtype): The compute dtype of the sidecar layers. This cannot easily be inferred from the model, - since the sidecar layers are typically applied on top of quantized layers whose weight dtype is - different from their compute dtype. - """ - original_modules: dict[str, torch.nn.Module] = {} - try: - for patch, patch_weight in patches: - LayerPatcher._apply_model_sidecar_patch( - model=model, - prefix=prefix, - patch=patch, - patch_weight=patch_weight, - original_modules=original_modules, - dtype=dtype, - ) - yield - finally: - # Restore original modules. - # Note: This logic assumes no nested modules in original_modules. - for module_key, orig_module in original_modules.items(): - module_parent_key, module_name = LayerPatcher._split_parent_key(module_key) - parent_module = model.get_submodule(module_parent_key) - LayerPatcher._set_submodule(parent_module, module_name, orig_module) - - @staticmethod - def _apply_model_sidecar_patch( - model: torch.nn.Module, - patch: ModelPatchRaw, - patch_weight: float, - prefix: str, - original_modules: dict[str, torch.nn.Module], - dtype: torch.dtype, - ): - """Apply a single LoRA sidecar patch to a model.""" - - if patch_weight == 0: - return - - # If the layer keys contain a dot, then they are not flattened, and can be directly used to access model - # submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been - # replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly - # without searching, but some legacy code still uses flattened keys. - layer_keys_are_flattened = "." not in next(iter(patch.layers.keys())) - - prefix_len = len(prefix) - - for layer_key, layer in patch.layers.items(): - if not layer_key.startswith(prefix): - continue - - module_key, module = LayerPatcher._get_submodule( - model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened - ) - - LayerPatcher._apply_model_layer_wrapper_patch( - model=model, - module_to_patch=module, - module_to_patch_key=module_key, - patch=layer, - patch_weight=patch_weight, - original_modules=original_modules, - dtype=dtype, - ) - @staticmethod @torch.no_grad() def _apply_model_layer_wrapper_patch( - model: torch.nn.Module, module_to_patch: torch.nn.Module, module_to_patch_key: str, patch: BaseLayerPatch, @@ -237,25 +196,16 @@ class LayerPatcher: original_modules: dict[str, torch.nn.Module], dtype: torch.dtype, ): - """Apply a single LoRA wrapper patch to a model.""" - # Replace the original module with a BaseSidecarWrapper if it has not already been done. - if not isinstance(module_to_patch, BaseSidecarWrapper): - wrapped_module = wrap_module_with_sidecar_wrapper(orig_module=module_to_patch) - original_modules[module_to_patch_key] = module_to_patch - module_parent_key, module_name = LayerPatcher._split_parent_key(module_to_patch_key) - module_parent = model.get_submodule(module_parent_key) - LayerPatcher._set_submodule(module_parent, module_name, wrapped_module) - else: - assert module_to_patch_key in original_modules - wrapped_module = module_to_patch - + """Apply a single LoRA wrapper patch to a module.""" # Move the LoRA layer to the same device/dtype as the orig module. first_param = next(module_to_patch.parameters()) device = first_param.device patch.to(device=device, dtype=dtype) - # Add the patch to the sidecar wrapper. - wrapped_module.add_patch(patch, patch_weight) + if module_to_patch_key not in original_modules: + original_modules[module_to_patch_key] = module_to_patch + + module_to_patch.add_patch(patch, patch_weight) @staticmethod def _split_parent_key(module_key: str) -> tuple[str, str]: diff --git a/invokeai/backend/patches/layers/base_layer_patch.py b/invokeai/backend/patches/layers/base_layer_patch.py index 5eb04864c8..f6f0289a90 100644 --- a/invokeai/backend/patches/layers/base_layer_patch.py +++ b/invokeai/backend/patches/layers/base_layer_patch.py @@ -5,7 +5,7 @@ import torch class BaseLayerPatch(ABC): @abstractmethod - def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]: + def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]: """Get the parameter residual updates that should be applied to the original parameters. Parameters omitted from the returned dict are not updated. """ diff --git a/invokeai/backend/patches/layers/concatenated_lora_layer.py b/invokeai/backend/patches/layers/concatenated_lora_layer.py index a098a9e61b..a699a47433 100644 --- a/invokeai/backend/patches/layers/concatenated_lora_layer.py +++ b/invokeai/backend/patches/layers/concatenated_lora_layer.py @@ -30,7 +30,7 @@ class ConcatenatedLoRALayer(LoRALayerBase): layer_weights = [lora_layer.get_weight(None) * lora_layer.scale() for lora_layer in self.lora_layers] # pyright: ignore[reportArgumentType] return torch.cat(layer_weights, dim=self.concat_axis) - def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]: + def get_bias(self, orig_bias: torch.Tensor | None) -> Optional[torch.Tensor]: # TODO(ryand): Currently, we pass orig_bias=None to the sub-layers. If we want to support sub-layers that # require this value, we will need to implement chunking of the original bias tensor here. # Note that we must apply the sub-layer scales here. diff --git a/invokeai/backend/patches/layers/flux_control_lora_layer.py b/invokeai/backend/patches/layers/flux_control_lora_layer.py index 142336a00a..ad592456a9 100644 --- a/invokeai/backend/patches/layers/flux_control_lora_layer.py +++ b/invokeai/backend/patches/layers/flux_control_lora_layer.py @@ -8,11 +8,11 @@ class FluxControlLoRALayer(LoRALayer): shapes don't match. """ - def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]: + def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]: """This overrides the base class behavior to skip the reshaping step.""" scale = self.scale() - params = {"weight": self.get_weight(orig_module.weight) * (weight * scale)} - bias = self.get_bias(orig_module.bias) + params = {"weight": self.get_weight(orig_parameters["weight"]) * (weight * scale)} + bias = self.get_bias(orig_parameters.get("bias", None)) if bias is not None: params["bias"] = bias * (weight * scale) diff --git a/invokeai/backend/patches/layers/lora_layer_base.py b/invokeai/backend/patches/layers/lora_layer_base.py index 13669ad5d3..123e5afa2c 100644 --- a/invokeai/backend/patches/layers/lora_layer_base.py +++ b/invokeai/backend/patches/layers/lora_layer_base.py @@ -54,19 +54,19 @@ class LoRALayerBase(BaseLayerPatch): def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: raise NotImplementedError() - def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]: + def get_bias(self, orig_bias: torch.Tensor | None) -> Optional[torch.Tensor]: return self.bias - def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]: + def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]: scale = self.scale() - params = {"weight": self.get_weight(orig_module.weight) * (weight * scale)} - bias = self.get_bias(orig_module.bias) + params = {"weight": self.get_weight(orig_parameters["weight"]) * (weight * scale)} + bias = self.get_bias(orig_parameters.get("bias", None)) if bias is not None: params["bias"] = bias * (weight * scale) # Reshape all params to match the original module's shape. for param_name, param_weight in params.items(): - orig_param = orig_module.get_parameter(param_name) + orig_param = orig_parameters[param_name] if param_weight.shape != orig_param.shape: params[param_name] = param_weight.reshape(orig_param.shape) diff --git a/invokeai/backend/patches/layers/set_parameter_layer.py b/invokeai/backend/patches/layers/set_parameter_layer.py index f0ae461f4d..1b7fe94d36 100644 --- a/invokeai/backend/patches/layers/set_parameter_layer.py +++ b/invokeai/backend/patches/layers/set_parameter_layer.py @@ -14,10 +14,10 @@ class SetParameterLayer(BaseLayerPatch): self.weight = weight self.param_name = param_name - def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]: + def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]: # Note: We intentionally ignore the weight parameter here. This matches the behavior in the official FLUX # Control LoRA implementation. - diff = self.weight - orig_module.get_parameter(self.param_name) + diff = self.weight - orig_parameters[self.param_name] return {self.param_name: diff} def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None): diff --git a/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py b/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py deleted file mode 100644 index c22525bc95..0000000000 --- a/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch - -from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch - - -class BaseSidecarWrapper(torch.nn.Module): - """A base class for sidecar wrappers. - - A sidecar wrapper is a wrapper for an existing torch.nn.Module that applies a - list of patches as 'sidecar' patches. I.e. it applies the sidecar patches during forward inference without modifying - the original module. - - Sidecar wrappers are typically used over regular patches when: - - The original module is quantized and so the weights can't be patched in the usual way. - - The original module is on the CPU and modifying the weights would require backing up the original weights and - doubling the CPU memory usage. - """ - - def __init__( - self, orig_module: torch.nn.Module, patches_and_weights: list[tuple[BaseLayerPatch, float]] | None = None - ): - super().__init__() - self._orig_module = orig_module - self._patches_and_weights = [] if patches_and_weights is None else patches_and_weights - - @property - def orig_module(self) -> torch.nn.Module: - return self._orig_module - - def add_patch(self, patch: BaseLayerPatch, patch_weight: float): - """Add a patch to the sidecar wrapper.""" - self._patches_and_weights.append((patch, patch_weight)) - - def _aggregate_patch_parameters( - self, patches_and_weights: list[tuple[BaseLayerPatch, float]] - ) -> dict[str, torch.Tensor]: - """Helper function that aggregates the parameters from all patches into a single dict.""" - params: dict[str, torch.Tensor] = {} - - for patch, patch_weight in patches_and_weights: - # TODO(ryand): self._orig_module could be quantized. Depending on what the patch is doing with the original - # module, this might fail or return incorrect results. - layer_params = patch.get_parameters(self._orig_module, weight=patch_weight) - - for param_name, param_weight in layer_params.items(): - if param_name not in params: - params[param_name] = param_weight - else: - params[param_name] += param_weight - - return params - - def forward(self, *args, **kwargs): # type: ignore - raise NotImplementedError() diff --git a/invokeai/backend/patches/sidecar_wrappers/conv1d_sidecar_wrapper.py b/invokeai/backend/patches/sidecar_wrappers/conv1d_sidecar_wrapper.py deleted file mode 100644 index 7877aae8c7..0000000000 --- a/invokeai/backend/patches/sidecar_wrappers/conv1d_sidecar_wrapper.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch - -from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper - - -class Conv1dSidecarWrapper(BaseSidecarWrapper): - def forward(self, input: torch.Tensor) -> torch.Tensor: - aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights) - return self.orig_module(input) + torch.nn.functional.conv1d( - input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None) - ) diff --git a/invokeai/backend/patches/sidecar_wrappers/conv2d_sidecar_wrapper.py b/invokeai/backend/patches/sidecar_wrappers/conv2d_sidecar_wrapper.py deleted file mode 100644 index d9bb713534..0000000000 --- a/invokeai/backend/patches/sidecar_wrappers/conv2d_sidecar_wrapper.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch - -from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper - - -class Conv2dSidecarWrapper(BaseSidecarWrapper): - def forward(self, input: torch.Tensor) -> torch.Tensor: - aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights) - return self.orig_module(input) + torch.nn.functional.conv1d( - input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None) - ) diff --git a/invokeai/backend/patches/sidecar_wrappers/flux_rms_norm_sidecar_wrapper.py b/invokeai/backend/patches/sidecar_wrappers/flux_rms_norm_sidecar_wrapper.py deleted file mode 100644 index 34c3b9b369..0000000000 --- a/invokeai/backend/patches/sidecar_wrappers/flux_rms_norm_sidecar_wrapper.py +++ /dev/null @@ -1,24 +0,0 @@ -import torch - -from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer -from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper - - -class FluxRMSNormSidecarWrapper(BaseSidecarWrapper): - """A sidecar wrapper for a FLUX RMSNorm layer. - - This wrapper is a special case. It is added specifically to enable FLUX structural control LoRAs, which overwrite - the RMSNorm scale parameters. - """ - - def forward(self, input: torch.Tensor) -> torch.Tensor: - # Given the narrow focus of this wrapper, we only support a very particular patch configuration: - assert len(self._patches_and_weights) == 1 - patch, _patch_weight = self._patches_and_weights[0] - assert isinstance(patch, SetParameterLayer) - assert patch.param_name == "scale" - - # Apply the patch. - # NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should - # be handled. - return torch.nn.functional.rms_norm(input, patch.weight.shape, patch.weight, eps=1e-6) diff --git a/invokeai/backend/patches/sidecar_wrappers/linear_sidecar_wrapper.py b/invokeai/backend/patches/sidecar_wrappers/linear_sidecar_wrapper.py deleted file mode 100644 index 98775b9feb..0000000000 --- a/invokeai/backend/patches/sidecar_wrappers/linear_sidecar_wrapper.py +++ /dev/null @@ -1,66 +0,0 @@ -import torch - -from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch -from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer -from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer -from invokeai.backend.patches.layers.lora_layer import LoRALayer -from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper - - -class LinearSidecarWrapper(BaseSidecarWrapper): - def _lora_forward(self, input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor: - """An optimized implementation of the residual calculation for a Linear LoRALayer.""" - x = torch.nn.functional.linear(input, lora_layer.down) - if lora_layer.mid is not None: - x = torch.nn.functional.linear(x, lora_layer.mid) - x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias) - x *= lora_weight * lora_layer.scale() - return x - - def _concatenated_lora_forward( - self, input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float - ) -> torch.Tensor: - """An optimized implementation of the residual calculation for a Linear ConcatenatedLoRALayer.""" - x_chunks: list[torch.Tensor] = [] - for lora_layer in concatenated_lora_layer.lora_layers: - x_chunk = torch.nn.functional.linear(input, lora_layer.down) - if lora_layer.mid is not None: - x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid) - x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias) - x_chunk *= lora_weight * lora_layer.scale() - x_chunks.append(x_chunk) - - # TODO(ryand): Generalize to support concat_axis != 0. - assert concatenated_lora_layer.concat_axis == 0 - x = torch.cat(x_chunks, dim=-1) - return x - - def forward(self, input: torch.Tensor) -> torch.Tensor: - # First, apply the original linear layer. - # NOTE: We slice the input to match the original weight shape in order to work with FluxControlLoRAs, which - # change the linear layer's in_features. - orig_input = input - input = orig_input[..., : self.orig_module.in_features] - output = self.orig_module(input) - - # Then, apply layers for which we have optimized implementations. - unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = [] - for patch, patch_weight in self._patches_and_weights: - if isinstance(patch, FluxControlLoRALayer): - # Note that we use the original input here, not the sliced input. - output += self._lora_forward(orig_input, patch, patch_weight) - elif isinstance(patch, LoRALayer): - output += self._lora_forward(input, patch, patch_weight) - elif isinstance(patch, ConcatenatedLoRALayer): - output += self._concatenated_lora_forward(input, patch, patch_weight) - else: - unprocessed_patches_and_weights.append((patch, patch_weight)) - - # Finally, apply any remaining patches. - if len(unprocessed_patches_and_weights) > 0: - aggregated_param_residuals = self._aggregate_patch_parameters(unprocessed_patches_and_weights) - output += torch.nn.functional.linear( - input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None) - ) - - return output diff --git a/invokeai/backend/patches/sidecar_wrappers/utils.py b/invokeai/backend/patches/sidecar_wrappers/utils.py deleted file mode 100644 index 6a71213b09..0000000000 --- a/invokeai/backend/patches/sidecar_wrappers/utils.py +++ /dev/null @@ -1,20 +0,0 @@ -import torch - -from invokeai.backend.flux.modules.layers import RMSNorm -from invokeai.backend.patches.sidecar_wrappers.conv1d_sidecar_wrapper import Conv1dSidecarWrapper -from invokeai.backend.patches.sidecar_wrappers.conv2d_sidecar_wrapper import Conv2dSidecarWrapper -from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper -from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper - - -def wrap_module_with_sidecar_wrapper(orig_module: torch.nn.Module) -> torch.nn.Module: - if isinstance(orig_module, torch.nn.Linear): - return LinearSidecarWrapper(orig_module) - elif isinstance(orig_module, torch.nn.Conv1d): - return Conv1dSidecarWrapper(orig_module) - elif isinstance(orig_module, torch.nn.Conv2d): - return Conv2dSidecarWrapper(orig_module) - elif isinstance(orig_module, RMSNorm): - return FluxRMSNormSidecarWrapper(orig_module) - else: - raise ValueError(f"No sidecar wrapper found for module type: {type(orig_module)}") diff --git a/invokeai/backend/quantization/gguf/ggml_tensor.py b/invokeai/backend/quantization/gguf/ggml_tensor.py index a9f5d68b76..62be2bdb63 100644 --- a/invokeai/backend/quantization/gguf/ggml_tensor.py +++ b/invokeai/backend/quantization/gguf/ggml_tensor.py @@ -48,11 +48,13 @@ GGML_TENSOR_OP_TABLE = { # Ops to run on the quantized tensor. torch.ops.aten.detach.default: apply_to_quantized_tensor, # pyright: ignore torch.ops.aten._to_copy.default: apply_to_quantized_tensor, # pyright: ignore + torch.ops.aten.clone.default: apply_to_quantized_tensor, # pyright: ignore # Ops to run on dequantized tensors. torch.ops.aten.t.default: dequantize_and_run, # pyright: ignore torch.ops.aten.addmm.default: dequantize_and_run, # pyright: ignore torch.ops.aten.mul.Tensor: dequantize_and_run, # pyright: ignore torch.ops.aten.add.Tensor: dequantize_and_run, # pyright: ignore + torch.ops.aten.allclose.default: dequantize_and_run, # pyright: ignore } if torch.backends.mps.is_available(): diff --git a/invokeai/backend/stable_diffusion/extensions/lora.py b/invokeai/backend/stable_diffusion/extensions/lora.py index 9e04f8e941..43986fad4d 100644 --- a/invokeai/backend/stable_diffusion/extensions/lora.py +++ b/invokeai/backend/stable_diffusion/extensions/lora.py @@ -5,8 +5,8 @@ from typing import TYPE_CHECKING from diffusers import UNet2DConditionModel +from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.model_patch_raw import ModelPatchRaw -from invokeai.backend.patches.model_patcher import LayerPatcher from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase if TYPE_CHECKING: @@ -31,12 +31,16 @@ class LoRAExt(ExtensionBase): def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage): lora_model = self._node_context.models.load(self._model_id).model assert isinstance(lora_model, ModelPatchRaw) - LayerPatcher.apply_model_patch( + LayerPatcher.apply_smart_model_patch( model=unet, prefix="lora_unet_", patch=lora_model, patch_weight=self._weight, original_weights=original_weights, + original_modules={}, + dtype=unet.dtype, + force_direct_patching=True, + force_sidecar_patching=False, ) del lora_model diff --git a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py index e3c99d0c34..4fae046cf8 100644 --- a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py +++ b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py @@ -1,18 +1,27 @@ import itertools +import pytest import torch from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( CachedModelWithPartialLoad, ) -from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import CustomLinear +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + apply_custom_layers_to_model, +) from invokeai.backend.util.calc_tensor_size import calc_tensor_size from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda -@parameterize_mps_and_cuda -def test_cached_model_total_bytes(device: str): +@pytest.fixture +def model(): model = DummyModule() + apply_custom_layers_to_model(model) + return model + + +@parameterize_mps_and_cuda +def test_cached_model_total_bytes(device: str, model: DummyModule): cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) linear1_numel = 10 * 32 + 32 linear2_numel = 32 * 64 + 64 @@ -22,8 +31,7 @@ def test_cached_model_total_bytes(device: str): @parameterize_mps_and_cuda -def test_cached_model_cur_vram_bytes(device: str): - model = DummyModule() +def test_cached_model_cur_vram_bytes(device: str, model: DummyModule): # Model starts in CPU memory. cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) assert cached_model.cur_vram_bytes() == 0 @@ -37,8 +45,7 @@ def test_cached_model_cur_vram_bytes(device: str): @parameterize_mps_and_cuda -def test_cached_model_partial_load(device: str): - model = DummyModule() +def test_cached_model_partial_load(device: str, model: DummyModule): # Model starts in CPU memory. cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) model_total_bytes = cached_model.total_bytes() @@ -58,14 +65,13 @@ def test_cached_model_partial_load(device: str): if p.device.type == device and n != "buffer2" ) - # Check that the model's modules have been patched with CustomLinear layers. - assert type(model.linear1) is CustomLinear - assert type(model.linear2) is CustomLinear + # Check that the model's modules have device autocasting enabled. + assert model.linear1.is_device_autocasting_enabled() + assert model.linear2.is_device_autocasting_enabled() @parameterize_mps_and_cuda -def test_cached_model_partial_unload(device: str): - model = DummyModule() +def test_cached_model_partial_unload(device: str, model: DummyModule): # Model starts in CPU memory. cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) model_total_bytes = cached_model.total_bytes() @@ -87,14 +93,13 @@ def test_cached_model_partial_unload(device: str): calc_tensor_size(p) for p in itertools.chain(model.parameters(), model.buffers()) if p.device.type == "cpu" ) - # Check that the model's modules are still patched with CustomLinear layers. - assert type(model.linear1) is CustomLinear - assert type(model.linear2) is CustomLinear + # Check that the model's modules still have device autocasting enabled. + assert model.linear1.is_device_autocasting_enabled() + assert model.linear2.is_device_autocasting_enabled() @parameterize_mps_and_cuda -def test_cached_model_full_load_and_unload(device: str): - model = DummyModule() +def test_cached_model_full_load_and_unload(device: str, model: DummyModule): cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) # Model starts in CPU memory. @@ -107,8 +112,8 @@ def test_cached_model_full_load_and_unload(device: str): assert loaded_bytes == model_total_bytes assert loaded_bytes == cached_model.cur_vram_bytes() assert all(p.device.type == device for p in itertools.chain(model.parameters(), model.buffers())) - assert type(model.linear1) is torch.nn.Linear - assert type(model.linear2) is torch.nn.Linear + assert not model.linear1.is_device_autocasting_enabled() + assert not model.linear2.is_device_autocasting_enabled() # Full unload the model from VRAM. unloaded_bytes = cached_model.full_unload_from_vram() @@ -126,8 +131,7 @@ def test_cached_model_full_load_and_unload(device: str): @parameterize_mps_and_cuda -def test_cached_model_full_load_from_partial(device: str): - model = DummyModule() +def test_cached_model_full_load_from_partial(device: str, model: DummyModule): cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) # Model starts in CPU memory. @@ -140,8 +144,8 @@ def test_cached_model_full_load_from_partial(device: str): assert loaded_bytes > 0 assert loaded_bytes < model_total_bytes assert loaded_bytes == cached_model.cur_vram_bytes() - assert type(model.linear1) is CustomLinear - assert type(model.linear2) is CustomLinear + assert model.linear1.is_device_autocasting_enabled() + assert model.linear2.is_device_autocasting_enabled() # Full load the rest of the model into VRAM. loaded_bytes_2 = cached_model.full_load_to_vram() @@ -150,13 +154,12 @@ def test_cached_model_full_load_from_partial(device: str): assert loaded_bytes + loaded_bytes_2 == cached_model.cur_vram_bytes() assert loaded_bytes + loaded_bytes_2 == model_total_bytes assert all(p.device.type == device for p in itertools.chain(model.parameters(), model.buffers())) - assert type(model.linear1) is torch.nn.Linear - assert type(model.linear2) is torch.nn.Linear + assert not model.linear1.is_device_autocasting_enabled() + assert not model.linear2.is_device_autocasting_enabled() @parameterize_mps_and_cuda -def test_cached_model_full_unload_from_partial(device: str): - model = DummyModule() +def test_cached_model_full_unload_from_partial(device: str, model: DummyModule): cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) # Model starts in CPU memory. @@ -184,8 +187,7 @@ def test_cached_model_full_unload_from_partial(device: str): @parameterize_mps_and_cuda -def test_cached_model_get_cpu_state_dict(device: str): - model = DummyModule() +def test_cached_model_get_cpu_state_dict(device: str, model: DummyModule): cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) # Model starts in CPU memory. @@ -209,8 +211,7 @@ def test_cached_model_get_cpu_state_dict(device: str): @parameterize_mps_and_cuda -def test_cached_model_full_load_and_inference(device: str): - model = DummyModule() +def test_cached_model_full_load_and_inference(device: str, model: DummyModule): cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) # Model starts in CPU memory. model_total_bytes = cached_model.total_bytes() @@ -237,8 +238,7 @@ def test_cached_model_full_load_and_inference(device: str): @parameterize_mps_and_cuda -def test_cached_model_partial_load_and_inference(device: str): - model = DummyModule() +def test_cached_model_partial_load_and_inference(device: str, model: DummyModule): # Model starts in CPU memory. cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) model_total_bytes = cached_model.total_bytes() @@ -262,9 +262,9 @@ def test_cached_model_partial_load_and_inference(device: str): for n, p in itertools.chain(model.named_parameters(), model.named_buffers()) if p.device.type == device and n != "buffer2" ) - # Check that the model's modules have been patched with CustomLinear layers. - assert type(model.linear1) is CustomLinear - assert type(model.linear2) is CustomLinear + # Check that the model's modules have device autocasting enabled. + assert model.linear1.is_device_autocasting_enabled() + assert model.linear2.is_device_autocasting_enabled() # Run inference on the GPU. output2 = model(x.to(device)) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py new file mode 100644 index 0000000000..9706277234 --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py @@ -0,0 +1,530 @@ +import copy + +import gguf +import pytest +import torch + +from invokeai.backend.flux.modules.layers import RMSNorm +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + AUTOCAST_MODULE_TYPE_MAPPING, + AUTOCAST_MODULE_TYPE_MAPPING_INVERSE, + unwrap_custom_layer, + wrap_custom_layer, +) +from invokeai.backend.patches.layer_patcher import LayerPatcher +from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch +from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer +from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer +from invokeai.backend.patches.layers.lora_layer import LoRALayer +from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage +from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_8_bit_lt import ( + build_linear_8bit_lt_layer, +) +from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_nf4 import ( + build_linear_nf4_layer, +) +from tests.backend.quantization.gguf.test_ggml_tensor import quantize_tensor + + +def build_linear_layer_with_ggml_quantized_tensor(orig_layer: torch.nn.Linear | None = None): + if orig_layer is None: + orig_layer = torch.nn.Linear(32, 64) + + ggml_quantized_weight = quantize_tensor(orig_layer.weight, gguf.GGMLQuantizationType.Q8_0) + orig_layer.weight = torch.nn.Parameter(ggml_quantized_weight) + ggml_quantized_bias = quantize_tensor(orig_layer.bias, gguf.GGMLQuantizationType.Q8_0) + orig_layer.bias = torch.nn.Parameter(ggml_quantized_bias) + return orig_layer + + +parameterize_all_devices = pytest.mark.parametrize( + ("device"), + [ + pytest.param("cpu"), + pytest.param( + "mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.") + ), + pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")), + ], +) + +parameterize_cuda_and_mps = pytest.mark.parametrize( + ("device"), + [ + pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")), + pytest.param( + "mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.") + ), + ], +) + + +LayerUnderTest = tuple[torch.nn.Module, torch.Tensor, bool] + + +@pytest.fixture( + params=[ + "linear", + "conv1d", + "conv2d", + "group_norm", + "embedding", + "flux_rms_norm", + "linear_with_ggml_quantized_tensor", + "invoke_linear_8_bit_lt", + "invoke_linear_nf4", + ] +) +def layer_under_test(request: pytest.FixtureRequest) -> LayerUnderTest: + """A fixture that returns a tuple of (layer, input, supports_cpu_inference) for the layer under test.""" + layer_type = request.param + if layer_type == "linear": + return (torch.nn.Linear(8, 16), torch.randn(1, 8), True) + elif layer_type == "conv1d": + return (torch.nn.Conv1d(8, 16, 3), torch.randn(1, 8, 5), True) + elif layer_type == "conv2d": + return (torch.nn.Conv2d(8, 16, 3), torch.randn(1, 8, 5, 5), True) + elif layer_type == "group_norm": + return (torch.nn.GroupNorm(2, 8), torch.randn(1, 8, 5), True) + elif layer_type == "embedding": + return (torch.nn.Embedding(4, 8), torch.tensor([0, 1], dtype=torch.long), True) + elif layer_type == "flux_rms_norm": + return (RMSNorm(8), torch.randn(1, 8), True) + elif layer_type == "linear_with_ggml_quantized_tensor": + return (build_linear_layer_with_ggml_quantized_tensor(), torch.randn(1, 32), True) + elif layer_type == "invoke_linear_8_bit_lt": + return (build_linear_8bit_lt_layer(), torch.randn(1, 32), False) + elif layer_type == "invoke_linear_nf4": + return (build_linear_nf4_layer(), torch.randn(1, 64), False) + else: + raise ValueError(f"Unsupported layer_type: {layer_type}") + + +def layer_to_device_via_state_dict(layer: torch.nn.Module, device: str): + """A helper function to move a layer to a device by roundtripping through a state dict. This most closely matches + how models are moved in the app. Some of the quantization types have broken semantics around calling .to() on the + layer directly, so this is a workaround. + + We should fix this in the future. + Relevant article: https://pytorch.org/tutorials/recipes/recipes/swap_tensors.html + """ + state_dict = layer.state_dict() + state_dict = {k: v.to(device) for k, v in state_dict.items()} + layer.load_state_dict(state_dict, assign=True) + + +def wrap_single_custom_layer(layer: torch.nn.Module): + custom_layer_type = AUTOCAST_MODULE_TYPE_MAPPING[type(layer)] + return wrap_custom_layer(layer, custom_layer_type) + + +def unwrap_single_custom_layer(layer: torch.nn.Module): + orig_layer_type = AUTOCAST_MODULE_TYPE_MAPPING_INVERSE[type(layer)] + return unwrap_custom_layer(layer, orig_layer_type) + + +def test_isinstance(layer_under_test: LayerUnderTest): + """Test that isinstance() and type() behave as expected after wrapping a layer in a custom layer.""" + orig_layer, _, _ = layer_under_test + orig_type = type(orig_layer) + + custom_layer = wrap_single_custom_layer(orig_layer) + + assert isinstance(custom_layer, orig_type) + assert type(custom_layer) is not orig_type + + +def test_wrap_and_unwrap(layer_under_test: LayerUnderTest): + """Test that wrapping and unwrapping a layer behaves as expected.""" + orig_layer, _, _ = layer_under_test + orig_type = type(orig_layer) + + # Wrap the original layer and assert that attributes of the custom layer can be accessed. + custom_layer = wrap_single_custom_layer(orig_layer) + custom_layer.set_device_autocasting_enabled(True) + assert custom_layer._device_autocasting_enabled + + # Unwrap the custom layer. + # Assert that the methods of the wrapped layer are no longer accessible. + unwrapped_layer = unwrap_single_custom_layer(custom_layer) + with pytest.raises(AttributeError): + _ = unwrapped_layer.set_device_autocasting_enabled(True) + # For now, we have chosen to allow attributes to persist. We may revisit this in the future. + assert unwrapped_layer._device_autocasting_enabled + assert type(unwrapped_layer) is orig_type + + +@parameterize_all_devices +def test_state_dict(device: str, layer_under_test: LayerUnderTest): + """Test that .state_dict() behaves the same on the original layer and the wrapped layer.""" + orig_layer, _, _ = layer_under_test + + # Get the original layer on the test device. + orig_layer.to(device) + orig_state_dict = orig_layer.state_dict() + + # Wrap the original layer. + custom_layer = copy.deepcopy(orig_layer) + custom_layer = wrap_single_custom_layer(custom_layer) + + custom_state_dict = custom_layer.state_dict() + + assert set(orig_state_dict.keys()) == set(custom_state_dict.keys()) + for k in orig_state_dict: + assert orig_state_dict[k].shape == custom_state_dict[k].shape + assert orig_state_dict[k].dtype == custom_state_dict[k].dtype + assert orig_state_dict[k].device == custom_state_dict[k].device + assert torch.allclose(orig_state_dict[k], custom_state_dict[k]) + + +@parameterize_all_devices +def test_load_state_dict(device: str, layer_under_test: LayerUnderTest): + """Test that .load_state_dict() behaves the same on the original layer and the wrapped layer.""" + orig_layer, _, _ = layer_under_test + + orig_layer.to(device) + + custom_layer = copy.deepcopy(orig_layer) + custom_layer = wrap_single_custom_layer(custom_layer) + + # Do a state dict roundtrip. + orig_state_dict = orig_layer.state_dict() + custom_state_dict = custom_layer.state_dict() + + orig_layer.load_state_dict(custom_state_dict, assign=True) + custom_layer.load_state_dict(orig_state_dict, assign=True) + + orig_state_dict = orig_layer.state_dict() + custom_state_dict = custom_layer.state_dict() + + # Assert that the state dicts are the same after the roundtrip. + assert set(orig_state_dict.keys()) == set(custom_state_dict.keys()) + for k in orig_state_dict: + assert orig_state_dict[k].shape == custom_state_dict[k].shape + assert orig_state_dict[k].dtype == custom_state_dict[k].dtype + assert orig_state_dict[k].device == custom_state_dict[k].device + assert torch.allclose(orig_state_dict[k], custom_state_dict[k]) + + +@parameterize_all_devices +def test_inference_on_device(device: str, layer_under_test: LayerUnderTest): + """Test that inference behaves the same on the original layer and the wrapped layer when all weights are on the + device. + """ + orig_layer, layer_input, supports_cpu_inference = layer_under_test + + if device == "cpu" and not supports_cpu_inference: + pytest.skip("Layer does not support CPU inference.") + + layer_to_device_via_state_dict(orig_layer, device) + + custom_layer = copy.deepcopy(orig_layer) + custom_layer = wrap_single_custom_layer(custom_layer) + + # Run inference with the original layer. + x = layer_input.to(device) + orig_output = orig_layer(x) + + # Run inference with the wrapped layer. + custom_output = custom_layer(x) + + assert torch.allclose(orig_output, custom_output) + + +@parameterize_cuda_and_mps +def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: LayerUnderTest): + """Test that inference behaves the same on the original layer and the wrapped layer when all weights are on the + device. + """ + orig_layer, layer_input, supports_cpu_inference = layer_under_test + + if device == "cpu" and not supports_cpu_inference: + pytest.skip("Layer does not support CPU inference.") + + # Make sure the original layer is on the device. + layer_to_device_via_state_dict(orig_layer, device) + + x = layer_input.to(device) + + # Run inference with the original layer on the device. + orig_output = orig_layer(x) + + # Move the original layer to the CPU. + layer_to_device_via_state_dict(orig_layer, "cpu") + + # Inference should fail with an input on the device. + with pytest.raises(RuntimeError): + _ = orig_layer(x) + + # Wrap the original layer. + custom_layer = copy.deepcopy(orig_layer) + custom_layer = wrap_single_custom_layer(custom_layer) + + # Inference should still fail with autocasting disabled. + custom_layer.set_device_autocasting_enabled(False) + with pytest.raises(RuntimeError): + _ = custom_layer(x) + + # Run inference with the wrapped layer on the device. + custom_layer.set_device_autocasting_enabled(True) + custom_output = custom_layer(x) + assert custom_output.device.type == device + + assert torch.allclose(orig_output, custom_output) + + +PatchUnderTest = tuple[list[tuple[BaseLayerPatch, float]], torch.Tensor] + + +@pytest.fixture( + params=[ + "single_lora", + "multiple_loras", + "concatenated_lora", + "flux_control_lora", + ] +) +def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest: + """A fixture that returns a tuple of (patches, input) for the patch under test.""" + layer_type = request.param + torch.manual_seed(0) + + # The assumed in/out features of the base linear layer. + in_features = 32 + out_features = 64 + + rank = 4 + + if layer_type == "single_lora": + lora_layer = LoRALayer( + up=torch.randn(out_features, rank), + mid=None, + down=torch.randn(rank, in_features), + alpha=1.0, + bias=torch.randn(out_features), + ) + input = torch.randn(1, in_features) + return ([(lora_layer, 0.7)], input) + elif layer_type == "multiple_loras": + lora_layer = LoRALayer( + up=torch.randn(out_features, rank), + mid=None, + down=torch.randn(rank, in_features), + alpha=1.0, + bias=torch.randn(out_features), + ) + lora_layer_2 = LoRALayer( + up=torch.randn(out_features, rank), + mid=None, + down=torch.randn(rank, in_features), + alpha=1.0, + bias=torch.randn(out_features), + ) + + input = torch.randn(1, in_features) + return ([(lora_layer, 1.0), (lora_layer_2, 0.5)], input) + elif layer_type == "concatenated_lora": + sub_layer_out_features = [16, 16, 32] + + # Create a ConcatenatedLoRA layer. + sub_layers: list[LoRALayer] = [] + for out_features in sub_layer_out_features: + down = torch.randn(rank, in_features) + up = torch.randn(out_features, rank) + bias = torch.randn(out_features) + sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)) + concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0) + + input = torch.randn(1, in_features) + return ([(concatenated_lora_layer, 0.7)], input) + elif layer_type == "flux_control_lora": + # Create a FluxControlLoRALayer. + patched_in_features = 40 + lora_layer = FluxControlLoRALayer( + up=torch.randn(out_features, rank), + mid=None, + down=torch.randn(rank, patched_in_features), + alpha=1.0, + bias=torch.randn(out_features), + ) + + input = torch.randn(1, patched_in_features) + return ([(lora_layer, 0.7)], input) + else: + raise ValueError(f"Unsupported layer_type: {layer_type}") + + +@parameterize_all_devices +def test_linear_sidecar_patches(device: str, patch_under_test: PatchUnderTest): + patches, input = patch_under_test + + # Build the base layer under test. + layer = torch.nn.Linear(32, 64) + + # Move the layer and input to the device. + layer_to_device_via_state_dict(layer, device) + input = input.to(torch.device(device)) + + # Patch the LoRA layer into the linear layer. + layer_patched = copy.deepcopy(layer) + for patch, weight in patches: + LayerPatcher._apply_model_layer_patch( + module_to_patch=layer_patched, + module_to_patch_key="", + patch=patch, + patch_weight=weight, + original_weights=OriginalWeightsStorage(), + ) + + # Wrap the original layer in a custom layer and add the patch to it as a sidecar. + custom_layer = wrap_single_custom_layer(layer) + for patch, weight in patches: + patch.to(torch.device(device)) + custom_layer.add_patch(patch, weight) + + # Run inference with the original layer and the patched layer and assert they are equal. + output_patched = layer_patched(input) + output_custom = custom_layer(input) + assert torch.allclose(output_patched, output_custom, atol=1e-6) + + +@parameterize_cuda_and_mps +def test_linear_sidecar_patches_with_autocast_from_cpu_to_device(device: str, patch_under_test: PatchUnderTest): + """Test that the output of a linear layer with sidecar patches is the same when the layer is on the device and + when the layer is on the CPU and the patches are autocasted to the device. + """ + patches, input = patch_under_test + + # Build the base layer under test. + layer = torch.nn.Linear(32, 64) + + # Move the layer and input to the device. + layer_to_device_via_state_dict(layer, device) + input = input.to(torch.device(device)) + + # Wrap the original layer in a custom layer and add the patch to it. + custom_layer = wrap_single_custom_layer(layer) + for patch, weight in patches: + patch.to(torch.device(device)) + custom_layer.add_patch(patch, weight) + + # Run inference with the custom layer on the device. + expected_output = custom_layer(input) + + # Move the custom layer to the CPU. + layer_to_device_via_state_dict(custom_layer, "cpu") + + # Move the patches to the CPU. + custom_layer.clear_patches() + for patch, weight in patches: + patch.to(torch.device("cpu")) + custom_layer.add_patch(patch, weight) + + # Run inference with an input on the device, and all layer weights on the CPU. The weights should be autocasted to + # the device. + autocast_output = custom_layer(input) + assert autocast_output.device.type == device + + # Assert that the outputs with and without autocasting are the same. + assert torch.allclose(expected_output, autocast_output, atol=1e-6) + + +@pytest.fixture( + params=[ + "linear_ggml_quantized", + "invoke_linear_8_bit_lt", + "invoke_linear_nf4", + ] +) +def quantized_linear_layer_under_test(request: pytest.FixtureRequest): + in_features = 32 + out_features = 64 + torch.manual_seed(0) + layer_type = request.param + orig_layer = torch.nn.Linear(in_features, out_features) + if layer_type == "linear_ggml_quantized": + return orig_layer, build_linear_layer_with_ggml_quantized_tensor(orig_layer) + elif layer_type == "invoke_linear_8_bit_lt": + return orig_layer, build_linear_8bit_lt_layer(orig_layer) + elif layer_type == "invoke_linear_nf4": + return orig_layer, build_linear_nf4_layer(orig_layer) + else: + raise ValueError(f"Unsupported layer_type: {layer_type}") + + +@parameterize_cuda_and_mps +def test_quantized_linear_sidecar_patches( + device: str, + quantized_linear_layer_under_test: tuple[torch.nn.Module, torch.nn.Module], + patch_under_test: PatchUnderTest, +): + """Test that patches can be applied to quantized linear layers and that the output is the same as when the patch is + applied to a non-quantized linear layer. + """ + patches, input = patch_under_test + + linear_layer, quantized_linear_layer = quantized_linear_layer_under_test + + # Move everything to the device. + layer_to_device_via_state_dict(linear_layer, device) + layer_to_device_via_state_dict(quantized_linear_layer, device) + input = input.to(torch.device(device)) + + # Wrap both layers in custom layers. + linear_layer_custom = wrap_single_custom_layer(linear_layer) + quantized_linear_layer_custom = wrap_single_custom_layer(quantized_linear_layer) + + # Apply the patches to the custom layers. + for patch, weight in patches: + patch.to(torch.device(device)) + linear_layer_custom.add_patch(patch, weight) + quantized_linear_layer_custom.add_patch(patch, weight) + + # Run inference with the original layer and the patched layer and assert they are equal. + output_linear_patched = linear_layer_custom(input) + output_quantized_patched = quantized_linear_layer_custom(input) + assert torch.allclose(output_linear_patched, output_quantized_patched, rtol=0.2, atol=0.2) + + +@parameterize_cuda_and_mps +def test_quantized_linear_sidecar_patches_with_autocast_from_cpu_to_device( + device: str, + quantized_linear_layer_under_test: tuple[torch.nn.Module, torch.nn.Module], + patch_under_test: PatchUnderTest, +): + """Test that the output of a linear layer with sidecar patches is the same when the layer is on the device and + when the layer is on the CPU and the patches are autocasted to the device. + """ + patches, input = patch_under_test + + _, quantized_linear_layer = quantized_linear_layer_under_test + + # Move everything to the device. + layer_to_device_via_state_dict(quantized_linear_layer, device) + input = input.to(torch.device(device)) + + # Wrap the quantized linear layer in a custom layer and add the patch to it. + quantized_linear_layer_custom = wrap_single_custom_layer(quantized_linear_layer) + for patch, weight in patches: + patch.to(torch.device(device)) + quantized_linear_layer_custom.add_patch(patch, weight) + + # Run inference with the custom layer on the device. + expected_output = quantized_linear_layer_custom(input) + + # Move the custom layer to the CPU. + layer_to_device_via_state_dict(quantized_linear_layer_custom, "cpu") + + # Move the patches to the CPU. + quantized_linear_layer_custom.clear_patches() + for patch, weight in patches: + patch.to(torch.device("cpu")) + quantized_linear_layer_custom.add_patch(patch, weight) + + # Run inference with an input on the device, and all layer weights on the CPU. The weights should be autocasted to + # the device. + autocast_output = quantized_linear_layer_custom(input) + assert autocast_output.device.type == device + + # Assert that the outputs with and without autocasting are the same. + assert torch.allclose(expected_output, autocast_output, atol=1e-6) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_flux_rms_norm.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_flux_rms_norm.py new file mode 100644 index 0000000000..05e15302d5 --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_flux_rms_norm.py @@ -0,0 +1,31 @@ +import torch + +from invokeai.backend.flux.modules.layers import RMSNorm +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_flux_rms_norm import ( + CustomFluxRMSNorm, +) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + wrap_custom_layer, +) +from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer + + +def test_custom_flux_rms_norm_patch(): + """Test a SetParameterLayer patch on a CustomFluxRMSNorm layer.""" + # Create a RMSNorm layer. + dim = 8 + rms_norm = RMSNorm(dim) + + # Create a SetParameterLayer. + new_scale = torch.randn(dim) + set_parameter_layer = SetParameterLayer("scale", new_scale) + + # Wrap the RMSNorm layer in a CustomFluxRMSNorm layer. + custom_flux_rms_norm = wrap_custom_layer(rms_norm, CustomFluxRMSNorm) + custom_flux_rms_norm.add_patch(set_parameter_layer, 1.0) + + # Run the CustomFluxRMSNorm layer. + input = torch.randn(1, dim) + expected_output = torch.nn.functional.rms_norm(input, new_scale.shape, new_scale, eps=1e-6) + output_custom = custom_flux_rms_norm(input) + assert torch.allclose(output_custom, expected_output, atol=1e-6) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py new file mode 100644 index 0000000000..9a225267fb --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py @@ -0,0 +1,82 @@ +import pytest +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + wrap_custom_layer, +) + +if not torch.cuda.is_available(): + pytest.skip("CUDA is not available", allow_module_level=True) +else: + from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_8_bit_lt import ( + CustomInvokeLinear8bitLt, + ) + from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt + + +def build_linear_8bit_lt_layer(orig_layer: torch.nn.Linear | None = None): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + torch.manual_seed(1) + + if orig_layer is None: + orig_layer = torch.nn.Linear(32, 64) + orig_layer_state_dict = orig_layer.state_dict() + + # Prepare a quantized InvokeLinear8bitLt layer. + quantized_layer = InvokeLinear8bitLt( + input_features=orig_layer.in_features, output_features=orig_layer.out_features, has_fp16_weights=False + ) + quantized_layer.load_state_dict(orig_layer_state_dict) + quantized_layer.to("cuda") + + # Assert that the InvokeLinear8bitLt layer is quantized. + assert quantized_layer.weight.CB is not None + assert quantized_layer.weight.SCB is not None + assert quantized_layer.weight.CB.dtype == torch.int8 + + return quantized_layer + + +@pytest.fixture +def linear_8bit_lt_layer(): + return build_linear_8bit_lt_layer() + + +def test_custom_invoke_linear_8bit_lt_all_weights_on_cuda(linear_8bit_lt_layer: InvokeLinear8bitLt): + """Test CustomInvokeLinear8bitLt inference with all weights on the GPU.""" + # Run inference on the original layer. + x = torch.randn(1, 32).to("cuda") + y_quantized = linear_8bit_lt_layer(x) + + # Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it. + custom_linear_8bit_lt_layer = wrap_custom_layer(linear_8bit_lt_layer, CustomInvokeLinear8bitLt) + y_custom = custom_linear_8bit_lt_layer(x) + + # Assert that the quantized and custom layers produce the same output. + assert torch.allclose(y_quantized, y_custom, atol=1e-5) + + +def test_custom_invoke_linear_8bit_lt_all_weights_on_cpu(linear_8bit_lt_layer: InvokeLinear8bitLt): + """Test CustomInvokeLinear8bitLt inference with all weights on the CPU (streaming to the GPU).""" + # Run inference on the original layer. + x = torch.randn(1, 32).to("cuda") + y_quantized = linear_8bit_lt_layer(x) + + # Copy the state dict to the CPU and reload it. + state_dict = linear_8bit_lt_layer.state_dict() + state_dict = {k: v.to("cpu") for k, v in state_dict.items()} + linear_8bit_lt_layer.load_state_dict(state_dict) + + # Inference of the original layer should fail. + with pytest.raises(RuntimeError): + linear_8bit_lt_layer(x) + + # Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it. + custom_linear_8bit_lt_layer = wrap_custom_layer(linear_8bit_lt_layer, CustomInvokeLinear8bitLt) + custom_linear_8bit_lt_layer.set_device_autocasting_enabled(True) + y_custom = custom_linear_8bit_lt_layer(x) + + # Assert that the quantized and custom layers produce the same output. + assert torch.allclose(y_quantized, y_custom, atol=1e-5) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py new file mode 100644 index 0000000000..f97404fb94 --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py @@ -0,0 +1,92 @@ +import pytest +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + wrap_custom_layer, +) + +if not torch.cuda.is_available(): + pytest.skip("CUDA is not available", allow_module_level=True) +else: + from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_nf4 import ( + CustomInvokeLinearNF4, + ) + from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4 + + +def build_linear_nf4_layer(orig_layer: torch.nn.Linear | None = None): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + torch.manual_seed(1) + + if orig_layer is None: + orig_layer = torch.nn.Linear(64, 16) + + orig_layer_state_dict = orig_layer.state_dict() + + # Prepare a quantized InvokeLinearNF4 layer. + quantized_layer = InvokeLinearNF4(input_features=orig_layer.in_features, output_features=orig_layer.out_features) + quantized_layer.load_state_dict(orig_layer_state_dict) + quantized_layer.to("cuda") + + # Assert that the InvokeLinearNF4 layer is quantized. + assert quantized_layer.weight.bnb_quantized + + return quantized_layer + + +@pytest.fixture +def linear_nf4_layer(): + return build_linear_nf4_layer() + + +def test_custom_invoke_linear_nf4_all_weights_on_cuda(linear_nf4_layer: InvokeLinearNF4): + """Test CustomInvokeLinearNF4 inference with all weights on the GPU.""" + # Run inference on the original layer. + x = torch.randn(1, 64).to("cuda") + y_quantized = linear_nf4_layer(x) + + # Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it. + custom_linear_nf4_layer = wrap_custom_layer(linear_nf4_layer, CustomInvokeLinearNF4) + custom_linear_nf4_layer.set_device_autocasting_enabled(True) + y_custom = custom_linear_nf4_layer(x) + + # Assert that the quantized and custom layers produce the same output. + assert torch.allclose(y_quantized, y_custom, atol=1e-5) + + +# We run with two different input dimensions, because the NF4 layer follows a different code path depending on the +# input dimension, and this has caused issues in the past. +@pytest.mark.parametrize("input_dim_0", [1, 2]) +def test_custom_invoke_linear_nf4_all_weights_on_cpu(linear_nf4_layer: InvokeLinearNF4, input_dim_0: int): + """Test CustomInvokeLinearNF4 inference with all weights on the CPU (streaming to the GPU).""" + # Run inference on the original layer. + x = torch.randn(input_dim_0, 64).to(device="cuda") + y_quantized = linear_nf4_layer(x) + + # Copy the state dict to the CPU and reload it. + state_dict = linear_nf4_layer.state_dict() + state_dict = {k: v.to("cpu") for k, v in state_dict.items()} + linear_nf4_layer.load_state_dict(state_dict) + + # Inference of the original layer should fail. + with pytest.raises(RuntimeError): + linear_nf4_layer(x) + + # Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it. + custom_linear_nf4_layer = wrap_custom_layer(linear_nf4_layer, CustomInvokeLinearNF4) + custom_linear_nf4_layer.set_device_autocasting_enabled(True) + y_custom = custom_linear_nf4_layer(x) + + # Assert that the state dict (and the tensors that it references) are still on the CPU. + assert all(v.device == torch.device("cpu") for v in state_dict.values()) + + # Assert that the weight, bias, and quant_state are all on the CPU. + assert custom_linear_nf4_layer.weight.device == torch.device("cpu") + assert custom_linear_nf4_layer.bias.device == torch.device("cpu") + assert custom_linear_nf4_layer.weight.quant_state.absmax.device == torch.device("cpu") + assert custom_linear_nf4_layer.weight.quant_state.code.device == torch.device("cpu") + + # Assert that the quantized and custom layers produce the same output. + assert torch.allclose(y_quantized, y_custom, atol=1e-5) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py deleted file mode 100644 index 38fa467c60..0000000000 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py +++ /dev/null @@ -1,144 +0,0 @@ -import pytest -import torch - -if not torch.cuda.is_available(): - pytest.skip("CUDA is not available", allow_module_level=True) -else: - from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_8_bit_lt import ( - CustomInvokeLinear8bitLt, - ) - from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_nf4 import ( - CustomInvokeLinearNF4, - ) - from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt - from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4 - - -@pytest.fixture -def linear_8bit_lt_layer(): - if not torch.cuda.is_available(): - pytest.skip("CUDA is not available") - - torch.manual_seed(1) - - orig_layer = torch.nn.Linear(32, 64) - orig_layer_state_dict = orig_layer.state_dict() - - # Prepare a quantized InvokeLinear8bitLt layer. - quantized_layer = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False) - quantized_layer.load_state_dict(orig_layer_state_dict) - quantized_layer.to("cuda") - - # Assert that the InvokeLinear8bitLt layer is quantized. - assert quantized_layer.weight.CB is not None - assert quantized_layer.weight.SCB is not None - assert quantized_layer.weight.CB.dtype == torch.int8 - - return quantized_layer - - -def test_custom_invoke_linear_8bit_lt_all_weights_on_cuda(linear_8bit_lt_layer: InvokeLinear8bitLt): - """Test CustomInvokeLinear8bitLt inference with all weights on the GPU.""" - # Run inference on the original layer. - x = torch.randn(1, 32).to("cuda") - y_quantized = linear_8bit_lt_layer(x) - - # Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it. - linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt - y_custom = linear_8bit_lt_layer(x) - - # Assert that the quantized and custom layers produce the same output. - assert torch.allclose(y_quantized, y_custom, atol=1e-5) - - -def test_custom_invoke_linear_8bit_lt_all_weights_on_cpu(linear_8bit_lt_layer: InvokeLinear8bitLt): - """Test CustomInvokeLinear8bitLt inference with all weights on the CPU (streaming to the GPU).""" - # Run inference on the original layer. - x = torch.randn(1, 32).to("cuda") - y_quantized = linear_8bit_lt_layer(x) - - # Copy the state dict to the CPU and reload it. - state_dict = linear_8bit_lt_layer.state_dict() - state_dict = {k: v.to("cpu") for k, v in state_dict.items()} - linear_8bit_lt_layer.load_state_dict(state_dict) - - # Inference of the original layer should fail. - with pytest.raises(RuntimeError): - linear_8bit_lt_layer(x) - - # Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it. - linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt - y_custom = linear_8bit_lt_layer(x) - - # Assert that the quantized and custom layers produce the same output. - assert torch.allclose(y_quantized, y_custom, atol=1e-5) - - -@pytest.fixture -def linear_nf4_layer(): - if not torch.cuda.is_available(): - pytest.skip("CUDA is not available") - - torch.manual_seed(1) - - orig_layer = torch.nn.Linear(64, 16) - orig_layer_state_dict = orig_layer.state_dict() - - # Prepare a quantized InvokeLinearNF4 layer. - quantized_layer = InvokeLinearNF4(input_features=64, output_features=16) - quantized_layer.load_state_dict(orig_layer_state_dict) - quantized_layer.to("cuda") - - # Assert that the InvokeLinearNF4 layer is quantized. - assert quantized_layer.weight.bnb_quantized - - return quantized_layer - - -def test_custom_invoke_linear_nf4_all_weights_on_cuda(linear_nf4_layer: InvokeLinearNF4): - """Test CustomInvokeLinearNF4 inference with all weights on the GPU.""" - # Run inference on the original layer. - x = torch.randn(1, 64).to("cuda") - y_quantized = linear_nf4_layer(x) - - # Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it. - linear_nf4_layer.__class__ = CustomInvokeLinearNF4 - y_custom = linear_nf4_layer(x) - - # Assert that the quantized and custom layers produce the same output. - assert torch.allclose(y_quantized, y_custom, atol=1e-5) - - -# We run with two different input dimensions, because the NF4 layer follows a different code path depending on the -# input dimension, and this has caused issues in the past. -@pytest.mark.parametrize("input_dim_0", [1, 2]) -def test_custom_invoke_linear_nf4_all_weights_on_cpu(linear_nf4_layer: InvokeLinearNF4, input_dim_0: int): - """Test CustomInvokeLinearNF4 inference with all weights on the CPU (streaming to the GPU).""" - # Run inference on the original layer. - x = torch.randn(input_dim_0, 64).to(device="cuda") - y_quantized = linear_nf4_layer(x) - - # Copy the state dict to the CPU and reload it. - state_dict = linear_nf4_layer.state_dict() - state_dict = {k: v.to("cpu") for k, v in state_dict.items()} - linear_nf4_layer.load_state_dict(state_dict) - - # Inference of the original layer should fail. - with pytest.raises(RuntimeError): - linear_nf4_layer(x) - - # Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it. - linear_nf4_layer.__class__ = CustomInvokeLinearNF4 - y_custom = linear_nf4_layer(x) - - # Assert that the state dict (and the tensors that it references) are still on the CPU. - assert all(v.device == torch.device("cpu") for v in state_dict.values()) - - # Assert that the weight, bias, and quant_state are all on the CPU. - assert linear_nf4_layer.weight.device == torch.device("cpu") - assert linear_nf4_layer.bias.device == torch.device("cpu") - assert linear_nf4_layer.weight.quant_state.absmax.device == torch.device("cpu") - assert linear_nf4_layer.weight.quant_state.code.device == torch.device("cpu") - - # Assert that the quantized and custom layers produce the same output. - assert torch.allclose(y_quantized, y_custom, atol=1e-5) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py index 65b9f66066..1861597a63 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py @@ -72,7 +72,7 @@ def test_torch_module_autocast_linear_layer(device: torch.device, model: torch.n assert expected.device.type == "cpu" # Apply the custom layers to the model. - apply_custom_layers_to_model(model) + apply_custom_layers_to_model(model, device_autocasting_enabled=True) # Run the model on the device. autocast_result = model(x.to(device)) @@ -122,7 +122,7 @@ def test_torch_module_autocast_bnb_llm_int8_linear_layer(): # Move the model back to the CPU and add the custom layers to the model. model.to("cpu") - apply_custom_layers_to_model(model) + apply_custom_layers_to_model(model, device_autocasting_enabled=True) # Run inference with weights being streamed to the GPU. autocast_result = model(x.to("cuda")) diff --git a/tests/backend/patches/layers/test_flux_control_lora_layer.py b/tests/backend/patches/layers/test_flux_control_lora_layer.py index 00590c3514..129fcfcb4e 100644 --- a/tests/backend/patches/layers/test_flux_control_lora_layer.py +++ b/tests/backend/patches/layers/test_flux_control_lora_layer.py @@ -18,7 +18,7 @@ def test_flux_control_lora_layer_get_parameters(): orig_module = torch.nn.Linear(small_in_features, out_features) # Test that get_parameters() behaves as expected in spite of the difference in in_features shapes. - params = layer.get_parameters(orig_module, weight=1.0) + params = layer.get_parameters(dict(orig_module.named_parameters(recurse=False)), weight=1.0) assert "weight" in params assert params["weight"].shape == (out_features, big_in_features) assert params["weight"].allclose(torch.ones(out_features, big_in_features) * alpha) diff --git a/tests/backend/patches/layers/test_lora_layer.py b/tests/backend/patches/layers/test_lora_layer.py index 34f62c3bcf..c0971fb9a1 100644 --- a/tests/backend/patches/layers/test_lora_layer.py +++ b/tests/backend/patches/layers/test_lora_layer.py @@ -107,7 +107,7 @@ def test_lora_layer_get_parameters(): # Create mock original module orig_module = torch.nn.Linear(in_features, out_features) - params = layer.get_parameters(orig_module, weight=1.0) + params = layer.get_parameters(dict(orig_module.named_parameters(recurse=False)), weight=1.0) assert "weight" in params assert params["weight"].shape == orig_module.weight.shape assert params["weight"].allclose(torch.ones(out_features, in_features) * alpha) diff --git a/tests/backend/patches/layers/test_set_parameter_layer.py b/tests/backend/patches/layers/test_set_parameter_layer.py index 0bca0293f5..bdf8e33749 100644 --- a/tests/backend/patches/layers/test_set_parameter_layer.py +++ b/tests/backend/patches/layers/test_set_parameter_layer.py @@ -10,7 +10,7 @@ def test_set_parameter_layer_get_parameters(): target_weight = torch.randn(8, 4) layer = SetParameterLayer(param_name="weight", weight=target_weight) - params = layer.get_parameters(orig_module, weight=1.0) + params = layer.get_parameters(dict(orig_module.named_parameters(recurse=False)), weight=1.0) assert len(params) == 1 new_weight = orig_module.weight + params["weight"] assert torch.allclose(new_weight, target_weight) diff --git a/tests/backend/patches/sidecar_wrappers/test_flux_rms_norm_sidecar_wrapper.py b/tests/backend/patches/sidecar_wrappers/test_flux_rms_norm_sidecar_wrapper.py deleted file mode 100644 index ee0dce554f..0000000000 --- a/tests/backend/patches/sidecar_wrappers/test_flux_rms_norm_sidecar_wrapper.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch - -from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer -from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper - - -def test_flux_rms_norm_sidecar_wrapper(): - # Create a RMSNorm layer. - dim = 10 - rms_norm = torch.nn.RMSNorm(dim) - - # Create a SetParameterLayer. - new_scale = torch.randn(dim) - set_parameter_layer = SetParameterLayer("scale", new_scale) - - # Create a FluxRMSNormSidecarWrapper. - rms_norm_wrapped = FluxRMSNormSidecarWrapper(rms_norm, [(set_parameter_layer, 1.0)]) - - # Run the FluxRMSNormSidecarWrapper. - input = torch.randn(1, dim) - expected_output = torch.nn.functional.rms_norm(input, new_scale.shape, new_scale, eps=1e-6) - output_wrapped = rms_norm_wrapped(input) - assert torch.allclose(output_wrapped, expected_output, atol=1e-6) diff --git a/tests/backend/patches/sidecar_wrappers/test_linear_sidecar_wrapper.py b/tests/backend/patches/sidecar_wrappers/test_linear_sidecar_wrapper.py deleted file mode 100644 index 607f364dcd..0000000000 --- a/tests/backend/patches/sidecar_wrappers/test_linear_sidecar_wrapper.py +++ /dev/null @@ -1,182 +0,0 @@ -import copy - -import torch - -from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer -from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer -from invokeai.backend.patches.layers.full_layer import FullLayer -from invokeai.backend.patches.layers.lora_layer import LoRALayer -from invokeai.backend.patches.pad_with_zeros import pad_with_zeros -from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper - - -@torch.no_grad() -def test_linear_sidecar_wrapper_lora(): - # Create a linear layer. - in_features = 10 - out_features = 20 - linear = torch.nn.Linear(in_features, out_features) - - # Create a LoRA layer. - rank = 4 - down = torch.randn(rank, in_features) - up = torch.randn(out_features, rank) - bias = torch.randn(out_features) - lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias) - - # Patch the LoRA layer into the linear layer. - linear_patched = copy.deepcopy(linear) - linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale() - linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale() - - # Create a LinearSidecarWrapper. - lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, 1.0)]) - - # Run the LoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal. - input = torch.randn(1, in_features) - output_patched = linear_patched(input) - output_wrapped = lora_wrapped(input) - assert torch.allclose(output_patched, output_wrapped, atol=1e-6) - - -@torch.no_grad() -def test_linear_sidecar_wrapper_multiple_loras(): - # Create a linear layer. - in_features = 10 - out_features = 20 - linear = torch.nn.Linear(in_features, out_features) - - # Create two LoRA layers. - rank = 4 - lora_layer = LoRALayer( - up=torch.randn(out_features, rank), - mid=None, - down=torch.randn(rank, in_features), - alpha=1.0, - bias=torch.randn(out_features), - ) - lora_layer_2 = LoRALayer( - up=torch.randn(out_features, rank), - mid=None, - down=torch.randn(rank, in_features), - alpha=1.0, - bias=torch.randn(out_features), - ) - # We use different weights for the two LoRA layers to ensure this is working. - lora_weight = 1.0 - lora_weight_2 = 0.5 - - # Patch the LoRA layers into the linear layer. - linear_patched = copy.deepcopy(linear) - linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * (lora_layer.scale() * lora_weight) - linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * (lora_layer.scale() * lora_weight) - linear_patched.weight.data += lora_layer_2.get_weight(linear_patched.weight) * ( - lora_layer_2.scale() * lora_weight_2 - ) - linear_patched.bias.data += lora_layer_2.get_bias(linear_patched.bias) * (lora_layer_2.scale() * lora_weight_2) - - # Create a LinearSidecarWrapper. - lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, lora_weight), (lora_layer_2, lora_weight_2)]) - - # Run the LoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal. - input = torch.randn(1, in_features) - output_patched = linear_patched(input) - output_wrapped = lora_wrapped(input) - assert torch.allclose(output_patched, output_wrapped, atol=1e-6) - - -@torch.no_grad() -def test_linear_sidecar_wrapper_concatenated_lora(): - # Create a linear layer. - in_features = 5 - sub_layer_out_features = [5, 10, 15] - linear = torch.nn.Linear(in_features, sum(sub_layer_out_features)) - - # Create a ConcatenatedLoRA layer. - rank = 4 - sub_layers: list[LoRALayer] = [] - for out_features in sub_layer_out_features: - down = torch.randn(rank, in_features) - up = torch.randn(out_features, rank) - bias = torch.randn(out_features) - sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)) - concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0) - - # Patch the ConcatenatedLoRA layer into the linear layer. - linear_patched = copy.deepcopy(linear) - linear_patched.weight.data += ( - concatenated_lora_layer.get_weight(linear_patched.weight) * concatenated_lora_layer.scale() - ) - linear_patched.bias.data += concatenated_lora_layer.get_bias(linear_patched.bias) * concatenated_lora_layer.scale() - - # Create a LinearSidecarWrapper. - lora_wrapped = LinearSidecarWrapper(linear, [(concatenated_lora_layer, 1.0)]) - - # Run the ConcatenatedLoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal. - input = torch.randn(1, in_features) - output_patched = linear_patched(input) - output_wrapped = lora_wrapped(input) - assert torch.allclose(output_patched, output_wrapped, atol=1e-6) - - -def test_linear_sidecar_wrapper_full_layer(): - # Create a linear layer. - in_features = 10 - out_features = 20 - linear = torch.nn.Linear(in_features, out_features) - - # Create a FullLayer. - full_layer = FullLayer(weight=torch.randn(out_features, in_features), bias=torch.randn(out_features)) - - # Patch the FullLayer into the linear layer. - linear_patched = copy.deepcopy(linear) - linear_patched.weight.data += full_layer.get_weight(linear_patched.weight) - linear_patched.bias.data += full_layer.get_bias(linear_patched.bias) - - # Create a LinearSidecarWrapper. - full_wrapped = LinearSidecarWrapper(linear, [(full_layer, 1.0)]) - - # Run the FullLayer-patched linear layer and the LinearSidecarWrapper and assert they are equal. - input = torch.randn(1, in_features) - output_patched = linear_patched(input) - output_wrapped = full_wrapped(input) - assert torch.allclose(output_patched, output_wrapped, atol=1e-6) - - -def test_linear_sidecar_wrapper_flux_control_lora_layer(): - # Create a linear layer. - orig_in_features = 10 - out_features = 40 - linear = torch.nn.Linear(orig_in_features, out_features) - - # Create a FluxControlLoRALayer. - patched_in_features = 20 - rank = 4 - lora_layer = FluxControlLoRALayer( - up=torch.randn(out_features, rank), - mid=None, - down=torch.randn(rank, patched_in_features), - alpha=1.0, - bias=torch.randn(out_features), - ) - - # Patch the FluxControlLoRALayer into the linear layer. - linear_patched = copy.deepcopy(linear) - # Expand the existing weight. - expanded_weight = pad_with_zeros(linear_patched.weight, torch.Size([out_features, patched_in_features])) - linear_patched.weight = torch.nn.Parameter(expanded_weight, requires_grad=linear_patched.weight.requires_grad) - # Expand the existing bias. - expanded_bias = pad_with_zeros(linear_patched.bias, torch.Size([out_features])) - linear_patched.bias = torch.nn.Parameter(expanded_bias, requires_grad=linear_patched.bias.requires_grad) - # Add the residuals. - linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale() - linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale() - - # Create a LinearSidecarWrapper. - lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, 1.0)]) - - # Run the FluxControlLoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal. - input = torch.randn(1, patched_in_features) - output_patched = linear_patched(input) - output_wrapped = lora_wrapped(input) - assert torch.allclose(output_patched, output_wrapped, atol=1e-6) diff --git a/tests/backend/patches/test_layer_patcher.py b/tests/backend/patches/test_layer_patcher.py new file mode 100644 index 0000000000..84741fc60a --- /dev/null +++ b/tests/backend/patches/test_layer_patcher.py @@ -0,0 +1,313 @@ +import pytest +import torch + +from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( + CachedModelWithPartialLoad, +) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + apply_custom_layers_to_model, +) +from invokeai.backend.patches.layer_patcher import LayerPatcher +from invokeai.backend.patches.layers.lora_layer import LoRALayer +from invokeai.backend.patches.model_patch_raw import ModelPatchRaw + + +class DummyModuleWithOneLayer(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype): + super().__init__() + self.linear_layer_1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear_layer_1(x) + + +class DummyModuleWithTwoLayers(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype): + super().__init__() + self.linear_layer_1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype) + self.linear_layer_2 = torch.nn.Linear(out_features, out_features, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear_layer_2(self.linear_layer_1(x)) + + +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")), + ], +) +@pytest.mark.parametrize("num_loras", [1, 2]) +@pytest.mark.parametrize( + ["force_sidecar_patching", "force_direct_patching"], [(True, False), (False, True), (False, False)] +) +@torch.no_grad() +def test_apply_smart_model_patches( + device: str, num_loras: int, force_sidecar_patching: bool, force_direct_patching: bool +): + """Test the basic behavior of ModelPatcher.apply_smart_model_patches(...). Check that unpatching works correctly.""" + dtype = torch.float16 + linear_in_features = 4 + linear_out_features = 8 + lora_rank = 2 + model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=dtype) + apply_custom_layers_to_model(model) + + # Initialize num_loras LoRA models with weights of 0.5. + lora_weight = 0.5 + lora_models: list[tuple[ModelPatchRaw, float]] = [] + for _ in range(num_loras): + lora_layers = { + "linear_layer_1": LoRALayer.from_state_dict_values( + values={ + "lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16), + "lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16), + }, + ) + } + lora = ModelPatchRaw(lora_layers) + lora_models.append((lora, lora_weight)) + + orig_linear_weight = model.linear_layer_1.weight.data.detach().clone() + expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_loras) + + # Run inference before patching the model. + input = torch.randn(1, linear_in_features, device=device, dtype=dtype) + output_before_patch = model(input) + + expect_sidecar_wrappers = device == "cpu" + if force_sidecar_patching: + expect_sidecar_wrappers = True + elif force_direct_patching: + expect_sidecar_wrappers = False + + # Patch the model and run inference during the patch. + with LayerPatcher.apply_smart_model_patches( + model=model, + patches=lora_models, + prefix="", + dtype=dtype, + force_direct_patching=force_direct_patching, + force_sidecar_patching=force_sidecar_patching, + ): + if expect_sidecar_wrappers: + # There should be sidecar patches in the model. + assert model.linear_layer_1.get_num_patches() == num_loras + else: + # There should be no sidecar patches in the model. + assert model.linear_layer_1.get_num_patches() == 0 + torch.testing.assert_close(model.linear_layer_1.weight.data, expected_patched_linear_weight) + + # After patching, the patched model should still be on its original device. + assert model.linear_layer_1.weight.data.device.type == device + + # After patching, all LoRA layer weights should have been moved back to the cpu. + for lora, _ in lora_models: + assert lora.layers["linear_layer_1"].up.device.type == "cpu" + assert lora.layers["linear_layer_1"].down.device.type == "cpu" + + output_during_patch = model(input) + + # Run inference after unpatching. + output_after_patch = model(input) + + # Check that the output before patching is different from the output during patching. + assert not torch.allclose(output_before_patch, output_during_patch) + + # Check that the output before patching is the same as the output after patching. + assert torch.allclose(output_before_patch, output_after_patch) + + +@pytest.mark.parametrize(["num_loras"], [(1,), (2,)]) +@torch.no_grad() +def test_apply_smart_lora_patches_to_partially_loaded_model(num_loras: int): + """Test the behavior of ModelPatcher.apply_smart_lora_patches(...) when it is applied to a + CachedModelWithPartialLoad that is partially loaded into VRAM. + """ + + if not torch.cuda.is_available(): + pytest.skip("requires CUDA device") + + # Initialize the model on the CPU. + dtype = torch.float16 + linear_in_features = 4 + linear_out_features = 8 + lora_rank = 2 + model = DummyModuleWithTwoLayers(linear_in_features, linear_out_features, device="cpu", dtype=dtype) + apply_custom_layers_to_model(model) + cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device("cuda")) + model_total_bytes = cached_model.total_bytes() + assert cached_model.cur_vram_bytes() == 0 + + # Partially load the model into VRAM. + target_vram_bytes = int(model_total_bytes * 0.6) + _ = cached_model.partial_load_to_vram(target_vram_bytes) + assert cached_model.model.linear_layer_1.weight.device.type == "cuda" + assert cached_model.model.linear_layer_2.weight.device.type == "cpu" + + # Initialize num_loras LoRA models with weights of 0.5. + lora_weight = 0.5 + lora_models: list[tuple[ModelPatchRaw, float]] = [] + for _ in range(num_loras): + lora_layers = { + "linear_layer_1": LoRALayer.from_state_dict_values( + values={ + "lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16), + "lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16), + }, + ), + "linear_layer_2": LoRALayer.from_state_dict_values( + values={ + "lora_down.weight": torch.ones((lora_rank, linear_out_features), device="cpu", dtype=torch.float16), + "lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16), + }, + ), + } + lora = ModelPatchRaw(lora_layers) + lora_models.append((lora, lora_weight)) + + # Run inference before patching the model. + input = torch.randn(1, linear_in_features, device="cuda", dtype=dtype) + output_before_patch = cached_model.model(input) + + # Patch the model and run inference during the patch. + with LayerPatcher.apply_smart_model_patches(model=cached_model.model, patches=lora_models, prefix="", dtype=dtype): + # Check that the second layer has sidecar patches, but the first layer does not. + assert cached_model.model.linear_layer_1.get_num_patches() == 0 + assert cached_model.model.linear_layer_2.get_num_patches() == num_loras + + output_during_patch = cached_model.model(input) + + # Run inference after unpatching. + output_after_patch = cached_model.model(input) + + # Check that the output before patching is different from the output during patching. + assert not torch.allclose(output_before_patch, output_during_patch) + + # Check that the output before patching is the same as the output after patching. + assert torch.allclose(output_before_patch, output_after_patch) + + +@torch.no_grad() +@pytest.mark.parametrize(["num_loras"], [(1,), (2,)]) +def test_all_patching_methods_produce_same_output(num_loras: int): + """Test that apply_lora_wrapper_patches(...) produces the same model outputs as apply_lora_patches(...).""" + dtype = torch.float32 + linear_in_features = 4 + linear_out_features = 8 + lora_rank = 2 + model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=dtype) + apply_custom_layers_to_model(model) + + # Initialize num_loras LoRA models with weights of 0.5. + lora_weight = 0.5 + lora_models: list[tuple[ModelPatchRaw, float]] = [] + for _ in range(num_loras): + lora_layers = { + "linear_layer_1": LoRALayer.from_state_dict_values( + values={ + "lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16), + "lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16), + }, + ) + } + lora = ModelPatchRaw(lora_layers) + lora_models.append((lora, lora_weight)) + + input = torch.randn(1, linear_in_features, device="cpu", dtype=dtype) + + with LayerPatcher.apply_smart_model_patches( + model=model, patches=lora_models, prefix="", dtype=dtype, force_direct_patching=True + ): + output_force_direct = model(input) + + with LayerPatcher.apply_smart_model_patches( + model=model, patches=lora_models, prefix="", dtype=dtype, force_sidecar_patching=True + ): + output_force_sidecar = model(input) + + with LayerPatcher.apply_smart_model_patches(model=model, patches=lora_models, prefix="", dtype=dtype): + output_smart = model(input) + + # Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical + # differences are tolerable and expected due to the difference between sidecar vs. patching. + assert torch.allclose(output_force_direct, output_force_sidecar, atol=1e-5) + assert torch.allclose(output_force_direct, output_smart, atol=1e-5) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device") +@torch.no_grad() +def test_apply_smart_model_patches_change_device(): + """Test that if LoRA patching is applied on the CPU, and then the patched model is moved to the GPU, unpatching + still behaves correctly. + """ + linear_in_features = 4 + linear_out_features = 8 + lora_dim = 2 + # Initialize the model on the CPU. + model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16) + apply_custom_layers_to_model(model) + + lora_layers = { + "linear_layer_1": LoRALayer.from_state_dict_values( + values={ + "lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16), + "lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16), + }, + ) + } + lora = ModelPatchRaw(lora_layers) + + orig_linear_weight = model.linear_layer_1.weight.data.detach().clone() + + with LayerPatcher.apply_smart_model_patches( + model=model, patches=[(lora, 0.5)], prefix="", dtype=torch.float16, force_direct_patching=True + ): + # After patching, all LoRA layer weights should have been moved back to the cpu. + assert lora_layers["linear_layer_1"].up.device.type == "cpu" + assert lora_layers["linear_layer_1"].down.device.type == "cpu" + + # After patching, the patched model should still be on the CPU. + assert model.linear_layer_1.weight.data.device.type == "cpu" + + # There should be no sidecar patches in the model. + assert model.linear_layer_1.get_num_patches() == 0 + + # Move the model to the GPU. + assert model.to("cuda") + + # After unpatching, the original model weights should have been restored on the GPU. + assert model.linear_layer_1.weight.data.device.type == "cuda" + torch.testing.assert_close(model.linear_layer_1.weight.data, orig_linear_weight, check_device=False) + + +def test_apply_smart_model_patches_force_sidecar_and_direct_patching(): + """Test that ModelPatcher.apply_smart_model_patches(..., force_direct_patching=True, force_sidecar_patching=True) + raises an error. + """ + linear_in_features = 4 + linear_out_features = 8 + lora_rank = 2 + model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16) + apply_custom_layers_to_model(model) + + lora_layers = { + "linear_layer_1": LoRALayer.from_state_dict_values( + values={ + "lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16), + "lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16), + }, + ) + } + lora = ModelPatchRaw(lora_layers) + with pytest.raises(ValueError, match="Cannot force both direct and sidecar patching."): + with LayerPatcher.apply_smart_model_patches( + model=model, + patches=[(lora, 0.5)], + prefix="", + dtype=torch.float16, + force_direct_patching=True, + force_sidecar_patching=True, + ): + pass diff --git a/tests/backend/patches/test_lora_patcher.py b/tests/backend/patches/test_lora_patcher.py deleted file mode 100644 index 057504bb97..0000000000 --- a/tests/backend/patches/test_lora_patcher.py +++ /dev/null @@ -1,197 +0,0 @@ -import pytest -import torch - -from invokeai.backend.patches.layers.lora_layer import LoRALayer -from invokeai.backend.patches.model_patch_raw import ModelPatchRaw -from invokeai.backend.patches.model_patcher import LayerPatcher - - -class DummyModule(torch.nn.Module): - def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype): - super().__init__() - self.linear_layer_1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear_layer_1(x) - - -@pytest.mark.parametrize( - ["device", "num_layers"], - [ - ("cpu", 1), - pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")), - ("cpu", 2), - pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")), - ], -) -@torch.no_grad() -def test_apply_lora_patches(device: str, num_layers: int): - """Test the basic behavior of ModelPatcher.apply_lora_patches(...). Check that patching and unpatching produce the - correct result, and that model/LoRA tensors are moved between devices as expected. - """ - - linear_in_features = 4 - linear_out_features = 8 - lora_rank = 2 - model = DummyModule(linear_in_features, linear_out_features, device=device, dtype=torch.float16) - - # Initialize num_layers LoRA models with weights of 0.5. - lora_weight = 0.5 - lora_models: list[tuple[ModelPatchRaw, float]] = [] - for _ in range(num_layers): - lora_layers = { - "linear_layer_1": LoRALayer.from_state_dict_values( - values={ - "lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16), - "lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16), - }, - ) - } - lora = ModelPatchRaw(lora_layers) - lora_models.append((lora, lora_weight)) - - orig_linear_weight = model.linear_layer_1.weight.data.detach().clone() - expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_layers) - - with LayerPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""): - # After patching, all LoRA layer weights should have been moved back to the cpu. - for lora, _ in lora_models: - assert lora.layers["linear_layer_1"].up.device.type == "cpu" - assert lora.layers["linear_layer_1"].down.device.type == "cpu" - - # After patching, the patched model should still be on its original device. - assert model.linear_layer_1.weight.data.device.type == device - - torch.testing.assert_close(model.linear_layer_1.weight.data, expected_patched_linear_weight) - - # After unpatching, the original model weights should have been restored on the original device. - assert model.linear_layer_1.weight.data.device.type == device - torch.testing.assert_close(model.linear_layer_1.weight.data, orig_linear_weight) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device") -@torch.no_grad() -def test_apply_lora_patches_change_device(): - """Test that if LoRA patching is applied on the CPU, and then the patched model is moved to the GPU, unpatching - still behaves correctly. - """ - linear_in_features = 4 - linear_out_features = 8 - lora_dim = 2 - # Initialize the model on the CPU. - model = DummyModule(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16) - - lora_layers = { - "linear_layer_1": LoRALayer.from_state_dict_values( - values={ - "lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16), - "lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16), - }, - ) - } - lora = ModelPatchRaw(lora_layers) - - orig_linear_weight = model.linear_layer_1.weight.data.detach().clone() - - with LayerPatcher.apply_model_patches(model=model, patches=[(lora, 0.5)], prefix=""): - # After patching, all LoRA layer weights should have been moved back to the cpu. - assert lora_layers["linear_layer_1"].up.device.type == "cpu" - assert lora_layers["linear_layer_1"].down.device.type == "cpu" - - # After patching, the patched model should still be on the CPU. - assert model.linear_layer_1.weight.data.device.type == "cpu" - - # Move the model to the GPU. - assert model.to("cuda") - - # After unpatching, the original model weights should have been restored on the GPU. - assert model.linear_layer_1.weight.data.device.type == "cuda" - torch.testing.assert_close(model.linear_layer_1.weight.data, orig_linear_weight, check_device=False) - - -@pytest.mark.parametrize( - ["device", "num_layers"], - [ - ("cpu", 1), - pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")), - ("cpu", 2), - pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")), - ], -) -def test_apply_lora_sidecar_patches(device: str, num_layers: int): - """Test the basic behavior of ModelPatcher.apply_lora_sidecar_patches(...). Check that unpatching works correctly.""" - dtype = torch.float16 - linear_in_features = 4 - linear_out_features = 8 - lora_rank = 2 - model = DummyModule(linear_in_features, linear_out_features, device=device, dtype=dtype) - - # Initialize num_layers LoRA models with weights of 0.5. - lora_weight = 0.5 - lora_models: list[tuple[ModelPatchRaw, float]] = [] - for _ in range(num_layers): - lora_layers = { - "linear_layer_1": LoRALayer.from_state_dict_values( - values={ - "lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16), - "lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16), - }, - ) - } - lora = ModelPatchRaw(lora_layers) - lora_models.append((lora, lora_weight)) - - # Run inference before patching the model. - input = torch.randn(1, linear_in_features, device=device, dtype=dtype) - output_before_patch = model(input) - - # Patch the model and run inference during the patch. - with LayerPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype): - output_during_patch = model(input) - - # Run inference after unpatching. - output_after_patch = model(input) - - # Check that the output before patching is different from the output during patching. - assert not torch.allclose(output_before_patch, output_during_patch) - - # Check that the output before patching is the same as the output after patching. - assert torch.allclose(output_before_patch, output_after_patch) - - -@torch.no_grad() -@pytest.mark.parametrize(["num_layers"], [(1,), (2,)]) -def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int): - """Test that apply_lora_sidecar_patches(...) produces the same model outputs as apply_lora_patches(...).""" - dtype = torch.float32 - linear_in_features = 4 - linear_out_features = 8 - lora_rank = 2 - model = DummyModule(linear_in_features, linear_out_features, device="cpu", dtype=dtype) - - # Initialize num_layers LoRA models with weights of 0.5. - lora_weight = 0.5 - lora_models: list[tuple[ModelPatchRaw, float]] = [] - for _ in range(num_layers): - lora_layers = { - "linear_layer_1": LoRALayer.from_state_dict_values( - values={ - "lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16), - "lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16), - }, - ) - } - lora = ModelPatchRaw(lora_layers) - lora_models.append((lora, lora_weight)) - - input = torch.randn(1, linear_in_features, device="cpu", dtype=dtype) - - with LayerPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""): - output_lora_patches = model(input) - - with LayerPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype): - output_lora_sidecar_patches = model(input) - - # Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical - # differences are tolerable and expected due to the difference between sidecar vs. patching. - assert torch.allclose(output_lora_patches, output_lora_sidecar_patches, atol=1e-5)