Partial Loading PR3: Integrate 1) partial loading, 2) quantized models, 3) model patching (#7500)

## Summary

This PR is the third in a sequence of PRs working towards support for
partial loading of models onto the compute device (for low-VRAM
operation). This PR updates the LoRA patching code so that the following
features can cooperate fully:
- Partial loading of weights onto the GPU
- Quantized layers / weights
- Model patches (e.g. LoRA)

Note that this PR does not yet enable partial loading. It adds support
in the model patching code so that partial loading can be enabled in a
future PR.

## Technical Design Decisions

The layer patching logic has been integrated into the custom layers (via
`CustomModuleMixin`) rather than keeping it in a separate set of wrapper
layers, as before. This has the following advantages:
- It makes it easier to calculate the modified weights on the fly and
then reuse the normal forward() logic.
- In the future, it makes it possible to pass original parameters that
have been cast to the device down to the LoRA calculation without having
to re-cast (but the current implementation hasn't fully taken advantage
of this yet).

## Know Limitations

1. I haven't fully solved device management for patch types that require
the original layer value to calculate the patch. These aren't very
common, and are not compatible with some quantized layers, so leaving
this for future if there's demand.
2. There is a small speed regression for models that have CPU
bottlenecks. This seems to be caused by slightly slower method
resolution on the custom layers sub-classes. The regression does not
show up on larger models, like FLUX, that are almost entirely
GPU-limited. I think this small regression is tolerable, but if we
decide that it's not, then the slowdown can easily be reclaimed by
optimizing other CPU operations (e.g. if we only sent every 2nd progress
image, we'd see a much more significant speedup).

## Related Issues / Discussions

- https://github.com/invoke-ai/InvokeAI/pull/7492
- https://github.com/invoke-ai/InvokeAI/pull/7494

## QA Instructions

Speed tests:
- Vanilla SD1 speed regression
    - Before: 3.156s (8.78 it/s)
    - After: 3.54s (8.35 it/s)
- Vanilla SDXL speed regression
    - Before: 6.23s (4.46 it/s)
    - After: 6.45s (4.31 it/s)
- Vanilla FLUX speed regression
    - Before: 12.02s (2.27 it/s)
    - After: 11.91s (2.29 it/s)

LoRA tests with default configuration:
- [x] SD1: A handful of LoRA variants
- [x] SDXL: A handful of LoRA variants
- [x] flux non-quantized: multiple lora variants
- [x] flux bnb-quantized: multiple lora variants
- [x] flux ggml-quantized: muliple lora variants
- [x] flux non-quantized: FLUX control LoRA
- [x] flux bnb-quantized: FLUX control LoRA
- [x] flux ggml-quantized: FLUX control LoRA

LoRA tests with sidecar patching forced:
- [x] SD1: A handful of LoRA variants
- [x] SDXL: A handful of LoRA variants
- [x] flux non-quantized: multiple lora variants
- [x] flux bnb-quantized: multiple lora variants
- [x] flux ggml-quantized: muliple lora variants
- [x] flux non-quantized: FLUX control LoRA
- [x] flux bnb-quantized: FLUX control LoRA
- [x] flux ggml-quantized: FLUX control LoRA

Other:
- [x] Smoke testing of IP-Adapter, ControlNet

All tests repeated on:
- [x] cuda
- [x] cpu (only test SD1, because larger models are prohibitively slow)
- [x] mps (skipped FLUX tests, because my Mac doesn't have enough memory
to run them in a reasonable amount of time)

## Merge Plan

No special instructions.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
This commit is contained in:
Ryan Dick 2024-12-31 13:58:13 -05:00 committed by GitHub
commit b46d7abfb0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
50 changed files with 1734 additions and 1035 deletions

View File

@ -20,8 +20,8 @@ from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningFieldData,
@ -82,10 +82,11 @@ class CompelInvocation(BaseInvocation):
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
LayerPatcher.apply_model_patches(
LayerPatcher.apply_smart_model_patches(
model=text_encoder,
patches=_lora_loader(),
prefix="lora_te_",
dtype=text_encoder.dtype,
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
@ -179,10 +180,11 @@ class SDXLPromptInvocationBase:
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
LayerPatcher.apply_model_patches(
text_encoder,
LayerPatcher.apply_smart_model_patches(
model=text_encoder,
patches=_lora_loader(),
prefix=lora_prefix,
dtype=text_encoder.dtype,
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.

View File

@ -39,8 +39,8 @@ from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
@ -1003,10 +1003,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
LayerPatcher.apply_model_patches(
LayerPatcher.apply_smart_model_patches(
model=unet,
patches=_lora_loader(),
prefix="lora_unet_",
dtype=unet.dtype,
cached_weights=cached_weights,
),
):

View File

@ -48,9 +48,9 @@ from invokeai.backend.flux.sampling_utils import (
)
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@ -304,36 +304,33 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
config = transformer_info.config
assert config is not None
# Apply LoRA models to the transformer.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
# Determine if the model is quantized.
# If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
# slower inference than direct patching, but is agnostic to the quantization format.
if config.format in [ModelFormat.Checkpoint]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LayerPatcher.apply_model_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
cached_weights=cached_weights,
)
)
model_is_quantized = False
elif config.format in [
ModelFormat.BnbQuantizedLlmInt8b,
ModelFormat.BnbQuantizednf4b,
ModelFormat.GGUFQuantized,
]:
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
# than directly patching the weights, but is agnostic to the quantization format.
exit_stack.enter_context(
LayerPatcher.apply_model_sidecar_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
dtype=inference_dtype,
)
)
model_is_quantized = True
else:
raise ValueError(f"Unsupported model format: {config.format}")
# Apply LoRA models to the transformer.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
dtype=inference_dtype,
cached_weights=cached_weights,
force_sidecar_patching=model_is_quantized,
)
)
# Prepare IP-Adapter extensions.
pos_ip_adapter_extensions, neg_ip_adapter_extensions = self._prep_ip_adapter_extensions(
pos_image_prompt_clip_embeds=pos_image_prompt_clip_embeds,

View File

@ -18,9 +18,9 @@ from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
@ -111,10 +111,11 @@ class FluxTextEncoderInvocation(BaseInvocation):
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LayerPatcher.apply_model_patches(
LayerPatcher.apply_smart_model_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context),
prefix=FLUX_LORA_CLIP_PREFIX,
dtype=clip_text_encoder.dtype,
cached_weights=cached_weights,
)
)

View File

@ -17,9 +17,9 @@ from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import SD3ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
@ -150,10 +150,11 @@ class Sd3TextEncoderInvocation(BaseInvocation):
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LayerPatcher.apply_model_patches(
LayerPatcher.apply_smart_model_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context, clip_model),
prefix=FLUX_LORA_CLIP_PREFIX,
dtype=clip_text_encoder.dtype,
cached_weights=cached_weights,
)
)

View File

@ -22,8 +22,8 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
@ -207,7 +207,9 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
with (
ExitStack() as exit_stack,
unet_info as unet,
LayerPatcher.apply_model_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
LayerPatcher.apply_smart_model_patches(
model=unet, patches=_lora_loader(), prefix="lora_unet_", dtype=unet.dtype
),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)

View File

@ -1,9 +1,7 @@
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
AUTOCAST_MODULE_TYPE_MAPPING,
apply_custom_layers_to_model,
remove_custom_layers_from_model,
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
from invokeai.backend.util.logging import InvokeAILogger
@ -45,10 +43,10 @@ class CachedModelWithPartialLoad:
def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]:
"""Find all modules that support autocasting."""
return {n: m for n, m in self._model.named_modules() if type(m) in AUTOCAST_MODULE_TYPE_MAPPING}
return {n: m for n, m in self._model.named_modules() if isinstance(m, CustomModuleMixin)} # type: ignore
def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]:
keys_in_modules_that_do_not_support_autocast = set()
keys_in_modules_that_do_not_support_autocast: set[str] = set()
for key in self._cpu_state_dict.keys():
for module_name in self._modules_that_support_autocast.keys():
if key.startswith(module_name):
@ -70,6 +68,11 @@ class CachedModelWithPartialLoad:
if name in module._non_persistent_buffers_set:
module._buffers[name] = buffer.to(device, copy=True)
def _set_autocast_enabled_in_all_modules(self, enabled: bool):
"""Set autocast_enabled flag in all modules that support device autocasting."""
for module in self._modules_that_support_autocast.values():
module.set_device_autocasting_enabled(enabled)
@property
def model(self) -> torch.nn.Module:
return self._model
@ -114,7 +117,7 @@ class CachedModelWithPartialLoad:
cur_state_dict = self._model.state_dict()
# First, process the keys *must* be loaded into VRAM.
# First, process the keys that *must* be loaded into VRAM.
for key in self._keys_in_modules_that_do_not_support_autocast:
param = cur_state_dict[key]
if param.device.type == self._compute_device.type:
@ -157,10 +160,10 @@ class CachedModelWithPartialLoad:
self._cur_vram_bytes += vram_bytes_loaded
if fully_loaded:
remove_custom_layers_from_model(self._model)
self._set_autocast_enabled_in_all_modules(False)
# TODO(ryand): Warn if the self.cur_vram_bytes() and self.total_bytes() are out of sync.
else:
apply_custom_layers_to_model(self._model)
self._set_autocast_enabled_in_all_modules(True)
# Move all non-persistent buffers to the compute device. These are a weird edge case and do not participate in
# the vram_bytes_loaded tracking.
@ -197,5 +200,5 @@ class CachedModelWithPartialLoad:
# We may have gone from a fully-loaded model to a partially-loaded model, so we need to reapply the custom
# layers.
apply_custom_layers_to_model(self._model)
self._set_autocast_enabled_in_all_modules(True)
return vram_bytes_freed

View File

@ -13,6 +13,9 @@ from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
apply_custom_layers_to_model,
)
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@ -143,6 +146,10 @@ class ModelCache:
size = calc_model_size_by_data(self._logger, model)
self.make_room(size)
# Inject custom modules into the model.
if isinstance(model, torch.nn.Module):
apply_custom_layers_to_model(model)
running_on_cpu = self._execution_device == torch.device("cpu")
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
cache_record = CacheRecord(key=key, model=model, device=self._storage_device, state_dict=state_dict, size=size)

View File

@ -1,50 +0,0 @@
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
# This file contains custom torch.nn.Module classes that support streaming of weights to the target device.
# Each class sub-classes the original module type that is is replacing, so the following properties are preserved:
# - isinstance(m, torch.nn.OrginalModule) should still work.
# - Patching the weights (e.g. for LoRA) should still work if non-quantized.
class CustomLinear(torch.nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return torch.nn.functional.linear(input, weight, bias)
class CustomConv1d(torch.nn.Conv1d):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return self._conv_forward(input, weight, bias)
class CustomConv2d(torch.nn.Conv2d):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return self._conv_forward(input, weight, bias)
class CustomGroupNorm(torch.nn.GroupNorm):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
class CustomEmbedding(torch.nn.Embedding):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
return torch.nn.functional.embedding(
input,
weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)

View File

@ -0,0 +1,8 @@
This directory contains custom implementations of common torch.nn.Module classes that add support for:
- Streaming weights to the execution device
- Applying sidecar patches at execution time (e.g. sidecar LoRA layers)
Each custom class sub-classes the original module type that is is replacing, so the following properties are preserved:
- `isinstance(m, torch.nn.OrginalModule)` should still work.
- Patching the weights directly (e.g. for LoRA) should still work. (Of course, this is not possible for quantized layers, hence the sidecar support.)

View File

@ -0,0 +1,43 @@
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
add_nullable_tensors,
)
class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin):
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
# Prepare the original parameters for the patch aggregation.
orig_params = {"weight": weight, "bias": bias}
# Filter out None values.
orig_params = {k: v for k, v in orig_params.items() if v is not None}
aggregated_param_residuals = self._aggregate_patch_parameters(
patches_and_weights=self._patches_and_weights,
orig_params=orig_params,
device=input.device,
)
weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None))
bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None))
return self._conv_forward(input, weight, bias)
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return self._conv_forward(input, weight, bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if len(self._patches_and_weights) > 0:
return self._autocast_forward_with_patches(input)
elif self._device_autocasting_enabled:
return self._autocast_forward(input)
else:
return super().forward(input)

View File

@ -0,0 +1,43 @@
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
add_nullable_tensors,
)
class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
# Prepare the original parameters for the patch aggregation.
orig_params = {"weight": weight, "bias": bias}
# Filter out None values.
orig_params = {k: v for k, v in orig_params.items() if v is not None}
aggregated_param_residuals = self._aggregate_patch_parameters(
patches_and_weights=self._patches_and_weights,
orig_params=orig_params,
device=input.device,
)
weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None))
bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None))
return self._conv_forward(input, weight, bias)
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return self._conv_forward(input, weight, bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if len(self._patches_and_weights) > 0:
return self._autocast_forward_with_patches(input)
elif self._device_autocasting_enabled:
return self._autocast_forward(input)
else:
return super().forward(input)

View File

@ -0,0 +1,29 @@
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
class CustomEmbedding(torch.nn.Embedding, CustomModuleMixin):
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
return torch.nn.functional.embedding(
input,
weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if len(self._patches_and_weights) > 0:
raise RuntimeError("Embedding layers do not support patches")
if self._device_autocasting_enabled:
return self._autocast_forward(input)
else:
return super().forward(input)

View File

@ -0,0 +1,36 @@
import torch
from invokeai.backend.flux.modules.layers import RMSNorm
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
class CustomFluxRMSNorm(RMSNorm, CustomModuleMixin):
def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor:
# Currently, CustomFluxRMSNorm layers only support patching with a single SetParameterLayer.
assert len(self._patches_and_weights) == 1
patch, _patch_weight = self._patches_and_weights[0]
assert isinstance(patch, SetParameterLayer)
assert patch.param_name == "scale"
scale = cast_to_device(patch.weight, x.device)
# Apply the patch.
# NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should
# be handled.
return torch.nn.functional.rms_norm(x, scale.shape, scale, eps=1e-6)
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
scale = cast_to_device(self.scale, x.device)
return torch.nn.functional.rms_norm(x, scale.shape, scale, eps=1e-6)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(self._patches_and_weights) > 0:
return self._autocast_forward_with_patches(x)
elif self._device_autocasting_enabled:
return self._autocast_forward(x)
else:
return super().forward(x)

View File

@ -0,0 +1,22 @@
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
class CustomGroupNorm(torch.nn.GroupNorm, CustomModuleMixin):
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if len(self._patches_and_weights) > 0:
raise RuntimeError("GroupNorm layers do not support patches")
if self._device_autocasting_enabled:
return self._autocast_forward(input)
else:
return super().forward(input)

View File

@ -2,11 +2,20 @@ import bitsandbytes as bnb
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import (
autocast_linear_forward_sidecar_patches,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
class CustomInvokeLinear8bitLt(InvokeLinear8bitLt):
def forward(self, x: torch.Tensor) -> torch.Tensor:
class CustomInvokeLinear8bitLt(InvokeLinear8bitLt, CustomModuleMixin):
def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor:
return autocast_linear_forward_sidecar_patches(self, x, self._patches_and_weights)
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
matmul_state = bnb.MatmulLtState()
matmul_state.threshold = self.state.threshold
matmul_state.has_fp16_weights = self.state.has_fp16_weights
@ -25,3 +34,11 @@ class CustomInvokeLinear8bitLt(InvokeLinear8bitLt):
# it's dtype field must be accessible, even though it's not used. We pass in self.weight even though it could be
# on the wrong device.
return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(self._patches_and_weights) > 0:
return self._autocast_forward_with_patches(x)
elif self._device_autocasting_enabled:
return self._autocast_forward(x)
else:
return super().forward(x)

View File

@ -4,11 +4,20 @@ import bitsandbytes as bnb
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import (
autocast_linear_forward_sidecar_patches,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
class CustomInvokeLinearNF4(InvokeLinearNF4):
def forward(self, x: torch.Tensor) -> torch.Tensor:
class CustomInvokeLinearNF4(InvokeLinearNF4, CustomModuleMixin):
def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor:
return autocast_linear_forward_sidecar_patches(self, x, self._patches_and_weights)
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
bnb.nn.modules.fix_4bit_weight_quant_state_from_module(self)
# weights are cast automatically as Int8Params, but the bias has to be cast manually
@ -43,3 +52,11 @@ class CustomInvokeLinearNF4(InvokeLinearNF4):
bias = cast_to_device(self.bias, x.device)
return bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state).to(inp_dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(self._patches_and_weights) > 0:
return self._autocast_forward_with_patches(x)
elif self._device_autocasting_enabled:
return self._autocast_forward(x)
else:
return super().forward(x)

View File

@ -0,0 +1,106 @@
import copy
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a sidecar linear LoRALayer."""
x = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x = torch.nn.functional.linear(x, lora_layer.mid)
x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias)
x *= lora_weight * lora_layer.scale()
return x
def concatenated_lora_forward(
input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float
) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a sidecar ConcatenatedLoRALayer."""
x_chunks: list[torch.Tensor] = []
for lora_layer in concatenated_lora_layer.lora_layers:
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
x_chunk *= lora_weight * lora_layer.scale()
x_chunks.append(x_chunk)
# TODO(ryand): Generalize to support concat_axis != 0.
assert concatenated_lora_layer.concat_axis == 0
x = torch.cat(x_chunks, dim=-1)
return x
def autocast_linear_forward_sidecar_patches(
orig_module: torch.nn.Linear, input: torch.Tensor, patches_and_weights: list[tuple[BaseLayerPatch, float]]
) -> torch.Tensor:
"""A function that runs a linear layer (quantized or non-quantized) with sidecar patches for a linear layer.
Compatible with both quantized and non-quantized Linear layers.
"""
# First, apply the original linear layer.
# NOTE: We slice the input to match the original weight shape in order to work with FluxControlLoRAs, which
# change the linear layer's in_features.
orig_input = input
input = orig_input[..., : orig_module.in_features]
output = orig_module._autocast_forward(input)
# Then, apply layers for which we have optimized implementations.
unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = []
for patch, patch_weight in patches_and_weights:
# Shallow copy the patch so that we can cast it to the target device without modifying the original patch.
patch = copy.copy(patch)
patch.to(input.device)
if isinstance(patch, FluxControlLoRALayer):
# Note that we use the original input here, not the sliced input.
output += linear_lora_forward(orig_input, patch, patch_weight)
elif isinstance(patch, LoRALayer):
output += linear_lora_forward(input, patch, patch_weight)
elif isinstance(patch, ConcatenatedLoRALayer):
output += concatenated_lora_forward(input, patch, patch_weight)
else:
unprocessed_patches_and_weights.append((patch, patch_weight))
# Finally, apply any remaining patches.
if len(unprocessed_patches_and_weights) > 0:
# Prepare the original parameters for the patch aggregation.
orig_params = {"weight": orig_module.weight, "bias": orig_module.bias}
# Filter out None values.
orig_params = {k: v for k, v in orig_params.items() if v is not None}
aggregated_param_residuals = orig_module._aggregate_patch_parameters(
unprocessed_patches_and_weights, orig_params=orig_params, device=input.device
)
output += torch.nn.functional.linear(
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
)
return output
class CustomLinear(torch.nn.Linear, CustomModuleMixin):
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
return autocast_linear_forward_sidecar_patches(self, input, self._patches_and_weights)
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if len(self._patches_and_weights) > 0:
return self._autocast_forward_with_patches(input)
elif self._device_autocasting_enabled:
return self._autocast_forward(input)
else:
return super().forward(input)

View File

@ -0,0 +1,63 @@
import copy
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
class CustomModuleMixin:
"""A mixin class for custom modules that enables device autocasting of module parameters."""
def __init__(self):
self._device_autocasting_enabled = False
self._patches_and_weights: list[tuple[BaseLayerPatch, float]] = []
def set_device_autocasting_enabled(self, enabled: bool):
"""Pass True to enable autocasting of module parameters to the same device as the input tensor. Pass False to
disable autocasting, which results in slightly faster execution speed when we know that device autocasting is
not needed.
"""
self._device_autocasting_enabled = enabled
def is_device_autocasting_enabled(self) -> bool:
"""Check if device autocasting is enabled for the module."""
return self._device_autocasting_enabled
def add_patch(self, patch: BaseLayerPatch, patch_weight: float):
"""Add a patch to the module."""
self._patches_and_weights.append((patch, patch_weight))
def clear_patches(self):
"""Clear all patches from the module."""
self._patches_and_weights = []
def get_num_patches(self) -> int:
"""Get the number of patches in the module."""
return len(self._patches_and_weights)
def _aggregate_patch_parameters(
self,
patches_and_weights: list[tuple[BaseLayerPatch, float]],
orig_params: dict[str, torch.Tensor],
device: torch.device | None = None,
):
"""Helper function that aggregates the parameters from all patches into a single dict."""
params: dict[str, torch.Tensor] = {}
for patch, patch_weight in patches_and_weights:
if device is not None:
# Shallow copy the patch so that we can cast it to the target device without modifying the original patch.
patch = copy.copy(patch)
patch.to(device)
# TODO(ryand): `self` could be a quantized module. Depending on what the patch is doing with the original
# parameters, this might fail or return incorrect results.
layer_params = patch.get_parameters(orig_params, weight=patch_weight)
for param_name, param_weight in layer_params.items():
if param_name not in params:
params[param_name] = param_weight
else:
params[param_name] += param_weight
return params

View File

@ -0,0 +1,30 @@
from typing import overload
import torch
@overload
def add_nullable_tensors(a: None, b: None) -> None: ...
@overload
def add_nullable_tensors(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: ...
@overload
def add_nullable_tensors(a: torch.Tensor, b: None) -> torch.Tensor: ...
@overload
def add_nullable_tensors(a: None, b: torch.Tensor) -> torch.Tensor: ...
def add_nullable_tensors(a: torch.Tensor | None, b: torch.Tensor | None) -> torch.Tensor | None:
if a is None and b is None:
return None
elif a is None:
return b
elif b is None:
return a
else:
return a + b

View File

@ -1,12 +1,29 @@
from typing import TypeVar
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import (
from invokeai.backend.flux.modules.layers import RMSNorm
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv1d import (
CustomConv1d,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv2d import (
CustomConv2d,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_embedding import (
CustomEmbedding,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_flux_rms_norm import (
CustomFluxRMSNorm,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_group_norm import (
CustomGroupNorm,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import (
CustomLinear,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]] = {
torch.nn.Linear: CustomLinear,
@ -14,14 +31,15 @@ AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]]
torch.nn.Conv2d: CustomConv2d,
torch.nn.GroupNorm: CustomGroupNorm,
torch.nn.Embedding: CustomEmbedding,
RMSNorm: CustomFluxRMSNorm,
}
try:
# These dependencies are not expected to be present on MacOS.
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_8_bit_lt import (
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_8_bit_lt import (
CustomInvokeLinear8bitLt,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_nf4 import (
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_nf4 import (
CustomInvokeLinearNF4,
)
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
@ -33,24 +51,55 @@ except ImportError:
pass
def apply_custom_layers_to_model(model: torch.nn.Module):
def apply_custom_layers(module: torch.nn.Module):
override_type = AUTOCAST_MODULE_TYPE_MAPPING.get(type(module), None)
AUTOCAST_MODULE_TYPE_MAPPING_INVERSE = {v: k for k, v in AUTOCAST_MODULE_TYPE_MAPPING.items()}
T = TypeVar("T", bound=torch.nn.Module)
def wrap_custom_layer(module_to_wrap: torch.nn.Module, custom_layer_type: type[T]) -> T:
# HACK(ryand): We use custom initialization logic so that we can initialize a new custom layer instance from an
# existing layer instance without calling __init__() on the original layer class. We achieve this by copying
# the attributes from the original layer instance to the new instance.
custom_layer = custom_layer_type.__new__(custom_layer_type)
# Note that we share the __dict__.
# TODO(ryand): In the future, we may want to do a shallow copy of the __dict__.
custom_layer.__dict__ = module_to_wrap.__dict__
# Initialize the CustomModuleMixin fields.
CustomModuleMixin.__init__(custom_layer) # type: ignore
return custom_layer
def unwrap_custom_layer(custom_layer: torch.nn.Module, original_layer_type: type[torch.nn.Module]):
# HACK(ryand): We use custom initialization logic so that we can initialize a new custom layer instance from an
# existing layer instance without calling __init__() on the original layer class. We achieve this by copying
# the attributes from the original layer instance to the new instance.
original_layer = original_layer_type.__new__(original_layer_type)
# Note that we share the __dict__.
# TODO(ryand): In the future, we may want to do a shallow copy of the __dict__ and strip out the CustomModuleMixin
# fields.
original_layer.__dict__ = custom_layer.__dict__
return original_layer
def apply_custom_layers_to_model(module: torch.nn.Module, device_autocasting_enabled: bool = False):
for name, submodule in module.named_children():
override_type = AUTOCAST_MODULE_TYPE_MAPPING.get(type(submodule), None)
if override_type is not None:
module.__class__ = override_type
# model.apply(...) calls apply_custom_layers(...) on each module in the model.
model.apply(apply_custom_layers)
custom_layer = wrap_custom_layer(submodule, override_type)
# TODO(ryand): In the future, we should manage this flag on a per-module basis.
custom_layer.set_device_autocasting_enabled(device_autocasting_enabled)
setattr(module, name, custom_layer)
else:
# Recursively apply to submodules
apply_custom_layers_to_model(submodule, device_autocasting_enabled)
def remove_custom_layers_from_model(model: torch.nn.Module):
# Invert AUTOCAST_MODULE_TYPE_MAPPING.
original_module_type_mapping = {v: k for k, v in AUTOCAST_MODULE_TYPE_MAPPING.items()}
def remove_custom_layers(module: torch.nn.Module):
override_type = original_module_type_mapping.get(type(module), None)
def remove_custom_layers_from_model(module: torch.nn.Module):
for name, submodule in module.named_children():
override_type = AUTOCAST_MODULE_TYPE_MAPPING_INVERSE.get(type(submodule), None)
if override_type is not None:
module.__class__ = override_type
# model.apply(...) calls remove_custom_layers(...) on each module in the model.
model.apply(remove_custom_layers)
setattr(module, name, unwrap_custom_layer(submodule, override_type))
else:
remove_custom_layers_from_model(submodule)

View File

@ -7,8 +7,6 @@ from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.pad_with_zeros import pad_with_zeros
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
from invokeai.backend.patches.sidecar_wrappers.utils import wrap_module_with_sidecar_wrapper
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
@ -17,58 +15,64 @@ class LayerPatcher:
@staticmethod
@torch.no_grad()
@contextmanager
def apply_model_patches(
def apply_smart_model_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[ModelPatchRaw, float]],
prefix: str,
dtype: torch.dtype,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
force_direct_patching: bool = False,
force_sidecar_patching: bool = False,
):
"""Apply one or more LoRA patches to a model within a context manager.
Args:
model (torch.nn.Module): The model to patch.
patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
all at once.
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
cached_weights (Optional[Dict[str, torch.Tensor]], optional): Read-only copy of the model's state dict in
CPU RAM, for efficient unpatching purposes.
"""Apply 'smart' model patching that chooses whether to use direct patching or a sidecar wrapper for each
module.
"""
# original_weights are stored for unpatching layers that are directly patched.
original_weights = OriginalWeightsStorage(cached_weights)
# original_modules are stored for unpatching layers that are wrapped.
original_modules: dict[str, torch.nn.Module] = {}
try:
for patch, patch_weight in patches:
LayerPatcher.apply_model_patch(
LayerPatcher.apply_smart_model_patch(
model=model,
prefix=prefix,
patch=patch,
patch_weight=patch_weight,
original_weights=original_weights,
original_modules=original_modules,
dtype=dtype,
force_direct_patching=force_direct_patching,
force_sidecar_patching=force_sidecar_patching,
)
del patch
yield
finally:
# Restore directly patched layers.
for param_key, weight in original_weights.get_changed_weights():
cur_param = model.get_parameter(param_key)
cur_param.data = weight.to(dtype=cur_param.dtype, device=cur_param.device, copy=True)
# Clear patches from all patched modules.
# Note: This logic assumes no nested modules in original_modules.
for orig_module in original_modules.values():
orig_module.clear_patches()
@staticmethod
@torch.no_grad()
def apply_model_patch(
def apply_smart_model_patch(
model: torch.nn.Module,
prefix: str,
patch: ModelPatchRaw,
patch_weight: float,
original_weights: OriginalWeightsStorage,
original_modules: dict[str, torch.nn.Module],
dtype: torch.dtype,
force_direct_patching: bool,
force_sidecar_patching: bool,
):
"""Apply a single LoRA patch to a model.
Args:
model (torch.nn.Module): The model to patch.
prefix (str): A string prefix that precedes keys used in the LoRAs weight layers.
patch (LoRAModelRaw): The LoRA model to patch in.
patch_weight (float): The weight of the LoRA patch.
original_weights (OriginalWeightsStorage): Storage for the original weights of the model, for unpatching.
"""Apply a single LoRA patch to a model using the 'smart' patching strategy that chooses whether to use direct
patching or a sidecar wrapper for each module.
"""
if patch_weight == 0:
return
@ -89,13 +93,50 @@ class LayerPatcher:
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
LayerPatcher._apply_model_layer_patch(
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
patch_weight=patch_weight,
original_weights=original_weights,
)
# Decide whether to use direct patching or a sidecar patch.
# Direct patching is preferred, because it results in better runtime speed.
# Reasons to use sidecar patching:
# - The module is quantized, so the caller passed force_sidecar_patching=True.
# - The module already has sidecar patches.
# - The module is on the CPU (and we don't want to store a second full copy of the original weights on the
# CPU, since this would double the RAM usage)
# NOTE: For now, we don't check if the layer is quantized here. We assume that this is checked in the caller
# and that the caller will set force_sidecar_patching=True if the layer is quantized.
# TODO(ryand): Handle the case where we are running without a GPU. Should we set a config flag that allows
# forcing full patching even on the CPU?
use_sidecar_patching = False
if force_direct_patching and force_sidecar_patching:
raise ValueError("Cannot force both direct and sidecar patching.")
elif force_direct_patching:
use_sidecar_patching = False
elif force_sidecar_patching:
use_sidecar_patching = True
elif module.get_num_patches() > 0:
use_sidecar_patching = True
elif LayerPatcher._is_any_part_of_layer_on_cpu(module):
use_sidecar_patching = True
if use_sidecar_patching:
LayerPatcher._apply_model_layer_wrapper_patch(
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
patch_weight=patch_weight,
original_modules=original_modules,
dtype=dtype,
)
else:
LayerPatcher._apply_model_layer_patch(
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
patch_weight=patch_weight,
original_weights=original_weights,
)
@staticmethod
def _is_any_part_of_layer_on_cpu(layer: torch.nn.Module) -> bool:
return any(p.device.type == "cpu" for p in layer.parameters())
@staticmethod
@torch.no_grad()
@ -120,7 +161,9 @@ class LayerPatcher:
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, param_weight in patch.get_parameters(module_to_patch, weight=patch_weight).items():
for param_name, param_weight in patch.get_parameters(
dict(module_to_patch.named_parameters(recurse=False)), weight=patch_weight
).items():
param_key = module_to_patch_key + "." + param_name
module_param = module_to_patch.get_parameter(param_name)
@ -143,93 +186,9 @@ class LayerPatcher:
patch.to(device=TorchDevice.CPU_DEVICE)
@staticmethod
@torch.no_grad()
@contextmanager
def apply_model_sidecar_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[ModelPatchRaw, float]],
prefix: str,
dtype: torch.dtype,
):
"""Apply one or more LoRA sidecar patches to a model within a context manager. Sidecar patches incur some
overhead compared to normal LoRA patching, but they allow for LoRA layers to applied to base layers in any
quantization format.
Args:
model (torch.nn.Module): The model to patch.
patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
all at once.
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
dtype (torch.dtype): The compute dtype of the sidecar layers. This cannot easily be inferred from the model,
since the sidecar layers are typically applied on top of quantized layers whose weight dtype is
different from their compute dtype.
"""
original_modules: dict[str, torch.nn.Module] = {}
try:
for patch, patch_weight in patches:
LayerPatcher._apply_model_sidecar_patch(
model=model,
prefix=prefix,
patch=patch,
patch_weight=patch_weight,
original_modules=original_modules,
dtype=dtype,
)
yield
finally:
# Restore original modules.
# Note: This logic assumes no nested modules in original_modules.
for module_key, orig_module in original_modules.items():
module_parent_key, module_name = LayerPatcher._split_parent_key(module_key)
parent_module = model.get_submodule(module_parent_key)
LayerPatcher._set_submodule(parent_module, module_name, orig_module)
@staticmethod
def _apply_model_sidecar_patch(
model: torch.nn.Module,
patch: ModelPatchRaw,
patch_weight: float,
prefix: str,
original_modules: dict[str, torch.nn.Module],
dtype: torch.dtype,
):
"""Apply a single LoRA sidecar patch to a model."""
if patch_weight == 0:
return
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
# without searching, but some legacy code still uses flattened keys.
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
prefix_len = len(prefix)
for layer_key, layer in patch.layers.items():
if not layer_key.startswith(prefix):
continue
module_key, module = LayerPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
LayerPatcher._apply_model_layer_wrapper_patch(
model=model,
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
patch_weight=patch_weight,
original_modules=original_modules,
dtype=dtype,
)
@staticmethod
@torch.no_grad()
def _apply_model_layer_wrapper_patch(
model: torch.nn.Module,
module_to_patch: torch.nn.Module,
module_to_patch_key: str,
patch: BaseLayerPatch,
@ -237,25 +196,16 @@ class LayerPatcher:
original_modules: dict[str, torch.nn.Module],
dtype: torch.dtype,
):
"""Apply a single LoRA wrapper patch to a model."""
# Replace the original module with a BaseSidecarWrapper if it has not already been done.
if not isinstance(module_to_patch, BaseSidecarWrapper):
wrapped_module = wrap_module_with_sidecar_wrapper(orig_module=module_to_patch)
original_modules[module_to_patch_key] = module_to_patch
module_parent_key, module_name = LayerPatcher._split_parent_key(module_to_patch_key)
module_parent = model.get_submodule(module_parent_key)
LayerPatcher._set_submodule(module_parent, module_name, wrapped_module)
else:
assert module_to_patch_key in original_modules
wrapped_module = module_to_patch
"""Apply a single LoRA wrapper patch to a module."""
# Move the LoRA layer to the same device/dtype as the orig module.
first_param = next(module_to_patch.parameters())
device = first_param.device
patch.to(device=device, dtype=dtype)
# Add the patch to the sidecar wrapper.
wrapped_module.add_patch(patch, patch_weight)
if module_to_patch_key not in original_modules:
original_modules[module_to_patch_key] = module_to_patch
module_to_patch.add_patch(patch, patch_weight)
@staticmethod
def _split_parent_key(module_key: str) -> tuple[str, str]:

View File

@ -5,7 +5,7 @@ import torch
class BaseLayerPatch(ABC):
@abstractmethod
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
"""Get the parameter residual updates that should be applied to the original parameters. Parameters omitted
from the returned dict are not updated.
"""

View File

@ -30,7 +30,7 @@ class ConcatenatedLoRALayer(LoRALayerBase):
layer_weights = [lora_layer.get_weight(None) * lora_layer.scale() for lora_layer in self.lora_layers] # pyright: ignore[reportArgumentType]
return torch.cat(layer_weights, dim=self.concat_axis)
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
def get_bias(self, orig_bias: torch.Tensor | None) -> Optional[torch.Tensor]:
# TODO(ryand): Currently, we pass orig_bias=None to the sub-layers. If we want to support sub-layers that
# require this value, we will need to implement chunking of the original bias tensor here.
# Note that we must apply the sub-layer scales here.

View File

@ -8,11 +8,11 @@ class FluxControlLoRALayer(LoRALayer):
shapes don't match.
"""
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
"""This overrides the base class behavior to skip the reshaping step."""
scale = self.scale()
params = {"weight": self.get_weight(orig_module.weight) * (weight * scale)}
bias = self.get_bias(orig_module.bias)
params = {"weight": self.get_weight(orig_parameters["weight"]) * (weight * scale)}
bias = self.get_bias(orig_parameters.get("bias", None))
if bias is not None:
params["bias"] = bias * (weight * scale)

View File

@ -54,19 +54,19 @@ class LoRALayerBase(BaseLayerPatch):
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
def get_bias(self, orig_bias: torch.Tensor | None) -> Optional[torch.Tensor]:
return self.bias
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
scale = self.scale()
params = {"weight": self.get_weight(orig_module.weight) * (weight * scale)}
bias = self.get_bias(orig_module.bias)
params = {"weight": self.get_weight(orig_parameters["weight"]) * (weight * scale)}
bias = self.get_bias(orig_parameters.get("bias", None))
if bias is not None:
params["bias"] = bias * (weight * scale)
# Reshape all params to match the original module's shape.
for param_name, param_weight in params.items():
orig_param = orig_module.get_parameter(param_name)
orig_param = orig_parameters[param_name]
if param_weight.shape != orig_param.shape:
params[param_name] = param_weight.reshape(orig_param.shape)

View File

@ -14,10 +14,10 @@ class SetParameterLayer(BaseLayerPatch):
self.weight = weight
self.param_name = param_name
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
# Note: We intentionally ignore the weight parameter here. This matches the behavior in the official FLUX
# Control LoRA implementation.
diff = self.weight - orig_module.get_parameter(self.param_name)
diff = self.weight - orig_parameters[self.param_name]
return {self.param_name: diff}
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):

View File

@ -1,54 +0,0 @@
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
class BaseSidecarWrapper(torch.nn.Module):
"""A base class for sidecar wrappers.
A sidecar wrapper is a wrapper for an existing torch.nn.Module that applies a
list of patches as 'sidecar' patches. I.e. it applies the sidecar patches during forward inference without modifying
the original module.
Sidecar wrappers are typically used over regular patches when:
- The original module is quantized and so the weights can't be patched in the usual way.
- The original module is on the CPU and modifying the weights would require backing up the original weights and
doubling the CPU memory usage.
"""
def __init__(
self, orig_module: torch.nn.Module, patches_and_weights: list[tuple[BaseLayerPatch, float]] | None = None
):
super().__init__()
self._orig_module = orig_module
self._patches_and_weights = [] if patches_and_weights is None else patches_and_weights
@property
def orig_module(self) -> torch.nn.Module:
return self._orig_module
def add_patch(self, patch: BaseLayerPatch, patch_weight: float):
"""Add a patch to the sidecar wrapper."""
self._patches_and_weights.append((patch, patch_weight))
def _aggregate_patch_parameters(
self, patches_and_weights: list[tuple[BaseLayerPatch, float]]
) -> dict[str, torch.Tensor]:
"""Helper function that aggregates the parameters from all patches into a single dict."""
params: dict[str, torch.Tensor] = {}
for patch, patch_weight in patches_and_weights:
# TODO(ryand): self._orig_module could be quantized. Depending on what the patch is doing with the original
# module, this might fail or return incorrect results.
layer_params = patch.get_parameters(self._orig_module, weight=patch_weight)
for param_name, param_weight in layer_params.items():
if param_name not in params:
params[param_name] = param_weight
else:
params[param_name] += param_weight
return params
def forward(self, *args, **kwargs): # type: ignore
raise NotImplementedError()

View File

@ -1,11 +0,0 @@
import torch
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
class Conv1dSidecarWrapper(BaseSidecarWrapper):
def forward(self, input: torch.Tensor) -> torch.Tensor:
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
return self.orig_module(input) + torch.nn.functional.conv1d(
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
)

View File

@ -1,11 +0,0 @@
import torch
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
class Conv2dSidecarWrapper(BaseSidecarWrapper):
def forward(self, input: torch.Tensor) -> torch.Tensor:
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
return self.orig_module(input) + torch.nn.functional.conv1d(
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
)

View File

@ -1,24 +0,0 @@
import torch
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
class FluxRMSNormSidecarWrapper(BaseSidecarWrapper):
"""A sidecar wrapper for a FLUX RMSNorm layer.
This wrapper is a special case. It is added specifically to enable FLUX structural control LoRAs, which overwrite
the RMSNorm scale parameters.
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
# Given the narrow focus of this wrapper, we only support a very particular patch configuration:
assert len(self._patches_and_weights) == 1
patch, _patch_weight = self._patches_and_weights[0]
assert isinstance(patch, SetParameterLayer)
assert patch.param_name == "scale"
# Apply the patch.
# NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should
# be handled.
return torch.nn.functional.rms_norm(input, patch.weight.shape, patch.weight, eps=1e-6)

View File

@ -1,66 +0,0 @@
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
class LinearSidecarWrapper(BaseSidecarWrapper):
def _lora_forward(self, input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a Linear LoRALayer."""
x = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x = torch.nn.functional.linear(x, lora_layer.mid)
x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias)
x *= lora_weight * lora_layer.scale()
return x
def _concatenated_lora_forward(
self, input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float
) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a Linear ConcatenatedLoRALayer."""
x_chunks: list[torch.Tensor] = []
for lora_layer in concatenated_lora_layer.lora_layers:
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
x_chunk *= lora_weight * lora_layer.scale()
x_chunks.append(x_chunk)
# TODO(ryand): Generalize to support concat_axis != 0.
assert concatenated_lora_layer.concat_axis == 0
x = torch.cat(x_chunks, dim=-1)
return x
def forward(self, input: torch.Tensor) -> torch.Tensor:
# First, apply the original linear layer.
# NOTE: We slice the input to match the original weight shape in order to work with FluxControlLoRAs, which
# change the linear layer's in_features.
orig_input = input
input = orig_input[..., : self.orig_module.in_features]
output = self.orig_module(input)
# Then, apply layers for which we have optimized implementations.
unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = []
for patch, patch_weight in self._patches_and_weights:
if isinstance(patch, FluxControlLoRALayer):
# Note that we use the original input here, not the sliced input.
output += self._lora_forward(orig_input, patch, patch_weight)
elif isinstance(patch, LoRALayer):
output += self._lora_forward(input, patch, patch_weight)
elif isinstance(patch, ConcatenatedLoRALayer):
output += self._concatenated_lora_forward(input, patch, patch_weight)
else:
unprocessed_patches_and_weights.append((patch, patch_weight))
# Finally, apply any remaining patches.
if len(unprocessed_patches_and_weights) > 0:
aggregated_param_residuals = self._aggregate_patch_parameters(unprocessed_patches_and_weights)
output += torch.nn.functional.linear(
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
)
return output

View File

@ -1,20 +0,0 @@
import torch
from invokeai.backend.flux.modules.layers import RMSNorm
from invokeai.backend.patches.sidecar_wrappers.conv1d_sidecar_wrapper import Conv1dSidecarWrapper
from invokeai.backend.patches.sidecar_wrappers.conv2d_sidecar_wrapper import Conv2dSidecarWrapper
from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper
from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper
def wrap_module_with_sidecar_wrapper(orig_module: torch.nn.Module) -> torch.nn.Module:
if isinstance(orig_module, torch.nn.Linear):
return LinearSidecarWrapper(orig_module)
elif isinstance(orig_module, torch.nn.Conv1d):
return Conv1dSidecarWrapper(orig_module)
elif isinstance(orig_module, torch.nn.Conv2d):
return Conv2dSidecarWrapper(orig_module)
elif isinstance(orig_module, RMSNorm):
return FluxRMSNormSidecarWrapper(orig_module)
else:
raise ValueError(f"No sidecar wrapper found for module type: {type(orig_module)}")

View File

@ -48,11 +48,13 @@ GGML_TENSOR_OP_TABLE = {
# Ops to run on the quantized tensor.
torch.ops.aten.detach.default: apply_to_quantized_tensor, # pyright: ignore
torch.ops.aten._to_copy.default: apply_to_quantized_tensor, # pyright: ignore
torch.ops.aten.clone.default: apply_to_quantized_tensor, # pyright: ignore
# Ops to run on dequantized tensors.
torch.ops.aten.t.default: dequantize_and_run, # pyright: ignore
torch.ops.aten.addmm.default: dequantize_and_run, # pyright: ignore
torch.ops.aten.mul.Tensor: dequantize_and_run, # pyright: ignore
torch.ops.aten.add.Tensor: dequantize_and_run, # pyright: ignore
torch.ops.aten.allclose.default: dequantize_and_run, # pyright: ignore
}
if torch.backends.mps.is_available():

View File

@ -5,8 +5,8 @@ from typing import TYPE_CHECKING
from diffusers import UNet2DConditionModel
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
if TYPE_CHECKING:
@ -31,12 +31,16 @@ class LoRAExt(ExtensionBase):
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
lora_model = self._node_context.models.load(self._model_id).model
assert isinstance(lora_model, ModelPatchRaw)
LayerPatcher.apply_model_patch(
LayerPatcher.apply_smart_model_patch(
model=unet,
prefix="lora_unet_",
patch=lora_model,
patch_weight=self._weight,
original_weights=original_weights,
original_modules={},
dtype=unet.dtype,
force_direct_patching=True,
force_sidecar_patching=False,
)
del lora_model

View File

@ -1,18 +1,27 @@
import itertools
import pytest
import torch
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
CachedModelWithPartialLoad,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import CustomLinear
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
apply_custom_layers_to_model,
)
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda
@parameterize_mps_and_cuda
def test_cached_model_total_bytes(device: str):
@pytest.fixture
def model():
model = DummyModule()
apply_custom_layers_to_model(model)
return model
@parameterize_mps_and_cuda
def test_cached_model_total_bytes(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
linear1_numel = 10 * 32 + 32
linear2_numel = 32 * 64 + 64
@ -22,8 +31,7 @@ def test_cached_model_total_bytes(device: str):
@parameterize_mps_and_cuda
def test_cached_model_cur_vram_bytes(device: str):
model = DummyModule()
def test_cached_model_cur_vram_bytes(device: str, model: DummyModule):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
assert cached_model.cur_vram_bytes() == 0
@ -37,8 +45,7 @@ def test_cached_model_cur_vram_bytes(device: str):
@parameterize_mps_and_cuda
def test_cached_model_partial_load(device: str):
model = DummyModule()
def test_cached_model_partial_load(device: str, model: DummyModule):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
model_total_bytes = cached_model.total_bytes()
@ -58,14 +65,13 @@ def test_cached_model_partial_load(device: str):
if p.device.type == device and n != "buffer2"
)
# Check that the model's modules have been patched with CustomLinear layers.
assert type(model.linear1) is CustomLinear
assert type(model.linear2) is CustomLinear
# Check that the model's modules have device autocasting enabled.
assert model.linear1.is_device_autocasting_enabled()
assert model.linear2.is_device_autocasting_enabled()
@parameterize_mps_and_cuda
def test_cached_model_partial_unload(device: str):
model = DummyModule()
def test_cached_model_partial_unload(device: str, model: DummyModule):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
model_total_bytes = cached_model.total_bytes()
@ -87,14 +93,13 @@ def test_cached_model_partial_unload(device: str):
calc_tensor_size(p) for p in itertools.chain(model.parameters(), model.buffers()) if p.device.type == "cpu"
)
# Check that the model's modules are still patched with CustomLinear layers.
assert type(model.linear1) is CustomLinear
assert type(model.linear2) is CustomLinear
# Check that the model's modules still have device autocasting enabled.
assert model.linear1.is_device_autocasting_enabled()
assert model.linear2.is_device_autocasting_enabled()
@parameterize_mps_and_cuda
def test_cached_model_full_load_and_unload(device: str):
model = DummyModule()
def test_cached_model_full_load_and_unload(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
# Model starts in CPU memory.
@ -107,8 +112,8 @@ def test_cached_model_full_load_and_unload(device: str):
assert loaded_bytes == model_total_bytes
assert loaded_bytes == cached_model.cur_vram_bytes()
assert all(p.device.type == device for p in itertools.chain(model.parameters(), model.buffers()))
assert type(model.linear1) is torch.nn.Linear
assert type(model.linear2) is torch.nn.Linear
assert not model.linear1.is_device_autocasting_enabled()
assert not model.linear2.is_device_autocasting_enabled()
# Full unload the model from VRAM.
unloaded_bytes = cached_model.full_unload_from_vram()
@ -126,8 +131,7 @@ def test_cached_model_full_load_and_unload(device: str):
@parameterize_mps_and_cuda
def test_cached_model_full_load_from_partial(device: str):
model = DummyModule()
def test_cached_model_full_load_from_partial(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
# Model starts in CPU memory.
@ -140,8 +144,8 @@ def test_cached_model_full_load_from_partial(device: str):
assert loaded_bytes > 0
assert loaded_bytes < model_total_bytes
assert loaded_bytes == cached_model.cur_vram_bytes()
assert type(model.linear1) is CustomLinear
assert type(model.linear2) is CustomLinear
assert model.linear1.is_device_autocasting_enabled()
assert model.linear2.is_device_autocasting_enabled()
# Full load the rest of the model into VRAM.
loaded_bytes_2 = cached_model.full_load_to_vram()
@ -150,13 +154,12 @@ def test_cached_model_full_load_from_partial(device: str):
assert loaded_bytes + loaded_bytes_2 == cached_model.cur_vram_bytes()
assert loaded_bytes + loaded_bytes_2 == model_total_bytes
assert all(p.device.type == device for p in itertools.chain(model.parameters(), model.buffers()))
assert type(model.linear1) is torch.nn.Linear
assert type(model.linear2) is torch.nn.Linear
assert not model.linear1.is_device_autocasting_enabled()
assert not model.linear2.is_device_autocasting_enabled()
@parameterize_mps_and_cuda
def test_cached_model_full_unload_from_partial(device: str):
model = DummyModule()
def test_cached_model_full_unload_from_partial(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
# Model starts in CPU memory.
@ -184,8 +187,7 @@ def test_cached_model_full_unload_from_partial(device: str):
@parameterize_mps_and_cuda
def test_cached_model_get_cpu_state_dict(device: str):
model = DummyModule()
def test_cached_model_get_cpu_state_dict(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
# Model starts in CPU memory.
@ -209,8 +211,7 @@ def test_cached_model_get_cpu_state_dict(device: str):
@parameterize_mps_and_cuda
def test_cached_model_full_load_and_inference(device: str):
model = DummyModule()
def test_cached_model_full_load_and_inference(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
@ -237,8 +238,7 @@ def test_cached_model_full_load_and_inference(device: str):
@parameterize_mps_and_cuda
def test_cached_model_partial_load_and_inference(device: str):
model = DummyModule()
def test_cached_model_partial_load_and_inference(device: str, model: DummyModule):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
model_total_bytes = cached_model.total_bytes()
@ -262,9 +262,9 @@ def test_cached_model_partial_load_and_inference(device: str):
for n, p in itertools.chain(model.named_parameters(), model.named_buffers())
if p.device.type == device and n != "buffer2"
)
# Check that the model's modules have been patched with CustomLinear layers.
assert type(model.linear1) is CustomLinear
assert type(model.linear2) is CustomLinear
# Check that the model's modules have device autocasting enabled.
assert model.linear1.is_device_autocasting_enabled()
assert model.linear2.is_device_autocasting_enabled()
# Run inference on the GPU.
output2 = model(x.to(device))

View File

@ -0,0 +1,530 @@
import copy
import gguf
import pytest
import torch
from invokeai.backend.flux.modules.layers import RMSNorm
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
AUTOCAST_MODULE_TYPE_MAPPING,
AUTOCAST_MODULE_TYPE_MAPPING_INVERSE,
unwrap_custom_layer,
wrap_custom_layer,
)
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_8_bit_lt import (
build_linear_8bit_lt_layer,
)
from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_nf4 import (
build_linear_nf4_layer,
)
from tests.backend.quantization.gguf.test_ggml_tensor import quantize_tensor
def build_linear_layer_with_ggml_quantized_tensor(orig_layer: torch.nn.Linear | None = None):
if orig_layer is None:
orig_layer = torch.nn.Linear(32, 64)
ggml_quantized_weight = quantize_tensor(orig_layer.weight, gguf.GGMLQuantizationType.Q8_0)
orig_layer.weight = torch.nn.Parameter(ggml_quantized_weight)
ggml_quantized_bias = quantize_tensor(orig_layer.bias, gguf.GGMLQuantizationType.Q8_0)
orig_layer.bias = torch.nn.Parameter(ggml_quantized_bias)
return orig_layer
parameterize_all_devices = pytest.mark.parametrize(
("device"),
[
pytest.param("cpu"),
pytest.param(
"mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.")
),
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
],
)
parameterize_cuda_and_mps = pytest.mark.parametrize(
("device"),
[
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
pytest.param(
"mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.")
),
],
)
LayerUnderTest = tuple[torch.nn.Module, torch.Tensor, bool]
@pytest.fixture(
params=[
"linear",
"conv1d",
"conv2d",
"group_norm",
"embedding",
"flux_rms_norm",
"linear_with_ggml_quantized_tensor",
"invoke_linear_8_bit_lt",
"invoke_linear_nf4",
]
)
def layer_under_test(request: pytest.FixtureRequest) -> LayerUnderTest:
"""A fixture that returns a tuple of (layer, input, supports_cpu_inference) for the layer under test."""
layer_type = request.param
if layer_type == "linear":
return (torch.nn.Linear(8, 16), torch.randn(1, 8), True)
elif layer_type == "conv1d":
return (torch.nn.Conv1d(8, 16, 3), torch.randn(1, 8, 5), True)
elif layer_type == "conv2d":
return (torch.nn.Conv2d(8, 16, 3), torch.randn(1, 8, 5, 5), True)
elif layer_type == "group_norm":
return (torch.nn.GroupNorm(2, 8), torch.randn(1, 8, 5), True)
elif layer_type == "embedding":
return (torch.nn.Embedding(4, 8), torch.tensor([0, 1], dtype=torch.long), True)
elif layer_type == "flux_rms_norm":
return (RMSNorm(8), torch.randn(1, 8), True)
elif layer_type == "linear_with_ggml_quantized_tensor":
return (build_linear_layer_with_ggml_quantized_tensor(), torch.randn(1, 32), True)
elif layer_type == "invoke_linear_8_bit_lt":
return (build_linear_8bit_lt_layer(), torch.randn(1, 32), False)
elif layer_type == "invoke_linear_nf4":
return (build_linear_nf4_layer(), torch.randn(1, 64), False)
else:
raise ValueError(f"Unsupported layer_type: {layer_type}")
def layer_to_device_via_state_dict(layer: torch.nn.Module, device: str):
"""A helper function to move a layer to a device by roundtripping through a state dict. This most closely matches
how models are moved in the app. Some of the quantization types have broken semantics around calling .to() on the
layer directly, so this is a workaround.
We should fix this in the future.
Relevant article: https://pytorch.org/tutorials/recipes/recipes/swap_tensors.html
"""
state_dict = layer.state_dict()
state_dict = {k: v.to(device) for k, v in state_dict.items()}
layer.load_state_dict(state_dict, assign=True)
def wrap_single_custom_layer(layer: torch.nn.Module):
custom_layer_type = AUTOCAST_MODULE_TYPE_MAPPING[type(layer)]
return wrap_custom_layer(layer, custom_layer_type)
def unwrap_single_custom_layer(layer: torch.nn.Module):
orig_layer_type = AUTOCAST_MODULE_TYPE_MAPPING_INVERSE[type(layer)]
return unwrap_custom_layer(layer, orig_layer_type)
def test_isinstance(layer_under_test: LayerUnderTest):
"""Test that isinstance() and type() behave as expected after wrapping a layer in a custom layer."""
orig_layer, _, _ = layer_under_test
orig_type = type(orig_layer)
custom_layer = wrap_single_custom_layer(orig_layer)
assert isinstance(custom_layer, orig_type)
assert type(custom_layer) is not orig_type
def test_wrap_and_unwrap(layer_under_test: LayerUnderTest):
"""Test that wrapping and unwrapping a layer behaves as expected."""
orig_layer, _, _ = layer_under_test
orig_type = type(orig_layer)
# Wrap the original layer and assert that attributes of the custom layer can be accessed.
custom_layer = wrap_single_custom_layer(orig_layer)
custom_layer.set_device_autocasting_enabled(True)
assert custom_layer._device_autocasting_enabled
# Unwrap the custom layer.
# Assert that the methods of the wrapped layer are no longer accessible.
unwrapped_layer = unwrap_single_custom_layer(custom_layer)
with pytest.raises(AttributeError):
_ = unwrapped_layer.set_device_autocasting_enabled(True)
# For now, we have chosen to allow attributes to persist. We may revisit this in the future.
assert unwrapped_layer._device_autocasting_enabled
assert type(unwrapped_layer) is orig_type
@parameterize_all_devices
def test_state_dict(device: str, layer_under_test: LayerUnderTest):
"""Test that .state_dict() behaves the same on the original layer and the wrapped layer."""
orig_layer, _, _ = layer_under_test
# Get the original layer on the test device.
orig_layer.to(device)
orig_state_dict = orig_layer.state_dict()
# Wrap the original layer.
custom_layer = copy.deepcopy(orig_layer)
custom_layer = wrap_single_custom_layer(custom_layer)
custom_state_dict = custom_layer.state_dict()
assert set(orig_state_dict.keys()) == set(custom_state_dict.keys())
for k in orig_state_dict:
assert orig_state_dict[k].shape == custom_state_dict[k].shape
assert orig_state_dict[k].dtype == custom_state_dict[k].dtype
assert orig_state_dict[k].device == custom_state_dict[k].device
assert torch.allclose(orig_state_dict[k], custom_state_dict[k])
@parameterize_all_devices
def test_load_state_dict(device: str, layer_under_test: LayerUnderTest):
"""Test that .load_state_dict() behaves the same on the original layer and the wrapped layer."""
orig_layer, _, _ = layer_under_test
orig_layer.to(device)
custom_layer = copy.deepcopy(orig_layer)
custom_layer = wrap_single_custom_layer(custom_layer)
# Do a state dict roundtrip.
orig_state_dict = orig_layer.state_dict()
custom_state_dict = custom_layer.state_dict()
orig_layer.load_state_dict(custom_state_dict, assign=True)
custom_layer.load_state_dict(orig_state_dict, assign=True)
orig_state_dict = orig_layer.state_dict()
custom_state_dict = custom_layer.state_dict()
# Assert that the state dicts are the same after the roundtrip.
assert set(orig_state_dict.keys()) == set(custom_state_dict.keys())
for k in orig_state_dict:
assert orig_state_dict[k].shape == custom_state_dict[k].shape
assert orig_state_dict[k].dtype == custom_state_dict[k].dtype
assert orig_state_dict[k].device == custom_state_dict[k].device
assert torch.allclose(orig_state_dict[k], custom_state_dict[k])
@parameterize_all_devices
def test_inference_on_device(device: str, layer_under_test: LayerUnderTest):
"""Test that inference behaves the same on the original layer and the wrapped layer when all weights are on the
device.
"""
orig_layer, layer_input, supports_cpu_inference = layer_under_test
if device == "cpu" and not supports_cpu_inference:
pytest.skip("Layer does not support CPU inference.")
layer_to_device_via_state_dict(orig_layer, device)
custom_layer = copy.deepcopy(orig_layer)
custom_layer = wrap_single_custom_layer(custom_layer)
# Run inference with the original layer.
x = layer_input.to(device)
orig_output = orig_layer(x)
# Run inference with the wrapped layer.
custom_output = custom_layer(x)
assert torch.allclose(orig_output, custom_output)
@parameterize_cuda_and_mps
def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: LayerUnderTest):
"""Test that inference behaves the same on the original layer and the wrapped layer when all weights are on the
device.
"""
orig_layer, layer_input, supports_cpu_inference = layer_under_test
if device == "cpu" and not supports_cpu_inference:
pytest.skip("Layer does not support CPU inference.")
# Make sure the original layer is on the device.
layer_to_device_via_state_dict(orig_layer, device)
x = layer_input.to(device)
# Run inference with the original layer on the device.
orig_output = orig_layer(x)
# Move the original layer to the CPU.
layer_to_device_via_state_dict(orig_layer, "cpu")
# Inference should fail with an input on the device.
with pytest.raises(RuntimeError):
_ = orig_layer(x)
# Wrap the original layer.
custom_layer = copy.deepcopy(orig_layer)
custom_layer = wrap_single_custom_layer(custom_layer)
# Inference should still fail with autocasting disabled.
custom_layer.set_device_autocasting_enabled(False)
with pytest.raises(RuntimeError):
_ = custom_layer(x)
# Run inference with the wrapped layer on the device.
custom_layer.set_device_autocasting_enabled(True)
custom_output = custom_layer(x)
assert custom_output.device.type == device
assert torch.allclose(orig_output, custom_output)
PatchUnderTest = tuple[list[tuple[BaseLayerPatch, float]], torch.Tensor]
@pytest.fixture(
params=[
"single_lora",
"multiple_loras",
"concatenated_lora",
"flux_control_lora",
]
)
def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest:
"""A fixture that returns a tuple of (patches, input) for the patch under test."""
layer_type = request.param
torch.manual_seed(0)
# The assumed in/out features of the base linear layer.
in_features = 32
out_features = 64
rank = 4
if layer_type == "single_lora":
lora_layer = LoRALayer(
up=torch.randn(out_features, rank),
mid=None,
down=torch.randn(rank, in_features),
alpha=1.0,
bias=torch.randn(out_features),
)
input = torch.randn(1, in_features)
return ([(lora_layer, 0.7)], input)
elif layer_type == "multiple_loras":
lora_layer = LoRALayer(
up=torch.randn(out_features, rank),
mid=None,
down=torch.randn(rank, in_features),
alpha=1.0,
bias=torch.randn(out_features),
)
lora_layer_2 = LoRALayer(
up=torch.randn(out_features, rank),
mid=None,
down=torch.randn(rank, in_features),
alpha=1.0,
bias=torch.randn(out_features),
)
input = torch.randn(1, in_features)
return ([(lora_layer, 1.0), (lora_layer_2, 0.5)], input)
elif layer_type == "concatenated_lora":
sub_layer_out_features = [16, 16, 32]
# Create a ConcatenatedLoRA layer.
sub_layers: list[LoRALayer] = []
for out_features in sub_layer_out_features:
down = torch.randn(rank, in_features)
up = torch.randn(out_features, rank)
bias = torch.randn(out_features)
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias))
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
input = torch.randn(1, in_features)
return ([(concatenated_lora_layer, 0.7)], input)
elif layer_type == "flux_control_lora":
# Create a FluxControlLoRALayer.
patched_in_features = 40
lora_layer = FluxControlLoRALayer(
up=torch.randn(out_features, rank),
mid=None,
down=torch.randn(rank, patched_in_features),
alpha=1.0,
bias=torch.randn(out_features),
)
input = torch.randn(1, patched_in_features)
return ([(lora_layer, 0.7)], input)
else:
raise ValueError(f"Unsupported layer_type: {layer_type}")
@parameterize_all_devices
def test_linear_sidecar_patches(device: str, patch_under_test: PatchUnderTest):
patches, input = patch_under_test
# Build the base layer under test.
layer = torch.nn.Linear(32, 64)
# Move the layer and input to the device.
layer_to_device_via_state_dict(layer, device)
input = input.to(torch.device(device))
# Patch the LoRA layer into the linear layer.
layer_patched = copy.deepcopy(layer)
for patch, weight in patches:
LayerPatcher._apply_model_layer_patch(
module_to_patch=layer_patched,
module_to_patch_key="",
patch=patch,
patch_weight=weight,
original_weights=OriginalWeightsStorage(),
)
# Wrap the original layer in a custom layer and add the patch to it as a sidecar.
custom_layer = wrap_single_custom_layer(layer)
for patch, weight in patches:
patch.to(torch.device(device))
custom_layer.add_patch(patch, weight)
# Run inference with the original layer and the patched layer and assert they are equal.
output_patched = layer_patched(input)
output_custom = custom_layer(input)
assert torch.allclose(output_patched, output_custom, atol=1e-6)
@parameterize_cuda_and_mps
def test_linear_sidecar_patches_with_autocast_from_cpu_to_device(device: str, patch_under_test: PatchUnderTest):
"""Test that the output of a linear layer with sidecar patches is the same when the layer is on the device and
when the layer is on the CPU and the patches are autocasted to the device.
"""
patches, input = patch_under_test
# Build the base layer under test.
layer = torch.nn.Linear(32, 64)
# Move the layer and input to the device.
layer_to_device_via_state_dict(layer, device)
input = input.to(torch.device(device))
# Wrap the original layer in a custom layer and add the patch to it.
custom_layer = wrap_single_custom_layer(layer)
for patch, weight in patches:
patch.to(torch.device(device))
custom_layer.add_patch(patch, weight)
# Run inference with the custom layer on the device.
expected_output = custom_layer(input)
# Move the custom layer to the CPU.
layer_to_device_via_state_dict(custom_layer, "cpu")
# Move the patches to the CPU.
custom_layer.clear_patches()
for patch, weight in patches:
patch.to(torch.device("cpu"))
custom_layer.add_patch(patch, weight)
# Run inference with an input on the device, and all layer weights on the CPU. The weights should be autocasted to
# the device.
autocast_output = custom_layer(input)
assert autocast_output.device.type == device
# Assert that the outputs with and without autocasting are the same.
assert torch.allclose(expected_output, autocast_output, atol=1e-6)
@pytest.fixture(
params=[
"linear_ggml_quantized",
"invoke_linear_8_bit_lt",
"invoke_linear_nf4",
]
)
def quantized_linear_layer_under_test(request: pytest.FixtureRequest):
in_features = 32
out_features = 64
torch.manual_seed(0)
layer_type = request.param
orig_layer = torch.nn.Linear(in_features, out_features)
if layer_type == "linear_ggml_quantized":
return orig_layer, build_linear_layer_with_ggml_quantized_tensor(orig_layer)
elif layer_type == "invoke_linear_8_bit_lt":
return orig_layer, build_linear_8bit_lt_layer(orig_layer)
elif layer_type == "invoke_linear_nf4":
return orig_layer, build_linear_nf4_layer(orig_layer)
else:
raise ValueError(f"Unsupported layer_type: {layer_type}")
@parameterize_cuda_and_mps
def test_quantized_linear_sidecar_patches(
device: str,
quantized_linear_layer_under_test: tuple[torch.nn.Module, torch.nn.Module],
patch_under_test: PatchUnderTest,
):
"""Test that patches can be applied to quantized linear layers and that the output is the same as when the patch is
applied to a non-quantized linear layer.
"""
patches, input = patch_under_test
linear_layer, quantized_linear_layer = quantized_linear_layer_under_test
# Move everything to the device.
layer_to_device_via_state_dict(linear_layer, device)
layer_to_device_via_state_dict(quantized_linear_layer, device)
input = input.to(torch.device(device))
# Wrap both layers in custom layers.
linear_layer_custom = wrap_single_custom_layer(linear_layer)
quantized_linear_layer_custom = wrap_single_custom_layer(quantized_linear_layer)
# Apply the patches to the custom layers.
for patch, weight in patches:
patch.to(torch.device(device))
linear_layer_custom.add_patch(patch, weight)
quantized_linear_layer_custom.add_patch(patch, weight)
# Run inference with the original layer and the patched layer and assert they are equal.
output_linear_patched = linear_layer_custom(input)
output_quantized_patched = quantized_linear_layer_custom(input)
assert torch.allclose(output_linear_patched, output_quantized_patched, rtol=0.2, atol=0.2)
@parameterize_cuda_and_mps
def test_quantized_linear_sidecar_patches_with_autocast_from_cpu_to_device(
device: str,
quantized_linear_layer_under_test: tuple[torch.nn.Module, torch.nn.Module],
patch_under_test: PatchUnderTest,
):
"""Test that the output of a linear layer with sidecar patches is the same when the layer is on the device and
when the layer is on the CPU and the patches are autocasted to the device.
"""
patches, input = patch_under_test
_, quantized_linear_layer = quantized_linear_layer_under_test
# Move everything to the device.
layer_to_device_via_state_dict(quantized_linear_layer, device)
input = input.to(torch.device(device))
# Wrap the quantized linear layer in a custom layer and add the patch to it.
quantized_linear_layer_custom = wrap_single_custom_layer(quantized_linear_layer)
for patch, weight in patches:
patch.to(torch.device(device))
quantized_linear_layer_custom.add_patch(patch, weight)
# Run inference with the custom layer on the device.
expected_output = quantized_linear_layer_custom(input)
# Move the custom layer to the CPU.
layer_to_device_via_state_dict(quantized_linear_layer_custom, "cpu")
# Move the patches to the CPU.
quantized_linear_layer_custom.clear_patches()
for patch, weight in patches:
patch.to(torch.device("cpu"))
quantized_linear_layer_custom.add_patch(patch, weight)
# Run inference with an input on the device, and all layer weights on the CPU. The weights should be autocasted to
# the device.
autocast_output = quantized_linear_layer_custom(input)
assert autocast_output.device.type == device
# Assert that the outputs with and without autocasting are the same.
assert torch.allclose(expected_output, autocast_output, atol=1e-6)

View File

@ -0,0 +1,31 @@
import torch
from invokeai.backend.flux.modules.layers import RMSNorm
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_flux_rms_norm import (
CustomFluxRMSNorm,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
wrap_custom_layer,
)
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
def test_custom_flux_rms_norm_patch():
"""Test a SetParameterLayer patch on a CustomFluxRMSNorm layer."""
# Create a RMSNorm layer.
dim = 8
rms_norm = RMSNorm(dim)
# Create a SetParameterLayer.
new_scale = torch.randn(dim)
set_parameter_layer = SetParameterLayer("scale", new_scale)
# Wrap the RMSNorm layer in a CustomFluxRMSNorm layer.
custom_flux_rms_norm = wrap_custom_layer(rms_norm, CustomFluxRMSNorm)
custom_flux_rms_norm.add_patch(set_parameter_layer, 1.0)
# Run the CustomFluxRMSNorm layer.
input = torch.randn(1, dim)
expected_output = torch.nn.functional.rms_norm(input, new_scale.shape, new_scale, eps=1e-6)
output_custom = custom_flux_rms_norm(input)
assert torch.allclose(output_custom, expected_output, atol=1e-6)

View File

@ -0,0 +1,82 @@
import pytest
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
wrap_custom_layer,
)
if not torch.cuda.is_available():
pytest.skip("CUDA is not available", allow_module_level=True)
else:
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_8_bit_lt import (
CustomInvokeLinear8bitLt,
)
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
def build_linear_8bit_lt_layer(orig_layer: torch.nn.Linear | None = None):
if not torch.cuda.is_available():
pytest.skip("CUDA is not available")
torch.manual_seed(1)
if orig_layer is None:
orig_layer = torch.nn.Linear(32, 64)
orig_layer_state_dict = orig_layer.state_dict()
# Prepare a quantized InvokeLinear8bitLt layer.
quantized_layer = InvokeLinear8bitLt(
input_features=orig_layer.in_features, output_features=orig_layer.out_features, has_fp16_weights=False
)
quantized_layer.load_state_dict(orig_layer_state_dict)
quantized_layer.to("cuda")
# Assert that the InvokeLinear8bitLt layer is quantized.
assert quantized_layer.weight.CB is not None
assert quantized_layer.weight.SCB is not None
assert quantized_layer.weight.CB.dtype == torch.int8
return quantized_layer
@pytest.fixture
def linear_8bit_lt_layer():
return build_linear_8bit_lt_layer()
def test_custom_invoke_linear_8bit_lt_all_weights_on_cuda(linear_8bit_lt_layer: InvokeLinear8bitLt):
"""Test CustomInvokeLinear8bitLt inference with all weights on the GPU."""
# Run inference on the original layer.
x = torch.randn(1, 32).to("cuda")
y_quantized = linear_8bit_lt_layer(x)
# Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it.
custom_linear_8bit_lt_layer = wrap_custom_layer(linear_8bit_lt_layer, CustomInvokeLinear8bitLt)
y_custom = custom_linear_8bit_lt_layer(x)
# Assert that the quantized and custom layers produce the same output.
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
def test_custom_invoke_linear_8bit_lt_all_weights_on_cpu(linear_8bit_lt_layer: InvokeLinear8bitLt):
"""Test CustomInvokeLinear8bitLt inference with all weights on the CPU (streaming to the GPU)."""
# Run inference on the original layer.
x = torch.randn(1, 32).to("cuda")
y_quantized = linear_8bit_lt_layer(x)
# Copy the state dict to the CPU and reload it.
state_dict = linear_8bit_lt_layer.state_dict()
state_dict = {k: v.to("cpu") for k, v in state_dict.items()}
linear_8bit_lt_layer.load_state_dict(state_dict)
# Inference of the original layer should fail.
with pytest.raises(RuntimeError):
linear_8bit_lt_layer(x)
# Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it.
custom_linear_8bit_lt_layer = wrap_custom_layer(linear_8bit_lt_layer, CustomInvokeLinear8bitLt)
custom_linear_8bit_lt_layer.set_device_autocasting_enabled(True)
y_custom = custom_linear_8bit_lt_layer(x)
# Assert that the quantized and custom layers produce the same output.
assert torch.allclose(y_quantized, y_custom, atol=1e-5)

View File

@ -0,0 +1,92 @@
import pytest
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
wrap_custom_layer,
)
if not torch.cuda.is_available():
pytest.skip("CUDA is not available", allow_module_level=True)
else:
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_nf4 import (
CustomInvokeLinearNF4,
)
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
def build_linear_nf4_layer(orig_layer: torch.nn.Linear | None = None):
if not torch.cuda.is_available():
pytest.skip("CUDA is not available")
torch.manual_seed(1)
if orig_layer is None:
orig_layer = torch.nn.Linear(64, 16)
orig_layer_state_dict = orig_layer.state_dict()
# Prepare a quantized InvokeLinearNF4 layer.
quantized_layer = InvokeLinearNF4(input_features=orig_layer.in_features, output_features=orig_layer.out_features)
quantized_layer.load_state_dict(orig_layer_state_dict)
quantized_layer.to("cuda")
# Assert that the InvokeLinearNF4 layer is quantized.
assert quantized_layer.weight.bnb_quantized
return quantized_layer
@pytest.fixture
def linear_nf4_layer():
return build_linear_nf4_layer()
def test_custom_invoke_linear_nf4_all_weights_on_cuda(linear_nf4_layer: InvokeLinearNF4):
"""Test CustomInvokeLinearNF4 inference with all weights on the GPU."""
# Run inference on the original layer.
x = torch.randn(1, 64).to("cuda")
y_quantized = linear_nf4_layer(x)
# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
custom_linear_nf4_layer = wrap_custom_layer(linear_nf4_layer, CustomInvokeLinearNF4)
custom_linear_nf4_layer.set_device_autocasting_enabled(True)
y_custom = custom_linear_nf4_layer(x)
# Assert that the quantized and custom layers produce the same output.
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
# We run with two different input dimensions, because the NF4 layer follows a different code path depending on the
# input dimension, and this has caused issues in the past.
@pytest.mark.parametrize("input_dim_0", [1, 2])
def test_custom_invoke_linear_nf4_all_weights_on_cpu(linear_nf4_layer: InvokeLinearNF4, input_dim_0: int):
"""Test CustomInvokeLinearNF4 inference with all weights on the CPU (streaming to the GPU)."""
# Run inference on the original layer.
x = torch.randn(input_dim_0, 64).to(device="cuda")
y_quantized = linear_nf4_layer(x)
# Copy the state dict to the CPU and reload it.
state_dict = linear_nf4_layer.state_dict()
state_dict = {k: v.to("cpu") for k, v in state_dict.items()}
linear_nf4_layer.load_state_dict(state_dict)
# Inference of the original layer should fail.
with pytest.raises(RuntimeError):
linear_nf4_layer(x)
# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
custom_linear_nf4_layer = wrap_custom_layer(linear_nf4_layer, CustomInvokeLinearNF4)
custom_linear_nf4_layer.set_device_autocasting_enabled(True)
y_custom = custom_linear_nf4_layer(x)
# Assert that the state dict (and the tensors that it references) are still on the CPU.
assert all(v.device == torch.device("cpu") for v in state_dict.values())
# Assert that the weight, bias, and quant_state are all on the CPU.
assert custom_linear_nf4_layer.weight.device == torch.device("cpu")
assert custom_linear_nf4_layer.bias.device == torch.device("cpu")
assert custom_linear_nf4_layer.weight.quant_state.absmax.device == torch.device("cpu")
assert custom_linear_nf4_layer.weight.quant_state.code.device == torch.device("cpu")
# Assert that the quantized and custom layers produce the same output.
assert torch.allclose(y_quantized, y_custom, atol=1e-5)

View File

@ -1,144 +0,0 @@
import pytest
import torch
if not torch.cuda.is_available():
pytest.skip("CUDA is not available", allow_module_level=True)
else:
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_8_bit_lt import (
CustomInvokeLinear8bitLt,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_nf4 import (
CustomInvokeLinearNF4,
)
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
@pytest.fixture
def linear_8bit_lt_layer():
if not torch.cuda.is_available():
pytest.skip("CUDA is not available")
torch.manual_seed(1)
orig_layer = torch.nn.Linear(32, 64)
orig_layer_state_dict = orig_layer.state_dict()
# Prepare a quantized InvokeLinear8bitLt layer.
quantized_layer = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False)
quantized_layer.load_state_dict(orig_layer_state_dict)
quantized_layer.to("cuda")
# Assert that the InvokeLinear8bitLt layer is quantized.
assert quantized_layer.weight.CB is not None
assert quantized_layer.weight.SCB is not None
assert quantized_layer.weight.CB.dtype == torch.int8
return quantized_layer
def test_custom_invoke_linear_8bit_lt_all_weights_on_cuda(linear_8bit_lt_layer: InvokeLinear8bitLt):
"""Test CustomInvokeLinear8bitLt inference with all weights on the GPU."""
# Run inference on the original layer.
x = torch.randn(1, 32).to("cuda")
y_quantized = linear_8bit_lt_layer(x)
# Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it.
linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt
y_custom = linear_8bit_lt_layer(x)
# Assert that the quantized and custom layers produce the same output.
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
def test_custom_invoke_linear_8bit_lt_all_weights_on_cpu(linear_8bit_lt_layer: InvokeLinear8bitLt):
"""Test CustomInvokeLinear8bitLt inference with all weights on the CPU (streaming to the GPU)."""
# Run inference on the original layer.
x = torch.randn(1, 32).to("cuda")
y_quantized = linear_8bit_lt_layer(x)
# Copy the state dict to the CPU and reload it.
state_dict = linear_8bit_lt_layer.state_dict()
state_dict = {k: v.to("cpu") for k, v in state_dict.items()}
linear_8bit_lt_layer.load_state_dict(state_dict)
# Inference of the original layer should fail.
with pytest.raises(RuntimeError):
linear_8bit_lt_layer(x)
# Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it.
linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt
y_custom = linear_8bit_lt_layer(x)
# Assert that the quantized and custom layers produce the same output.
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
@pytest.fixture
def linear_nf4_layer():
if not torch.cuda.is_available():
pytest.skip("CUDA is not available")
torch.manual_seed(1)
orig_layer = torch.nn.Linear(64, 16)
orig_layer_state_dict = orig_layer.state_dict()
# Prepare a quantized InvokeLinearNF4 layer.
quantized_layer = InvokeLinearNF4(input_features=64, output_features=16)
quantized_layer.load_state_dict(orig_layer_state_dict)
quantized_layer.to("cuda")
# Assert that the InvokeLinearNF4 layer is quantized.
assert quantized_layer.weight.bnb_quantized
return quantized_layer
def test_custom_invoke_linear_nf4_all_weights_on_cuda(linear_nf4_layer: InvokeLinearNF4):
"""Test CustomInvokeLinearNF4 inference with all weights on the GPU."""
# Run inference on the original layer.
x = torch.randn(1, 64).to("cuda")
y_quantized = linear_nf4_layer(x)
# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
linear_nf4_layer.__class__ = CustomInvokeLinearNF4
y_custom = linear_nf4_layer(x)
# Assert that the quantized and custom layers produce the same output.
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
# We run with two different input dimensions, because the NF4 layer follows a different code path depending on the
# input dimension, and this has caused issues in the past.
@pytest.mark.parametrize("input_dim_0", [1, 2])
def test_custom_invoke_linear_nf4_all_weights_on_cpu(linear_nf4_layer: InvokeLinearNF4, input_dim_0: int):
"""Test CustomInvokeLinearNF4 inference with all weights on the CPU (streaming to the GPU)."""
# Run inference on the original layer.
x = torch.randn(input_dim_0, 64).to(device="cuda")
y_quantized = linear_nf4_layer(x)
# Copy the state dict to the CPU and reload it.
state_dict = linear_nf4_layer.state_dict()
state_dict = {k: v.to("cpu") for k, v in state_dict.items()}
linear_nf4_layer.load_state_dict(state_dict)
# Inference of the original layer should fail.
with pytest.raises(RuntimeError):
linear_nf4_layer(x)
# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
linear_nf4_layer.__class__ = CustomInvokeLinearNF4
y_custom = linear_nf4_layer(x)
# Assert that the state dict (and the tensors that it references) are still on the CPU.
assert all(v.device == torch.device("cpu") for v in state_dict.values())
# Assert that the weight, bias, and quant_state are all on the CPU.
assert linear_nf4_layer.weight.device == torch.device("cpu")
assert linear_nf4_layer.bias.device == torch.device("cpu")
assert linear_nf4_layer.weight.quant_state.absmax.device == torch.device("cpu")
assert linear_nf4_layer.weight.quant_state.code.device == torch.device("cpu")
# Assert that the quantized and custom layers produce the same output.
assert torch.allclose(y_quantized, y_custom, atol=1e-5)

View File

@ -72,7 +72,7 @@ def test_torch_module_autocast_linear_layer(device: torch.device, model: torch.n
assert expected.device.type == "cpu"
# Apply the custom layers to the model.
apply_custom_layers_to_model(model)
apply_custom_layers_to_model(model, device_autocasting_enabled=True)
# Run the model on the device.
autocast_result = model(x.to(device))
@ -122,7 +122,7 @@ def test_torch_module_autocast_bnb_llm_int8_linear_layer():
# Move the model back to the CPU and add the custom layers to the model.
model.to("cpu")
apply_custom_layers_to_model(model)
apply_custom_layers_to_model(model, device_autocasting_enabled=True)
# Run inference with weights being streamed to the GPU.
autocast_result = model(x.to("cuda"))

View File

@ -18,7 +18,7 @@ def test_flux_control_lora_layer_get_parameters():
orig_module = torch.nn.Linear(small_in_features, out_features)
# Test that get_parameters() behaves as expected in spite of the difference in in_features shapes.
params = layer.get_parameters(orig_module, weight=1.0)
params = layer.get_parameters(dict(orig_module.named_parameters(recurse=False)), weight=1.0)
assert "weight" in params
assert params["weight"].shape == (out_features, big_in_features)
assert params["weight"].allclose(torch.ones(out_features, big_in_features) * alpha)

View File

@ -107,7 +107,7 @@ def test_lora_layer_get_parameters():
# Create mock original module
orig_module = torch.nn.Linear(in_features, out_features)
params = layer.get_parameters(orig_module, weight=1.0)
params = layer.get_parameters(dict(orig_module.named_parameters(recurse=False)), weight=1.0)
assert "weight" in params
assert params["weight"].shape == orig_module.weight.shape
assert params["weight"].allclose(torch.ones(out_features, in_features) * alpha)

View File

@ -10,7 +10,7 @@ def test_set_parameter_layer_get_parameters():
target_weight = torch.randn(8, 4)
layer = SetParameterLayer(param_name="weight", weight=target_weight)
params = layer.get_parameters(orig_module, weight=1.0)
params = layer.get_parameters(dict(orig_module.named_parameters(recurse=False)), weight=1.0)
assert len(params) == 1
new_weight = orig_module.weight + params["weight"]
assert torch.allclose(new_weight, target_weight)

View File

@ -1,23 +0,0 @@
import torch
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper
def test_flux_rms_norm_sidecar_wrapper():
# Create a RMSNorm layer.
dim = 10
rms_norm = torch.nn.RMSNorm(dim)
# Create a SetParameterLayer.
new_scale = torch.randn(dim)
set_parameter_layer = SetParameterLayer("scale", new_scale)
# Create a FluxRMSNormSidecarWrapper.
rms_norm_wrapped = FluxRMSNormSidecarWrapper(rms_norm, [(set_parameter_layer, 1.0)])
# Run the FluxRMSNormSidecarWrapper.
input = torch.randn(1, dim)
expected_output = torch.nn.functional.rms_norm(input, new_scale.shape, new_scale, eps=1e-6)
output_wrapped = rms_norm_wrapped(input)
assert torch.allclose(output_wrapped, expected_output, atol=1e-6)

View File

@ -1,182 +0,0 @@
import copy
import torch
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.layers.full_layer import FullLayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.pad_with_zeros import pad_with_zeros
from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper
@torch.no_grad()
def test_linear_sidecar_wrapper_lora():
# Create a linear layer.
in_features = 10
out_features = 20
linear = torch.nn.Linear(in_features, out_features)
# Create a LoRA layer.
rank = 4
down = torch.randn(rank, in_features)
up = torch.randn(out_features, rank)
bias = torch.randn(out_features)
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)
# Patch the LoRA layer into the linear layer.
linear_patched = copy.deepcopy(linear)
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
# Create a LinearSidecarWrapper.
lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, 1.0)])
# Run the LoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal.
input = torch.randn(1, in_features)
output_patched = linear_patched(input)
output_wrapped = lora_wrapped(input)
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
@torch.no_grad()
def test_linear_sidecar_wrapper_multiple_loras():
# Create a linear layer.
in_features = 10
out_features = 20
linear = torch.nn.Linear(in_features, out_features)
# Create two LoRA layers.
rank = 4
lora_layer = LoRALayer(
up=torch.randn(out_features, rank),
mid=None,
down=torch.randn(rank, in_features),
alpha=1.0,
bias=torch.randn(out_features),
)
lora_layer_2 = LoRALayer(
up=torch.randn(out_features, rank),
mid=None,
down=torch.randn(rank, in_features),
alpha=1.0,
bias=torch.randn(out_features),
)
# We use different weights for the two LoRA layers to ensure this is working.
lora_weight = 1.0
lora_weight_2 = 0.5
# Patch the LoRA layers into the linear layer.
linear_patched = copy.deepcopy(linear)
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * (lora_layer.scale() * lora_weight)
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * (lora_layer.scale() * lora_weight)
linear_patched.weight.data += lora_layer_2.get_weight(linear_patched.weight) * (
lora_layer_2.scale() * lora_weight_2
)
linear_patched.bias.data += lora_layer_2.get_bias(linear_patched.bias) * (lora_layer_2.scale() * lora_weight_2)
# Create a LinearSidecarWrapper.
lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, lora_weight), (lora_layer_2, lora_weight_2)])
# Run the LoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal.
input = torch.randn(1, in_features)
output_patched = linear_patched(input)
output_wrapped = lora_wrapped(input)
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
@torch.no_grad()
def test_linear_sidecar_wrapper_concatenated_lora():
# Create a linear layer.
in_features = 5
sub_layer_out_features = [5, 10, 15]
linear = torch.nn.Linear(in_features, sum(sub_layer_out_features))
# Create a ConcatenatedLoRA layer.
rank = 4
sub_layers: list[LoRALayer] = []
for out_features in sub_layer_out_features:
down = torch.randn(rank, in_features)
up = torch.randn(out_features, rank)
bias = torch.randn(out_features)
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias))
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
# Patch the ConcatenatedLoRA layer into the linear layer.
linear_patched = copy.deepcopy(linear)
linear_patched.weight.data += (
concatenated_lora_layer.get_weight(linear_patched.weight) * concatenated_lora_layer.scale()
)
linear_patched.bias.data += concatenated_lora_layer.get_bias(linear_patched.bias) * concatenated_lora_layer.scale()
# Create a LinearSidecarWrapper.
lora_wrapped = LinearSidecarWrapper(linear, [(concatenated_lora_layer, 1.0)])
# Run the ConcatenatedLoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal.
input = torch.randn(1, in_features)
output_patched = linear_patched(input)
output_wrapped = lora_wrapped(input)
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
def test_linear_sidecar_wrapper_full_layer():
# Create a linear layer.
in_features = 10
out_features = 20
linear = torch.nn.Linear(in_features, out_features)
# Create a FullLayer.
full_layer = FullLayer(weight=torch.randn(out_features, in_features), bias=torch.randn(out_features))
# Patch the FullLayer into the linear layer.
linear_patched = copy.deepcopy(linear)
linear_patched.weight.data += full_layer.get_weight(linear_patched.weight)
linear_patched.bias.data += full_layer.get_bias(linear_patched.bias)
# Create a LinearSidecarWrapper.
full_wrapped = LinearSidecarWrapper(linear, [(full_layer, 1.0)])
# Run the FullLayer-patched linear layer and the LinearSidecarWrapper and assert they are equal.
input = torch.randn(1, in_features)
output_patched = linear_patched(input)
output_wrapped = full_wrapped(input)
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
def test_linear_sidecar_wrapper_flux_control_lora_layer():
# Create a linear layer.
orig_in_features = 10
out_features = 40
linear = torch.nn.Linear(orig_in_features, out_features)
# Create a FluxControlLoRALayer.
patched_in_features = 20
rank = 4
lora_layer = FluxControlLoRALayer(
up=torch.randn(out_features, rank),
mid=None,
down=torch.randn(rank, patched_in_features),
alpha=1.0,
bias=torch.randn(out_features),
)
# Patch the FluxControlLoRALayer into the linear layer.
linear_patched = copy.deepcopy(linear)
# Expand the existing weight.
expanded_weight = pad_with_zeros(linear_patched.weight, torch.Size([out_features, patched_in_features]))
linear_patched.weight = torch.nn.Parameter(expanded_weight, requires_grad=linear_patched.weight.requires_grad)
# Expand the existing bias.
expanded_bias = pad_with_zeros(linear_patched.bias, torch.Size([out_features]))
linear_patched.bias = torch.nn.Parameter(expanded_bias, requires_grad=linear_patched.bias.requires_grad)
# Add the residuals.
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
# Create a LinearSidecarWrapper.
lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, 1.0)])
# Run the FluxControlLoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal.
input = torch.randn(1, patched_in_features)
output_patched = linear_patched(input)
output_wrapped = lora_wrapped(input)
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)

View File

@ -0,0 +1,313 @@
import pytest
import torch
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
CachedModelWithPartialLoad,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
apply_custom_layers_to_model,
)
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
class DummyModuleWithOneLayer(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype):
super().__init__()
self.linear_layer_1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear_layer_1(x)
class DummyModuleWithTwoLayers(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype):
super().__init__()
self.linear_layer_1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
self.linear_layer_2 = torch.nn.Linear(out_features, out_features, device=device, dtype=dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear_layer_2(self.linear_layer_1(x))
@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
],
)
@pytest.mark.parametrize("num_loras", [1, 2])
@pytest.mark.parametrize(
["force_sidecar_patching", "force_direct_patching"], [(True, False), (False, True), (False, False)]
)
@torch.no_grad()
def test_apply_smart_model_patches(
device: str, num_loras: int, force_sidecar_patching: bool, force_direct_patching: bool
):
"""Test the basic behavior of ModelPatcher.apply_smart_model_patches(...). Check that unpatching works correctly."""
dtype = torch.float16
linear_in_features = 4
linear_out_features = 8
lora_rank = 2
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=dtype)
apply_custom_layers_to_model(model)
# Initialize num_loras LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[ModelPatchRaw, float]] = []
for _ in range(num_loras):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
)
}
lora = ModelPatchRaw(lora_layers)
lora_models.append((lora, lora_weight))
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_loras)
# Run inference before patching the model.
input = torch.randn(1, linear_in_features, device=device, dtype=dtype)
output_before_patch = model(input)
expect_sidecar_wrappers = device == "cpu"
if force_sidecar_patching:
expect_sidecar_wrappers = True
elif force_direct_patching:
expect_sidecar_wrappers = False
# Patch the model and run inference during the patch.
with LayerPatcher.apply_smart_model_patches(
model=model,
patches=lora_models,
prefix="",
dtype=dtype,
force_direct_patching=force_direct_patching,
force_sidecar_patching=force_sidecar_patching,
):
if expect_sidecar_wrappers:
# There should be sidecar patches in the model.
assert model.linear_layer_1.get_num_patches() == num_loras
else:
# There should be no sidecar patches in the model.
assert model.linear_layer_1.get_num_patches() == 0
torch.testing.assert_close(model.linear_layer_1.weight.data, expected_patched_linear_weight)
# After patching, the patched model should still be on its original device.
assert model.linear_layer_1.weight.data.device.type == device
# After patching, all LoRA layer weights should have been moved back to the cpu.
for lora, _ in lora_models:
assert lora.layers["linear_layer_1"].up.device.type == "cpu"
assert lora.layers["linear_layer_1"].down.device.type == "cpu"
output_during_patch = model(input)
# Run inference after unpatching.
output_after_patch = model(input)
# Check that the output before patching is different from the output during patching.
assert not torch.allclose(output_before_patch, output_during_patch)
# Check that the output before patching is the same as the output after patching.
assert torch.allclose(output_before_patch, output_after_patch)
@pytest.mark.parametrize(["num_loras"], [(1,), (2,)])
@torch.no_grad()
def test_apply_smart_lora_patches_to_partially_loaded_model(num_loras: int):
"""Test the behavior of ModelPatcher.apply_smart_lora_patches(...) when it is applied to a
CachedModelWithPartialLoad that is partially loaded into VRAM.
"""
if not torch.cuda.is_available():
pytest.skip("requires CUDA device")
# Initialize the model on the CPU.
dtype = torch.float16
linear_in_features = 4
linear_out_features = 8
lora_rank = 2
model = DummyModuleWithTwoLayers(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
apply_custom_layers_to_model(model)
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device("cuda"))
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
# Partially load the model into VRAM.
target_vram_bytes = int(model_total_bytes * 0.6)
_ = cached_model.partial_load_to_vram(target_vram_bytes)
assert cached_model.model.linear_layer_1.weight.device.type == "cuda"
assert cached_model.model.linear_layer_2.weight.device.type == "cpu"
# Initialize num_loras LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[ModelPatchRaw, float]] = []
for _ in range(num_loras):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
),
"linear_layer_2": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_out_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
),
}
lora = ModelPatchRaw(lora_layers)
lora_models.append((lora, lora_weight))
# Run inference before patching the model.
input = torch.randn(1, linear_in_features, device="cuda", dtype=dtype)
output_before_patch = cached_model.model(input)
# Patch the model and run inference during the patch.
with LayerPatcher.apply_smart_model_patches(model=cached_model.model, patches=lora_models, prefix="", dtype=dtype):
# Check that the second layer has sidecar patches, but the first layer does not.
assert cached_model.model.linear_layer_1.get_num_patches() == 0
assert cached_model.model.linear_layer_2.get_num_patches() == num_loras
output_during_patch = cached_model.model(input)
# Run inference after unpatching.
output_after_patch = cached_model.model(input)
# Check that the output before patching is different from the output during patching.
assert not torch.allclose(output_before_patch, output_during_patch)
# Check that the output before patching is the same as the output after patching.
assert torch.allclose(output_before_patch, output_after_patch)
@torch.no_grad()
@pytest.mark.parametrize(["num_loras"], [(1,), (2,)])
def test_all_patching_methods_produce_same_output(num_loras: int):
"""Test that apply_lora_wrapper_patches(...) produces the same model outputs as apply_lora_patches(...)."""
dtype = torch.float32
linear_in_features = 4
linear_out_features = 8
lora_rank = 2
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
apply_custom_layers_to_model(model)
# Initialize num_loras LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[ModelPatchRaw, float]] = []
for _ in range(num_loras):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
)
}
lora = ModelPatchRaw(lora_layers)
lora_models.append((lora, lora_weight))
input = torch.randn(1, linear_in_features, device="cpu", dtype=dtype)
with LayerPatcher.apply_smart_model_patches(
model=model, patches=lora_models, prefix="", dtype=dtype, force_direct_patching=True
):
output_force_direct = model(input)
with LayerPatcher.apply_smart_model_patches(
model=model, patches=lora_models, prefix="", dtype=dtype, force_sidecar_patching=True
):
output_force_sidecar = model(input)
with LayerPatcher.apply_smart_model_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_smart = model(input)
# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical
# differences are tolerable and expected due to the difference between sidecar vs. patching.
assert torch.allclose(output_force_direct, output_force_sidecar, atol=1e-5)
assert torch.allclose(output_force_direct, output_smart, atol=1e-5)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
@torch.no_grad()
def test_apply_smart_model_patches_change_device():
"""Test that if LoRA patching is applied on the CPU, and then the patched model is moved to the GPU, unpatching
still behaves correctly.
"""
linear_in_features = 4
linear_out_features = 8
lora_dim = 2
# Initialize the model on the CPU.
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
apply_custom_layers_to_model(model)
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
},
)
}
lora = ModelPatchRaw(lora_layers)
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
with LayerPatcher.apply_smart_model_patches(
model=model, patches=[(lora, 0.5)], prefix="", dtype=torch.float16, force_direct_patching=True
):
# After patching, all LoRA layer weights should have been moved back to the cpu.
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
# After patching, the patched model should still be on the CPU.
assert model.linear_layer_1.weight.data.device.type == "cpu"
# There should be no sidecar patches in the model.
assert model.linear_layer_1.get_num_patches() == 0
# Move the model to the GPU.
assert model.to("cuda")
# After unpatching, the original model weights should have been restored on the GPU.
assert model.linear_layer_1.weight.data.device.type == "cuda"
torch.testing.assert_close(model.linear_layer_1.weight.data, orig_linear_weight, check_device=False)
def test_apply_smart_model_patches_force_sidecar_and_direct_patching():
"""Test that ModelPatcher.apply_smart_model_patches(..., force_direct_patching=True, force_sidecar_patching=True)
raises an error.
"""
linear_in_features = 4
linear_out_features = 8
lora_rank = 2
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
apply_custom_layers_to_model(model)
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
)
}
lora = ModelPatchRaw(lora_layers)
with pytest.raises(ValueError, match="Cannot force both direct and sidecar patching."):
with LayerPatcher.apply_smart_model_patches(
model=model,
patches=[(lora, 0.5)],
prefix="",
dtype=torch.float16,
force_direct_patching=True,
force_sidecar_patching=True,
):
pass

View File

@ -1,197 +0,0 @@
import pytest
import torch
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import LayerPatcher
class DummyModule(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype):
super().__init__()
self.linear_layer_1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear_layer_1(x)
@pytest.mark.parametrize(
["device", "num_layers"],
[
("cpu", 1),
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
("cpu", 2),
pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
],
)
@torch.no_grad()
def test_apply_lora_patches(device: str, num_layers: int):
"""Test the basic behavior of ModelPatcher.apply_lora_patches(...). Check that patching and unpatching produce the
correct result, and that model/LoRA tensors are moved between devices as expected.
"""
linear_in_features = 4
linear_out_features = 8
lora_rank = 2
model = DummyModule(linear_in_features, linear_out_features, device=device, dtype=torch.float16)
# Initialize num_layers LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[ModelPatchRaw, float]] = []
for _ in range(num_layers):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
)
}
lora = ModelPatchRaw(lora_layers)
lora_models.append((lora, lora_weight))
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_layers)
with LayerPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
for lora, _ in lora_models:
assert lora.layers["linear_layer_1"].up.device.type == "cpu"
assert lora.layers["linear_layer_1"].down.device.type == "cpu"
# After patching, the patched model should still be on its original device.
assert model.linear_layer_1.weight.data.device.type == device
torch.testing.assert_close(model.linear_layer_1.weight.data, expected_patched_linear_weight)
# After unpatching, the original model weights should have been restored on the original device.
assert model.linear_layer_1.weight.data.device.type == device
torch.testing.assert_close(model.linear_layer_1.weight.data, orig_linear_weight)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
@torch.no_grad()
def test_apply_lora_patches_change_device():
"""Test that if LoRA patching is applied on the CPU, and then the patched model is moved to the GPU, unpatching
still behaves correctly.
"""
linear_in_features = 4
linear_out_features = 8
lora_dim = 2
# Initialize the model on the CPU.
model = DummyModule(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
},
)
}
lora = ModelPatchRaw(lora_layers)
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
with LayerPatcher.apply_model_patches(model=model, patches=[(lora, 0.5)], prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
# After patching, the patched model should still be on the CPU.
assert model.linear_layer_1.weight.data.device.type == "cpu"
# Move the model to the GPU.
assert model.to("cuda")
# After unpatching, the original model weights should have been restored on the GPU.
assert model.linear_layer_1.weight.data.device.type == "cuda"
torch.testing.assert_close(model.linear_layer_1.weight.data, orig_linear_weight, check_device=False)
@pytest.mark.parametrize(
["device", "num_layers"],
[
("cpu", 1),
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
("cpu", 2),
pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
],
)
def test_apply_lora_sidecar_patches(device: str, num_layers: int):
"""Test the basic behavior of ModelPatcher.apply_lora_sidecar_patches(...). Check that unpatching works correctly."""
dtype = torch.float16
linear_in_features = 4
linear_out_features = 8
lora_rank = 2
model = DummyModule(linear_in_features, linear_out_features, device=device, dtype=dtype)
# Initialize num_layers LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[ModelPatchRaw, float]] = []
for _ in range(num_layers):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
)
}
lora = ModelPatchRaw(lora_layers)
lora_models.append((lora, lora_weight))
# Run inference before patching the model.
input = torch.randn(1, linear_in_features, device=device, dtype=dtype)
output_before_patch = model(input)
# Patch the model and run inference during the patch.
with LayerPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_during_patch = model(input)
# Run inference after unpatching.
output_after_patch = model(input)
# Check that the output before patching is different from the output during patching.
assert not torch.allclose(output_before_patch, output_during_patch)
# Check that the output before patching is the same as the output after patching.
assert torch.allclose(output_before_patch, output_after_patch)
@torch.no_grad()
@pytest.mark.parametrize(["num_layers"], [(1,), (2,)])
def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int):
"""Test that apply_lora_sidecar_patches(...) produces the same model outputs as apply_lora_patches(...)."""
dtype = torch.float32
linear_in_features = 4
linear_out_features = 8
lora_rank = 2
model = DummyModule(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
# Initialize num_layers LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[ModelPatchRaw, float]] = []
for _ in range(num_layers):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
)
}
lora = ModelPatchRaw(lora_layers)
lora_models.append((lora, lora_weight))
input = torch.randn(1, linear_in_features, device="cpu", dtype=dtype)
with LayerPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""):
output_lora_patches = model(input)
with LayerPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_lora_sidecar_patches = model(input)
# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical
# differences are tolerable and expected due to the difference between sidecar vs. patching.
assert torch.allclose(output_lora_patches, output_lora_sidecar_patches, atol=1e-5)