mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-07 03:17:05 +08:00
Partial Loading PR3: Integrate 1) partial loading, 2) quantized models, 3) model patching (#7500)
## Summary This PR is the third in a sequence of PRs working towards support for partial loading of models onto the compute device (for low-VRAM operation). This PR updates the LoRA patching code so that the following features can cooperate fully: - Partial loading of weights onto the GPU - Quantized layers / weights - Model patches (e.g. LoRA) Note that this PR does not yet enable partial loading. It adds support in the model patching code so that partial loading can be enabled in a future PR. ## Technical Design Decisions The layer patching logic has been integrated into the custom layers (via `CustomModuleMixin`) rather than keeping it in a separate set of wrapper layers, as before. This has the following advantages: - It makes it easier to calculate the modified weights on the fly and then reuse the normal forward() logic. - In the future, it makes it possible to pass original parameters that have been cast to the device down to the LoRA calculation without having to re-cast (but the current implementation hasn't fully taken advantage of this yet). ## Know Limitations 1. I haven't fully solved device management for patch types that require the original layer value to calculate the patch. These aren't very common, and are not compatible with some quantized layers, so leaving this for future if there's demand. 2. There is a small speed regression for models that have CPU bottlenecks. This seems to be caused by slightly slower method resolution on the custom layers sub-classes. The regression does not show up on larger models, like FLUX, that are almost entirely GPU-limited. I think this small regression is tolerable, but if we decide that it's not, then the slowdown can easily be reclaimed by optimizing other CPU operations (e.g. if we only sent every 2nd progress image, we'd see a much more significant speedup). ## Related Issues / Discussions - https://github.com/invoke-ai/InvokeAI/pull/7492 - https://github.com/invoke-ai/InvokeAI/pull/7494 ## QA Instructions Speed tests: - Vanilla SD1 speed regression - Before: 3.156s (8.78 it/s) - After: 3.54s (8.35 it/s) - Vanilla SDXL speed regression - Before: 6.23s (4.46 it/s) - After: 6.45s (4.31 it/s) - Vanilla FLUX speed regression - Before: 12.02s (2.27 it/s) - After: 11.91s (2.29 it/s) LoRA tests with default configuration: - [x] SD1: A handful of LoRA variants - [x] SDXL: A handful of LoRA variants - [x] flux non-quantized: multiple lora variants - [x] flux bnb-quantized: multiple lora variants - [x] flux ggml-quantized: muliple lora variants - [x] flux non-quantized: FLUX control LoRA - [x] flux bnb-quantized: FLUX control LoRA - [x] flux ggml-quantized: FLUX control LoRA LoRA tests with sidecar patching forced: - [x] SD1: A handful of LoRA variants - [x] SDXL: A handful of LoRA variants - [x] flux non-quantized: multiple lora variants - [x] flux bnb-quantized: multiple lora variants - [x] flux ggml-quantized: muliple lora variants - [x] flux non-quantized: FLUX control LoRA - [x] flux bnb-quantized: FLUX control LoRA - [x] flux ggml-quantized: FLUX control LoRA Other: - [x] Smoke testing of IP-Adapter, ControlNet All tests repeated on: - [x] cuda - [x] cpu (only test SD1, because larger models are prohibitively slow) - [x] mps (skipped FLUX tests, because my Mac doesn't have enough memory to run them in a reasonable amount of time) ## Merge Plan No special instructions. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
This commit is contained in:
commit
b46d7abfb0
@ -20,8 +20,8 @@ from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import LayerPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
@ -82,10 +82,11 @@ class CompelInvocation(BaseInvocation):
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
LayerPatcher.apply_model_patches(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=text_encoder,
|
||||
patches=_lora_loader(),
|
||||
prefix="lora_te_",
|
||||
dtype=text_encoder.dtype,
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
@ -179,10 +180,11 @@ class SDXLPromptInvocationBase:
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
LayerPatcher.apply_model_patches(
|
||||
text_encoder,
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=text_encoder,
|
||||
patches=_lora_loader(),
|
||||
prefix=lora_prefix,
|
||||
dtype=text_encoder.dtype,
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
|
@ -39,8 +39,8 @@ from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import LayerPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||
@ -1003,10 +1003,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
LayerPatcher.apply_model_patches(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=unet,
|
||||
patches=_lora_loader(),
|
||||
prefix="lora_unet_",
|
||||
dtype=unet.dtype,
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
):
|
||||
|
@ -48,9 +48,9 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
)
|
||||
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import LayerPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@ -304,36 +304,33 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
config = transformer_info.config
|
||||
assert config is not None
|
||||
|
||||
# Apply LoRA models to the transformer.
|
||||
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||
# Determine if the model is quantized.
|
||||
# If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
|
||||
# slower inference than direct patching, but is agnostic to the quantization format.
|
||||
if config.format in [ModelFormat.Checkpoint]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_model_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
model_is_quantized = False
|
||||
elif config.format in [
|
||||
ModelFormat.BnbQuantizedLlmInt8b,
|
||||
ModelFormat.BnbQuantizednf4b,
|
||||
ModelFormat.GGUFQuantized,
|
||||
]:
|
||||
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
|
||||
# than directly patching the weights, but is agnostic to the quantization format.
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_model_sidecar_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
)
|
||||
model_is_quantized = True
|
||||
else:
|
||||
raise ValueError(f"Unsupported model format: {config.format}")
|
||||
|
||||
# Apply LoRA models to the transformer.
|
||||
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
dtype=inference_dtype,
|
||||
cached_weights=cached_weights,
|
||||
force_sidecar_patching=model_is_quantized,
|
||||
)
|
||||
)
|
||||
|
||||
# Prepare IP-Adapter extensions.
|
||||
pos_ip_adapter_extensions, neg_ip_adapter_extensions = self._prep_ip_adapter_extensions(
|
||||
pos_image_prompt_clip_embeds=pos_image_prompt_clip_embeds,
|
||||
|
@ -18,9 +18,9 @@ from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import LayerPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
|
||||
|
||||
@ -111,10 +111,11 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_model_patches(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=clip_text_encoder,
|
||||
patches=self._clip_lora_iterator(context),
|
||||
prefix=FLUX_LORA_CLIP_PREFIX,
|
||||
dtype=clip_text_encoder.dtype,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
|
@ -17,9 +17,9 @@ from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import SD3ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import LayerPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
|
||||
|
||||
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
|
||||
@ -150,10 +150,11 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_model_patches(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=clip_text_encoder,
|
||||
patches=self._clip_lora_iterator(context, clip_model),
|
||||
prefix=FLUX_LORA_CLIP_PREFIX,
|
||||
dtype=clip_text_encoder.dtype,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
|
@ -22,8 +22,8 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import UNetField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import LayerPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
|
||||
MultiDiffusionPipeline,
|
||||
@ -207,7 +207,9 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
unet_info as unet,
|
||||
LayerPatcher.apply_model_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=unet, patches=_lora_loader(), prefix="lora_unet_", dtype=unet.dtype
|
||||
),
|
||||
):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
@ -1,9 +1,7 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
|
||||
AUTOCAST_MODULE_TYPE_MAPPING,
|
||||
apply_custom_layers_to_model,
|
||||
remove_custom_layers_from_model,
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@ -45,10 +43,10 @@ class CachedModelWithPartialLoad:
|
||||
|
||||
def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]:
|
||||
"""Find all modules that support autocasting."""
|
||||
return {n: m for n, m in self._model.named_modules() if type(m) in AUTOCAST_MODULE_TYPE_MAPPING}
|
||||
return {n: m for n, m in self._model.named_modules() if isinstance(m, CustomModuleMixin)} # type: ignore
|
||||
|
||||
def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]:
|
||||
keys_in_modules_that_do_not_support_autocast = set()
|
||||
keys_in_modules_that_do_not_support_autocast: set[str] = set()
|
||||
for key in self._cpu_state_dict.keys():
|
||||
for module_name in self._modules_that_support_autocast.keys():
|
||||
if key.startswith(module_name):
|
||||
@ -70,6 +68,11 @@ class CachedModelWithPartialLoad:
|
||||
if name in module._non_persistent_buffers_set:
|
||||
module._buffers[name] = buffer.to(device, copy=True)
|
||||
|
||||
def _set_autocast_enabled_in_all_modules(self, enabled: bool):
|
||||
"""Set autocast_enabled flag in all modules that support device autocasting."""
|
||||
for module in self._modules_that_support_autocast.values():
|
||||
module.set_device_autocasting_enabled(enabled)
|
||||
|
||||
@property
|
||||
def model(self) -> torch.nn.Module:
|
||||
return self._model
|
||||
@ -114,7 +117,7 @@ class CachedModelWithPartialLoad:
|
||||
|
||||
cur_state_dict = self._model.state_dict()
|
||||
|
||||
# First, process the keys *must* be loaded into VRAM.
|
||||
# First, process the keys that *must* be loaded into VRAM.
|
||||
for key in self._keys_in_modules_that_do_not_support_autocast:
|
||||
param = cur_state_dict[key]
|
||||
if param.device.type == self._compute_device.type:
|
||||
@ -157,10 +160,10 @@ class CachedModelWithPartialLoad:
|
||||
self._cur_vram_bytes += vram_bytes_loaded
|
||||
|
||||
if fully_loaded:
|
||||
remove_custom_layers_from_model(self._model)
|
||||
self._set_autocast_enabled_in_all_modules(False)
|
||||
# TODO(ryand): Warn if the self.cur_vram_bytes() and self.total_bytes() are out of sync.
|
||||
else:
|
||||
apply_custom_layers_to_model(self._model)
|
||||
self._set_autocast_enabled_in_all_modules(True)
|
||||
|
||||
# Move all non-persistent buffers to the compute device. These are a weird edge case and do not participate in
|
||||
# the vram_bytes_loaded tracking.
|
||||
@ -197,5 +200,5 @@ class CachedModelWithPartialLoad:
|
||||
|
||||
# We may have gone from a fully-loaded model to a partially-loaded model, so we need to reapply the custom
|
||||
# layers.
|
||||
apply_custom_layers_to_model(self._model)
|
||||
self._set_autocast_enabled_in_all_modules(True)
|
||||
return vram_bytes_freed
|
||||
|
@ -13,6 +13,9 @@ from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
|
||||
apply_custom_layers_to_model,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@ -143,6 +146,10 @@ class ModelCache:
|
||||
size = calc_model_size_by_data(self._logger, model)
|
||||
self.make_room(size)
|
||||
|
||||
# Inject custom modules into the model.
|
||||
if isinstance(model, torch.nn.Module):
|
||||
apply_custom_layers_to_model(model)
|
||||
|
||||
running_on_cpu = self._execution_device == torch.device("cpu")
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
|
||||
cache_record = CacheRecord(key=key, model=model, device=self._storage_device, state_dict=state_dict, size=size)
|
||||
|
@ -1,50 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
|
||||
# This file contains custom torch.nn.Module classes that support streaming of weights to the target device.
|
||||
# Each class sub-classes the original module type that is is replacing, so the following properties are preserved:
|
||||
# - isinstance(m, torch.nn.OrginalModule) should still work.
|
||||
# - Patching the weights (e.g. for LoRA) should still work if non-quantized.
|
||||
|
||||
|
||||
class CustomLinear(torch.nn.Linear):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
|
||||
class CustomConv1d(torch.nn.Conv1d):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
|
||||
class CustomConv2d(torch.nn.Conv2d):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
|
||||
class CustomGroupNorm(torch.nn.GroupNorm):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
|
||||
|
||||
class CustomEmbedding(torch.nn.Embedding):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
return torch.nn.functional.embedding(
|
||||
input,
|
||||
weight,
|
||||
self.padding_idx,
|
||||
self.max_norm,
|
||||
self.norm_type,
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
@ -0,0 +1,8 @@
|
||||
|
||||
This directory contains custom implementations of common torch.nn.Module classes that add support for:
|
||||
- Streaming weights to the execution device
|
||||
- Applying sidecar patches at execution time (e.g. sidecar LoRA layers)
|
||||
|
||||
Each custom class sub-classes the original module type that is is replacing, so the following properties are preserved:
|
||||
- `isinstance(m, torch.nn.OrginalModule)` should still work.
|
||||
- Patching the weights directly (e.g. for LoRA) should still work. (Of course, this is not possible for quantized layers, hence the sidecar support.)
|
@ -0,0 +1,43 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
|
||||
add_nullable_tensors,
|
||||
)
|
||||
|
||||
|
||||
class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin):
|
||||
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
|
||||
# Prepare the original parameters for the patch aggregation.
|
||||
orig_params = {"weight": weight, "bias": bias}
|
||||
# Filter out None values.
|
||||
orig_params = {k: v for k, v in orig_params.items() if v is not None}
|
||||
|
||||
aggregated_param_residuals = self._aggregate_patch_parameters(
|
||||
patches_and_weights=self._patches_and_weights,
|
||||
orig_params=orig_params,
|
||||
device=input.device,
|
||||
)
|
||||
|
||||
weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None))
|
||||
bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None))
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if len(self._patches_and_weights) > 0:
|
||||
return self._autocast_forward_with_patches(input)
|
||||
elif self._device_autocasting_enabled:
|
||||
return self._autocast_forward(input)
|
||||
else:
|
||||
return super().forward(input)
|
@ -0,0 +1,43 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
|
||||
add_nullable_tensors,
|
||||
)
|
||||
|
||||
|
||||
class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
|
||||
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
|
||||
# Prepare the original parameters for the patch aggregation.
|
||||
orig_params = {"weight": weight, "bias": bias}
|
||||
# Filter out None values.
|
||||
orig_params = {k: v for k, v in orig_params.items() if v is not None}
|
||||
|
||||
aggregated_param_residuals = self._aggregate_patch_parameters(
|
||||
patches_and_weights=self._patches_and_weights,
|
||||
orig_params=orig_params,
|
||||
device=input.device,
|
||||
)
|
||||
|
||||
weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None))
|
||||
bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None))
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if len(self._patches_and_weights) > 0:
|
||||
return self._autocast_forward_with_patches(input)
|
||||
elif self._device_autocasting_enabled:
|
||||
return self._autocast_forward(input)
|
||||
else:
|
||||
return super().forward(input)
|
@ -0,0 +1,29 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
|
||||
|
||||
class CustomEmbedding(torch.nn.Embedding, CustomModuleMixin):
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
return torch.nn.functional.embedding(
|
||||
input,
|
||||
weight,
|
||||
self.padding_idx,
|
||||
self.max_norm,
|
||||
self.norm_type,
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if len(self._patches_and_weights) > 0:
|
||||
raise RuntimeError("Embedding layers do not support patches")
|
||||
|
||||
if self._device_autocasting_enabled:
|
||||
return self._autocast_forward(input)
|
||||
else:
|
||||
return super().forward(input)
|
@ -0,0 +1,36 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.modules.layers import RMSNorm
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
|
||||
|
||||
|
||||
class CustomFluxRMSNorm(RMSNorm, CustomModuleMixin):
|
||||
def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Currently, CustomFluxRMSNorm layers only support patching with a single SetParameterLayer.
|
||||
assert len(self._patches_and_weights) == 1
|
||||
patch, _patch_weight = self._patches_and_weights[0]
|
||||
assert isinstance(patch, SetParameterLayer)
|
||||
assert patch.param_name == "scale"
|
||||
|
||||
scale = cast_to_device(patch.weight, x.device)
|
||||
|
||||
# Apply the patch.
|
||||
# NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should
|
||||
# be handled.
|
||||
return torch.nn.functional.rms_norm(x, scale.shape, scale, eps=1e-6)
|
||||
|
||||
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
scale = cast_to_device(self.scale, x.device)
|
||||
return torch.nn.functional.rms_norm(x, scale.shape, scale, eps=1e-6)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if len(self._patches_and_weights) > 0:
|
||||
return self._autocast_forward_with_patches(x)
|
||||
elif self._device_autocasting_enabled:
|
||||
return self._autocast_forward(x)
|
||||
else:
|
||||
return super().forward(x)
|
@ -0,0 +1,22 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
|
||||
|
||||
class CustomGroupNorm(torch.nn.GroupNorm, CustomModuleMixin):
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if len(self._patches_and_weights) > 0:
|
||||
raise RuntimeError("GroupNorm layers do not support patches")
|
||||
|
||||
if self._device_autocasting_enabled:
|
||||
return self._autocast_forward(input)
|
||||
else:
|
||||
return super().forward(input)
|
@ -2,11 +2,20 @@ import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import (
|
||||
autocast_linear_forward_sidecar_patches,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
|
||||
|
||||
|
||||
class CustomInvokeLinear8bitLt(InvokeLinear8bitLt):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
class CustomInvokeLinear8bitLt(InvokeLinear8bitLt, CustomModuleMixin):
|
||||
def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return autocast_linear_forward_sidecar_patches(self, x, self._patches_and_weights)
|
||||
|
||||
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
matmul_state = bnb.MatmulLtState()
|
||||
matmul_state.threshold = self.state.threshold
|
||||
matmul_state.has_fp16_weights = self.state.has_fp16_weights
|
||||
@ -25,3 +34,11 @@ class CustomInvokeLinear8bitLt(InvokeLinear8bitLt):
|
||||
# it's dtype field must be accessible, even though it's not used. We pass in self.weight even though it could be
|
||||
# on the wrong device.
|
||||
return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if len(self._patches_and_weights) > 0:
|
||||
return self._autocast_forward_with_patches(x)
|
||||
elif self._device_autocasting_enabled:
|
||||
return self._autocast_forward(x)
|
||||
else:
|
||||
return super().forward(x)
|
@ -4,11 +4,20 @@ import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import (
|
||||
autocast_linear_forward_sidecar_patches,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
|
||||
|
||||
|
||||
class CustomInvokeLinearNF4(InvokeLinearNF4):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
class CustomInvokeLinearNF4(InvokeLinearNF4, CustomModuleMixin):
|
||||
def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return autocast_linear_forward_sidecar_patches(self, x, self._patches_and_weights)
|
||||
|
||||
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
bnb.nn.modules.fix_4bit_weight_quant_state_from_module(self)
|
||||
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
@ -43,3 +52,11 @@ class CustomInvokeLinearNF4(InvokeLinearNF4):
|
||||
|
||||
bias = cast_to_device(self.bias, x.device)
|
||||
return bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state).to(inp_dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if len(self._patches_and_weights) > 0:
|
||||
return self._autocast_forward_with_patches(x)
|
||||
elif self._device_autocasting_enabled:
|
||||
return self._autocast_forward(x)
|
||||
else:
|
||||
return super().forward(x)
|
@ -0,0 +1,106 @@
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
|
||||
|
||||
def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
|
||||
"""An optimized implementation of the residual calculation for a sidecar linear LoRALayer."""
|
||||
x = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x = torch.nn.functional.linear(x, lora_layer.mid)
|
||||
x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias)
|
||||
x *= lora_weight * lora_layer.scale()
|
||||
return x
|
||||
|
||||
|
||||
def concatenated_lora_forward(
|
||||
input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float
|
||||
) -> torch.Tensor:
|
||||
"""An optimized implementation of the residual calculation for a sidecar ConcatenatedLoRALayer."""
|
||||
x_chunks: list[torch.Tensor] = []
|
||||
for lora_layer in concatenated_lora_layer.lora_layers:
|
||||
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
|
||||
x_chunk *= lora_weight * lora_layer.scale()
|
||||
x_chunks.append(x_chunk)
|
||||
|
||||
# TODO(ryand): Generalize to support concat_axis != 0.
|
||||
assert concatenated_lora_layer.concat_axis == 0
|
||||
x = torch.cat(x_chunks, dim=-1)
|
||||
return x
|
||||
|
||||
|
||||
def autocast_linear_forward_sidecar_patches(
|
||||
orig_module: torch.nn.Linear, input: torch.Tensor, patches_and_weights: list[tuple[BaseLayerPatch, float]]
|
||||
) -> torch.Tensor:
|
||||
"""A function that runs a linear layer (quantized or non-quantized) with sidecar patches for a linear layer.
|
||||
Compatible with both quantized and non-quantized Linear layers.
|
||||
"""
|
||||
# First, apply the original linear layer.
|
||||
# NOTE: We slice the input to match the original weight shape in order to work with FluxControlLoRAs, which
|
||||
# change the linear layer's in_features.
|
||||
orig_input = input
|
||||
input = orig_input[..., : orig_module.in_features]
|
||||
output = orig_module._autocast_forward(input)
|
||||
|
||||
# Then, apply layers for which we have optimized implementations.
|
||||
unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = []
|
||||
for patch, patch_weight in patches_and_weights:
|
||||
# Shallow copy the patch so that we can cast it to the target device without modifying the original patch.
|
||||
patch = copy.copy(patch)
|
||||
patch.to(input.device)
|
||||
|
||||
if isinstance(patch, FluxControlLoRALayer):
|
||||
# Note that we use the original input here, not the sliced input.
|
||||
output += linear_lora_forward(orig_input, patch, patch_weight)
|
||||
elif isinstance(patch, LoRALayer):
|
||||
output += linear_lora_forward(input, patch, patch_weight)
|
||||
elif isinstance(patch, ConcatenatedLoRALayer):
|
||||
output += concatenated_lora_forward(input, patch, patch_weight)
|
||||
else:
|
||||
unprocessed_patches_and_weights.append((patch, patch_weight))
|
||||
|
||||
# Finally, apply any remaining patches.
|
||||
if len(unprocessed_patches_and_weights) > 0:
|
||||
# Prepare the original parameters for the patch aggregation.
|
||||
orig_params = {"weight": orig_module.weight, "bias": orig_module.bias}
|
||||
# Filter out None values.
|
||||
orig_params = {k: v for k, v in orig_params.items() if v is not None}
|
||||
|
||||
aggregated_param_residuals = orig_module._aggregate_patch_parameters(
|
||||
unprocessed_patches_and_weights, orig_params=orig_params, device=input.device
|
||||
)
|
||||
output += torch.nn.functional.linear(
|
||||
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class CustomLinear(torch.nn.Linear, CustomModuleMixin):
|
||||
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return autocast_linear_forward_sidecar_patches(self, input, self._patches_and_weights)
|
||||
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if len(self._patches_and_weights) > 0:
|
||||
return self._autocast_forward_with_patches(input)
|
||||
elif self._device_autocasting_enabled:
|
||||
return self._autocast_forward(input)
|
||||
else:
|
||||
return super().forward(input)
|
@ -0,0 +1,63 @@
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
|
||||
|
||||
class CustomModuleMixin:
|
||||
"""A mixin class for custom modules that enables device autocasting of module parameters."""
|
||||
|
||||
def __init__(self):
|
||||
self._device_autocasting_enabled = False
|
||||
self._patches_and_weights: list[tuple[BaseLayerPatch, float]] = []
|
||||
|
||||
def set_device_autocasting_enabled(self, enabled: bool):
|
||||
"""Pass True to enable autocasting of module parameters to the same device as the input tensor. Pass False to
|
||||
disable autocasting, which results in slightly faster execution speed when we know that device autocasting is
|
||||
not needed.
|
||||
"""
|
||||
self._device_autocasting_enabled = enabled
|
||||
|
||||
def is_device_autocasting_enabled(self) -> bool:
|
||||
"""Check if device autocasting is enabled for the module."""
|
||||
return self._device_autocasting_enabled
|
||||
|
||||
def add_patch(self, patch: BaseLayerPatch, patch_weight: float):
|
||||
"""Add a patch to the module."""
|
||||
self._patches_and_weights.append((patch, patch_weight))
|
||||
|
||||
def clear_patches(self):
|
||||
"""Clear all patches from the module."""
|
||||
self._patches_and_weights = []
|
||||
|
||||
def get_num_patches(self) -> int:
|
||||
"""Get the number of patches in the module."""
|
||||
return len(self._patches_and_weights)
|
||||
|
||||
def _aggregate_patch_parameters(
|
||||
self,
|
||||
patches_and_weights: list[tuple[BaseLayerPatch, float]],
|
||||
orig_params: dict[str, torch.Tensor],
|
||||
device: torch.device | None = None,
|
||||
):
|
||||
"""Helper function that aggregates the parameters from all patches into a single dict."""
|
||||
params: dict[str, torch.Tensor] = {}
|
||||
|
||||
for patch, patch_weight in patches_and_weights:
|
||||
if device is not None:
|
||||
# Shallow copy the patch so that we can cast it to the target device without modifying the original patch.
|
||||
patch = copy.copy(patch)
|
||||
patch.to(device)
|
||||
|
||||
# TODO(ryand): `self` could be a quantized module. Depending on what the patch is doing with the original
|
||||
# parameters, this might fail or return incorrect results.
|
||||
layer_params = patch.get_parameters(orig_params, weight=patch_weight)
|
||||
|
||||
for param_name, param_weight in layer_params.items():
|
||||
if param_name not in params:
|
||||
params[param_name] = param_weight
|
||||
else:
|
||||
params[param_name] += param_weight
|
||||
|
||||
return params
|
@ -0,0 +1,30 @@
|
||||
from typing import overload
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@overload
|
||||
def add_nullable_tensors(a: None, b: None) -> None: ...
|
||||
|
||||
|
||||
@overload
|
||||
def add_nullable_tensors(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
@overload
|
||||
def add_nullable_tensors(a: torch.Tensor, b: None) -> torch.Tensor: ...
|
||||
|
||||
|
||||
@overload
|
||||
def add_nullable_tensors(a: None, b: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
def add_nullable_tensors(a: torch.Tensor | None, b: torch.Tensor | None) -> torch.Tensor | None:
|
||||
if a is None and b is None:
|
||||
return None
|
||||
elif a is None:
|
||||
return b
|
||||
elif b is None:
|
||||
return a
|
||||
else:
|
||||
return a + b
|
@ -1,12 +1,29 @@
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import (
|
||||
from invokeai.backend.flux.modules.layers import RMSNorm
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv1d import (
|
||||
CustomConv1d,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv2d import (
|
||||
CustomConv2d,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_embedding import (
|
||||
CustomEmbedding,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_flux_rms_norm import (
|
||||
CustomFluxRMSNorm,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_group_norm import (
|
||||
CustomGroupNorm,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import (
|
||||
CustomLinear,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
|
||||
AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]] = {
|
||||
torch.nn.Linear: CustomLinear,
|
||||
@ -14,14 +31,15 @@ AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]]
|
||||
torch.nn.Conv2d: CustomConv2d,
|
||||
torch.nn.GroupNorm: CustomGroupNorm,
|
||||
torch.nn.Embedding: CustomEmbedding,
|
||||
RMSNorm: CustomFluxRMSNorm,
|
||||
}
|
||||
|
||||
try:
|
||||
# These dependencies are not expected to be present on MacOS.
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_8_bit_lt import (
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_8_bit_lt import (
|
||||
CustomInvokeLinear8bitLt,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_nf4 import (
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_nf4 import (
|
||||
CustomInvokeLinearNF4,
|
||||
)
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
|
||||
@ -33,24 +51,55 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def apply_custom_layers_to_model(model: torch.nn.Module):
|
||||
def apply_custom_layers(module: torch.nn.Module):
|
||||
override_type = AUTOCAST_MODULE_TYPE_MAPPING.get(type(module), None)
|
||||
AUTOCAST_MODULE_TYPE_MAPPING_INVERSE = {v: k for k, v in AUTOCAST_MODULE_TYPE_MAPPING.items()}
|
||||
|
||||
|
||||
T = TypeVar("T", bound=torch.nn.Module)
|
||||
|
||||
|
||||
def wrap_custom_layer(module_to_wrap: torch.nn.Module, custom_layer_type: type[T]) -> T:
|
||||
# HACK(ryand): We use custom initialization logic so that we can initialize a new custom layer instance from an
|
||||
# existing layer instance without calling __init__() on the original layer class. We achieve this by copying
|
||||
# the attributes from the original layer instance to the new instance.
|
||||
custom_layer = custom_layer_type.__new__(custom_layer_type)
|
||||
# Note that we share the __dict__.
|
||||
# TODO(ryand): In the future, we may want to do a shallow copy of the __dict__.
|
||||
custom_layer.__dict__ = module_to_wrap.__dict__
|
||||
|
||||
# Initialize the CustomModuleMixin fields.
|
||||
CustomModuleMixin.__init__(custom_layer) # type: ignore
|
||||
return custom_layer
|
||||
|
||||
|
||||
def unwrap_custom_layer(custom_layer: torch.nn.Module, original_layer_type: type[torch.nn.Module]):
|
||||
# HACK(ryand): We use custom initialization logic so that we can initialize a new custom layer instance from an
|
||||
# existing layer instance without calling __init__() on the original layer class. We achieve this by copying
|
||||
# the attributes from the original layer instance to the new instance.
|
||||
original_layer = original_layer_type.__new__(original_layer_type)
|
||||
# Note that we share the __dict__.
|
||||
# TODO(ryand): In the future, we may want to do a shallow copy of the __dict__ and strip out the CustomModuleMixin
|
||||
# fields.
|
||||
original_layer.__dict__ = custom_layer.__dict__
|
||||
return original_layer
|
||||
|
||||
|
||||
def apply_custom_layers_to_model(module: torch.nn.Module, device_autocasting_enabled: bool = False):
|
||||
for name, submodule in module.named_children():
|
||||
override_type = AUTOCAST_MODULE_TYPE_MAPPING.get(type(submodule), None)
|
||||
if override_type is not None:
|
||||
module.__class__ = override_type
|
||||
|
||||
# model.apply(...) calls apply_custom_layers(...) on each module in the model.
|
||||
model.apply(apply_custom_layers)
|
||||
custom_layer = wrap_custom_layer(submodule, override_type)
|
||||
# TODO(ryand): In the future, we should manage this flag on a per-module basis.
|
||||
custom_layer.set_device_autocasting_enabled(device_autocasting_enabled)
|
||||
setattr(module, name, custom_layer)
|
||||
else:
|
||||
# Recursively apply to submodules
|
||||
apply_custom_layers_to_model(submodule, device_autocasting_enabled)
|
||||
|
||||
|
||||
def remove_custom_layers_from_model(model: torch.nn.Module):
|
||||
# Invert AUTOCAST_MODULE_TYPE_MAPPING.
|
||||
original_module_type_mapping = {v: k for k, v in AUTOCAST_MODULE_TYPE_MAPPING.items()}
|
||||
|
||||
def remove_custom_layers(module: torch.nn.Module):
|
||||
override_type = original_module_type_mapping.get(type(module), None)
|
||||
def remove_custom_layers_from_model(module: torch.nn.Module):
|
||||
for name, submodule in module.named_children():
|
||||
override_type = AUTOCAST_MODULE_TYPE_MAPPING_INVERSE.get(type(submodule), None)
|
||||
if override_type is not None:
|
||||
module.__class__ = override_type
|
||||
|
||||
# model.apply(...) calls remove_custom_layers(...) on each module in the model.
|
||||
model.apply(remove_custom_layers)
|
||||
setattr(module, name, unwrap_custom_layer(submodule, override_type))
|
||||
else:
|
||||
remove_custom_layers_from_model(submodule)
|
||||
|
@ -7,8 +7,6 @@ from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.pad_with_zeros import pad_with_zeros
|
||||
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
|
||||
from invokeai.backend.patches.sidecar_wrappers.utils import wrap_module_with_sidecar_wrapper
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
@ -17,58 +15,64 @@ class LayerPatcher:
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_model_patches(
|
||||
def apply_smart_model_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[ModelPatchRaw, float]],
|
||||
prefix: str,
|
||||
dtype: torch.dtype,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
force_direct_patching: bool = False,
|
||||
force_sidecar_patching: bool = False,
|
||||
):
|
||||
"""Apply one or more LoRA patches to a model within a context manager.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to patch.
|
||||
patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and
|
||||
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
|
||||
all at once.
|
||||
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
|
||||
cached_weights (Optional[Dict[str, torch.Tensor]], optional): Read-only copy of the model's state dict in
|
||||
CPU RAM, for efficient unpatching purposes.
|
||||
"""Apply 'smart' model patching that chooses whether to use direct patching or a sidecar wrapper for each
|
||||
module.
|
||||
"""
|
||||
|
||||
# original_weights are stored for unpatching layers that are directly patched.
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
# original_modules are stored for unpatching layers that are wrapped.
|
||||
original_modules: dict[str, torch.nn.Module] = {}
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
LayerPatcher.apply_model_patch(
|
||||
LayerPatcher.apply_smart_model_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
original_modules=original_modules,
|
||||
dtype=dtype,
|
||||
force_direct_patching=force_direct_patching,
|
||||
force_sidecar_patching=force_sidecar_patching,
|
||||
)
|
||||
del patch
|
||||
|
||||
yield
|
||||
finally:
|
||||
# Restore directly patched layers.
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
cur_param = model.get_parameter(param_key)
|
||||
cur_param.data = weight.to(dtype=cur_param.dtype, device=cur_param.device, copy=True)
|
||||
|
||||
# Clear patches from all patched modules.
|
||||
# Note: This logic assumes no nested modules in original_modules.
|
||||
for orig_module in original_modules.values():
|
||||
orig_module.clear_patches()
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def apply_model_patch(
|
||||
def apply_smart_model_patch(
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
patch: ModelPatchRaw,
|
||||
patch_weight: float,
|
||||
original_weights: OriginalWeightsStorage,
|
||||
original_modules: dict[str, torch.nn.Module],
|
||||
dtype: torch.dtype,
|
||||
force_direct_patching: bool,
|
||||
force_sidecar_patching: bool,
|
||||
):
|
||||
"""Apply a single LoRA patch to a model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to patch.
|
||||
prefix (str): A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
patch (LoRAModelRaw): The LoRA model to patch in.
|
||||
patch_weight (float): The weight of the LoRA patch.
|
||||
original_weights (OriginalWeightsStorage): Storage for the original weights of the model, for unpatching.
|
||||
"""Apply a single LoRA patch to a model using the 'smart' patching strategy that chooses whether to use direct
|
||||
patching or a sidecar wrapper for each module.
|
||||
"""
|
||||
if patch_weight == 0:
|
||||
return
|
||||
@ -89,13 +93,50 @@ class LayerPatcher:
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
LayerPatcher._apply_model_layer_patch(
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
patch=layer,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
# Decide whether to use direct patching or a sidecar patch.
|
||||
# Direct patching is preferred, because it results in better runtime speed.
|
||||
# Reasons to use sidecar patching:
|
||||
# - The module is quantized, so the caller passed force_sidecar_patching=True.
|
||||
# - The module already has sidecar patches.
|
||||
# - The module is on the CPU (and we don't want to store a second full copy of the original weights on the
|
||||
# CPU, since this would double the RAM usage)
|
||||
# NOTE: For now, we don't check if the layer is quantized here. We assume that this is checked in the caller
|
||||
# and that the caller will set force_sidecar_patching=True if the layer is quantized.
|
||||
# TODO(ryand): Handle the case where we are running without a GPU. Should we set a config flag that allows
|
||||
# forcing full patching even on the CPU?
|
||||
use_sidecar_patching = False
|
||||
if force_direct_patching and force_sidecar_patching:
|
||||
raise ValueError("Cannot force both direct and sidecar patching.")
|
||||
elif force_direct_patching:
|
||||
use_sidecar_patching = False
|
||||
elif force_sidecar_patching:
|
||||
use_sidecar_patching = True
|
||||
elif module.get_num_patches() > 0:
|
||||
use_sidecar_patching = True
|
||||
elif LayerPatcher._is_any_part_of_layer_on_cpu(module):
|
||||
use_sidecar_patching = True
|
||||
|
||||
if use_sidecar_patching:
|
||||
LayerPatcher._apply_model_layer_wrapper_patch(
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
patch=layer,
|
||||
patch_weight=patch_weight,
|
||||
original_modules=original_modules,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
LayerPatcher._apply_model_layer_patch(
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
patch=layer,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_any_part_of_layer_on_cpu(layer: torch.nn.Module) -> bool:
|
||||
return any(p.device.type == "cpu" for p in layer.parameters())
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@ -120,7 +161,9 @@ class LayerPatcher:
|
||||
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
for param_name, param_weight in patch.get_parameters(module_to_patch, weight=patch_weight).items():
|
||||
for param_name, param_weight in patch.get_parameters(
|
||||
dict(module_to_patch.named_parameters(recurse=False)), weight=patch_weight
|
||||
).items():
|
||||
param_key = module_to_patch_key + "." + param_name
|
||||
module_param = module_to_patch.get_parameter(param_name)
|
||||
|
||||
@ -143,93 +186,9 @@ class LayerPatcher:
|
||||
|
||||
patch.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_model_sidecar_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[ModelPatchRaw, float]],
|
||||
prefix: str,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Apply one or more LoRA sidecar patches to a model within a context manager. Sidecar patches incur some
|
||||
overhead compared to normal LoRA patching, but they allow for LoRA layers to applied to base layers in any
|
||||
quantization format.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to patch.
|
||||
patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and
|
||||
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
|
||||
all at once.
|
||||
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
|
||||
dtype (torch.dtype): The compute dtype of the sidecar layers. This cannot easily be inferred from the model,
|
||||
since the sidecar layers are typically applied on top of quantized layers whose weight dtype is
|
||||
different from their compute dtype.
|
||||
"""
|
||||
original_modules: dict[str, torch.nn.Module] = {}
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
LayerPatcher._apply_model_sidecar_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
patch_weight=patch_weight,
|
||||
original_modules=original_modules,
|
||||
dtype=dtype,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
# Restore original modules.
|
||||
# Note: This logic assumes no nested modules in original_modules.
|
||||
for module_key, orig_module in original_modules.items():
|
||||
module_parent_key, module_name = LayerPatcher._split_parent_key(module_key)
|
||||
parent_module = model.get_submodule(module_parent_key)
|
||||
LayerPatcher._set_submodule(parent_module, module_name, orig_module)
|
||||
|
||||
@staticmethod
|
||||
def _apply_model_sidecar_patch(
|
||||
model: torch.nn.Module,
|
||||
patch: ModelPatchRaw,
|
||||
patch_weight: float,
|
||||
prefix: str,
|
||||
original_modules: dict[str, torch.nn.Module],
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Apply a single LoRA sidecar patch to a model."""
|
||||
|
||||
if patch_weight == 0:
|
||||
return
|
||||
|
||||
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
|
||||
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
|
||||
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
|
||||
# without searching, but some legacy code still uses flattened keys.
|
||||
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
|
||||
|
||||
prefix_len = len(prefix)
|
||||
|
||||
for layer_key, layer in patch.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module_key, module = LayerPatcher._get_submodule(
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
LayerPatcher._apply_model_layer_wrapper_patch(
|
||||
model=model,
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
patch=layer,
|
||||
patch_weight=patch_weight,
|
||||
original_modules=original_modules,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _apply_model_layer_wrapper_patch(
|
||||
model: torch.nn.Module,
|
||||
module_to_patch: torch.nn.Module,
|
||||
module_to_patch_key: str,
|
||||
patch: BaseLayerPatch,
|
||||
@ -237,25 +196,16 @@ class LayerPatcher:
|
||||
original_modules: dict[str, torch.nn.Module],
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Apply a single LoRA wrapper patch to a model."""
|
||||
# Replace the original module with a BaseSidecarWrapper if it has not already been done.
|
||||
if not isinstance(module_to_patch, BaseSidecarWrapper):
|
||||
wrapped_module = wrap_module_with_sidecar_wrapper(orig_module=module_to_patch)
|
||||
original_modules[module_to_patch_key] = module_to_patch
|
||||
module_parent_key, module_name = LayerPatcher._split_parent_key(module_to_patch_key)
|
||||
module_parent = model.get_submodule(module_parent_key)
|
||||
LayerPatcher._set_submodule(module_parent, module_name, wrapped_module)
|
||||
else:
|
||||
assert module_to_patch_key in original_modules
|
||||
wrapped_module = module_to_patch
|
||||
|
||||
"""Apply a single LoRA wrapper patch to a module."""
|
||||
# Move the LoRA layer to the same device/dtype as the orig module.
|
||||
first_param = next(module_to_patch.parameters())
|
||||
device = first_param.device
|
||||
patch.to(device=device, dtype=dtype)
|
||||
|
||||
# Add the patch to the sidecar wrapper.
|
||||
wrapped_module.add_patch(patch, patch_weight)
|
||||
if module_to_patch_key not in original_modules:
|
||||
original_modules[module_to_patch_key] = module_to_patch
|
||||
|
||||
module_to_patch.add_patch(patch, patch_weight)
|
||||
|
||||
@staticmethod
|
||||
def _split_parent_key(module_key: str) -> tuple[str, str]:
|
@ -5,7 +5,7 @@ import torch
|
||||
|
||||
class BaseLayerPatch(ABC):
|
||||
@abstractmethod
|
||||
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
|
||||
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
|
||||
"""Get the parameter residual updates that should be applied to the original parameters. Parameters omitted
|
||||
from the returned dict are not updated.
|
||||
"""
|
||||
|
@ -30,7 +30,7 @@ class ConcatenatedLoRALayer(LoRALayerBase):
|
||||
layer_weights = [lora_layer.get_weight(None) * lora_layer.scale() for lora_layer in self.lora_layers] # pyright: ignore[reportArgumentType]
|
||||
return torch.cat(layer_weights, dim=self.concat_axis)
|
||||
|
||||
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
def get_bias(self, orig_bias: torch.Tensor | None) -> Optional[torch.Tensor]:
|
||||
# TODO(ryand): Currently, we pass orig_bias=None to the sub-layers. If we want to support sub-layers that
|
||||
# require this value, we will need to implement chunking of the original bias tensor here.
|
||||
# Note that we must apply the sub-layer scales here.
|
||||
|
@ -8,11 +8,11 @@ class FluxControlLoRALayer(LoRALayer):
|
||||
shapes don't match.
|
||||
"""
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
|
||||
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
|
||||
"""This overrides the base class behavior to skip the reshaping step."""
|
||||
scale = self.scale()
|
||||
params = {"weight": self.get_weight(orig_module.weight) * (weight * scale)}
|
||||
bias = self.get_bias(orig_module.bias)
|
||||
params = {"weight": self.get_weight(orig_parameters["weight"]) * (weight * scale)}
|
||||
bias = self.get_bias(orig_parameters.get("bias", None))
|
||||
if bias is not None:
|
||||
params["bias"] = bias * (weight * scale)
|
||||
|
||||
|
@ -54,19 +54,19 @@ class LoRALayerBase(BaseLayerPatch):
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
def get_bias(self, orig_bias: torch.Tensor | None) -> Optional[torch.Tensor]:
|
||||
return self.bias
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
|
||||
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
|
||||
scale = self.scale()
|
||||
params = {"weight": self.get_weight(orig_module.weight) * (weight * scale)}
|
||||
bias = self.get_bias(orig_module.bias)
|
||||
params = {"weight": self.get_weight(orig_parameters["weight"]) * (weight * scale)}
|
||||
bias = self.get_bias(orig_parameters.get("bias", None))
|
||||
if bias is not None:
|
||||
params["bias"] = bias * (weight * scale)
|
||||
|
||||
# Reshape all params to match the original module's shape.
|
||||
for param_name, param_weight in params.items():
|
||||
orig_param = orig_module.get_parameter(param_name)
|
||||
orig_param = orig_parameters[param_name]
|
||||
if param_weight.shape != orig_param.shape:
|
||||
params[param_name] = param_weight.reshape(orig_param.shape)
|
||||
|
||||
|
@ -14,10 +14,10 @@ class SetParameterLayer(BaseLayerPatch):
|
||||
self.weight = weight
|
||||
self.param_name = param_name
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
|
||||
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
|
||||
# Note: We intentionally ignore the weight parameter here. This matches the behavior in the official FLUX
|
||||
# Control LoRA implementation.
|
||||
diff = self.weight - orig_module.get_parameter(self.param_name)
|
||||
diff = self.weight - orig_parameters[self.param_name]
|
||||
return {self.param_name: diff}
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
|
@ -1,54 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
|
||||
|
||||
class BaseSidecarWrapper(torch.nn.Module):
|
||||
"""A base class for sidecar wrappers.
|
||||
|
||||
A sidecar wrapper is a wrapper for an existing torch.nn.Module that applies a
|
||||
list of patches as 'sidecar' patches. I.e. it applies the sidecar patches during forward inference without modifying
|
||||
the original module.
|
||||
|
||||
Sidecar wrappers are typically used over regular patches when:
|
||||
- The original module is quantized and so the weights can't be patched in the usual way.
|
||||
- The original module is on the CPU and modifying the weights would require backing up the original weights and
|
||||
doubling the CPU memory usage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, orig_module: torch.nn.Module, patches_and_weights: list[tuple[BaseLayerPatch, float]] | None = None
|
||||
):
|
||||
super().__init__()
|
||||
self._orig_module = orig_module
|
||||
self._patches_and_weights = [] if patches_and_weights is None else patches_and_weights
|
||||
|
||||
@property
|
||||
def orig_module(self) -> torch.nn.Module:
|
||||
return self._orig_module
|
||||
|
||||
def add_patch(self, patch: BaseLayerPatch, patch_weight: float):
|
||||
"""Add a patch to the sidecar wrapper."""
|
||||
self._patches_and_weights.append((patch, patch_weight))
|
||||
|
||||
def _aggregate_patch_parameters(
|
||||
self, patches_and_weights: list[tuple[BaseLayerPatch, float]]
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Helper function that aggregates the parameters from all patches into a single dict."""
|
||||
params: dict[str, torch.Tensor] = {}
|
||||
|
||||
for patch, patch_weight in patches_and_weights:
|
||||
# TODO(ryand): self._orig_module could be quantized. Depending on what the patch is doing with the original
|
||||
# module, this might fail or return incorrect results.
|
||||
layer_params = patch.get_parameters(self._orig_module, weight=patch_weight)
|
||||
|
||||
for param_name, param_weight in layer_params.items():
|
||||
if param_name not in params:
|
||||
params[param_name] = param_weight
|
||||
else:
|
||||
params[param_name] += param_weight
|
||||
|
||||
return params
|
||||
|
||||
def forward(self, *args, **kwargs): # type: ignore
|
||||
raise NotImplementedError()
|
@ -1,11 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
|
||||
|
||||
|
||||
class Conv1dSidecarWrapper(BaseSidecarWrapper):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
|
||||
return self.orig_module(input) + torch.nn.functional.conv1d(
|
||||
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
|
||||
)
|
@ -1,11 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
|
||||
|
||||
|
||||
class Conv2dSidecarWrapper(BaseSidecarWrapper):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
|
||||
return self.orig_module(input) + torch.nn.functional.conv1d(
|
||||
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
|
||||
)
|
@ -1,24 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
|
||||
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
|
||||
|
||||
|
||||
class FluxRMSNormSidecarWrapper(BaseSidecarWrapper):
|
||||
"""A sidecar wrapper for a FLUX RMSNorm layer.
|
||||
|
||||
This wrapper is a special case. It is added specifically to enable FLUX structural control LoRAs, which overwrite
|
||||
the RMSNorm scale parameters.
|
||||
"""
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# Given the narrow focus of this wrapper, we only support a very particular patch configuration:
|
||||
assert len(self._patches_and_weights) == 1
|
||||
patch, _patch_weight = self._patches_and_weights[0]
|
||||
assert isinstance(patch, SetParameterLayer)
|
||||
assert patch.param_name == "scale"
|
||||
|
||||
# Apply the patch.
|
||||
# NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should
|
||||
# be handled.
|
||||
return torch.nn.functional.rms_norm(input, patch.weight.shape, patch.weight, eps=1e-6)
|
@ -1,66 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
|
||||
|
||||
|
||||
class LinearSidecarWrapper(BaseSidecarWrapper):
|
||||
def _lora_forward(self, input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
|
||||
"""An optimized implementation of the residual calculation for a Linear LoRALayer."""
|
||||
x = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x = torch.nn.functional.linear(x, lora_layer.mid)
|
||||
x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias)
|
||||
x *= lora_weight * lora_layer.scale()
|
||||
return x
|
||||
|
||||
def _concatenated_lora_forward(
|
||||
self, input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float
|
||||
) -> torch.Tensor:
|
||||
"""An optimized implementation of the residual calculation for a Linear ConcatenatedLoRALayer."""
|
||||
x_chunks: list[torch.Tensor] = []
|
||||
for lora_layer in concatenated_lora_layer.lora_layers:
|
||||
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
|
||||
x_chunk *= lora_weight * lora_layer.scale()
|
||||
x_chunks.append(x_chunk)
|
||||
|
||||
# TODO(ryand): Generalize to support concat_axis != 0.
|
||||
assert concatenated_lora_layer.concat_axis == 0
|
||||
x = torch.cat(x_chunks, dim=-1)
|
||||
return x
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# First, apply the original linear layer.
|
||||
# NOTE: We slice the input to match the original weight shape in order to work with FluxControlLoRAs, which
|
||||
# change the linear layer's in_features.
|
||||
orig_input = input
|
||||
input = orig_input[..., : self.orig_module.in_features]
|
||||
output = self.orig_module(input)
|
||||
|
||||
# Then, apply layers for which we have optimized implementations.
|
||||
unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = []
|
||||
for patch, patch_weight in self._patches_and_weights:
|
||||
if isinstance(patch, FluxControlLoRALayer):
|
||||
# Note that we use the original input here, not the sliced input.
|
||||
output += self._lora_forward(orig_input, patch, patch_weight)
|
||||
elif isinstance(patch, LoRALayer):
|
||||
output += self._lora_forward(input, patch, patch_weight)
|
||||
elif isinstance(patch, ConcatenatedLoRALayer):
|
||||
output += self._concatenated_lora_forward(input, patch, patch_weight)
|
||||
else:
|
||||
unprocessed_patches_and_weights.append((patch, patch_weight))
|
||||
|
||||
# Finally, apply any remaining patches.
|
||||
if len(unprocessed_patches_and_weights) > 0:
|
||||
aggregated_param_residuals = self._aggregate_patch_parameters(unprocessed_patches_and_weights)
|
||||
output += torch.nn.functional.linear(
|
||||
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
|
||||
)
|
||||
|
||||
return output
|
@ -1,20 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.modules.layers import RMSNorm
|
||||
from invokeai.backend.patches.sidecar_wrappers.conv1d_sidecar_wrapper import Conv1dSidecarWrapper
|
||||
from invokeai.backend.patches.sidecar_wrappers.conv2d_sidecar_wrapper import Conv2dSidecarWrapper
|
||||
from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper
|
||||
from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper
|
||||
|
||||
|
||||
def wrap_module_with_sidecar_wrapper(orig_module: torch.nn.Module) -> torch.nn.Module:
|
||||
if isinstance(orig_module, torch.nn.Linear):
|
||||
return LinearSidecarWrapper(orig_module)
|
||||
elif isinstance(orig_module, torch.nn.Conv1d):
|
||||
return Conv1dSidecarWrapper(orig_module)
|
||||
elif isinstance(orig_module, torch.nn.Conv2d):
|
||||
return Conv2dSidecarWrapper(orig_module)
|
||||
elif isinstance(orig_module, RMSNorm):
|
||||
return FluxRMSNormSidecarWrapper(orig_module)
|
||||
else:
|
||||
raise ValueError(f"No sidecar wrapper found for module type: {type(orig_module)}")
|
@ -48,11 +48,13 @@ GGML_TENSOR_OP_TABLE = {
|
||||
# Ops to run on the quantized tensor.
|
||||
torch.ops.aten.detach.default: apply_to_quantized_tensor, # pyright: ignore
|
||||
torch.ops.aten._to_copy.default: apply_to_quantized_tensor, # pyright: ignore
|
||||
torch.ops.aten.clone.default: apply_to_quantized_tensor, # pyright: ignore
|
||||
# Ops to run on dequantized tensors.
|
||||
torch.ops.aten.t.default: dequantize_and_run, # pyright: ignore
|
||||
torch.ops.aten.addmm.default: dequantize_and_run, # pyright: ignore
|
||||
torch.ops.aten.mul.Tensor: dequantize_and_run, # pyright: ignore
|
||||
torch.ops.aten.add.Tensor: dequantize_and_run, # pyright: ignore
|
||||
torch.ops.aten.allclose.default: dequantize_and_run, # pyright: ignore
|
||||
}
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
|
@ -5,8 +5,8 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import LayerPatcher
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -31,12 +31,16 @@ class LoRAExt(ExtensionBase):
|
||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||
lora_model = self._node_context.models.load(self._model_id).model
|
||||
assert isinstance(lora_model, ModelPatchRaw)
|
||||
LayerPatcher.apply_model_patch(
|
||||
LayerPatcher.apply_smart_model_patch(
|
||||
model=unet,
|
||||
prefix="lora_unet_",
|
||||
patch=lora_model,
|
||||
patch_weight=self._weight,
|
||||
original_weights=original_weights,
|
||||
original_modules={},
|
||||
dtype=unet.dtype,
|
||||
force_direct_patching=True,
|
||||
force_sidecar_patching=False,
|
||||
)
|
||||
del lora_model
|
||||
|
||||
|
@ -1,18 +1,27 @@
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
|
||||
CachedModelWithPartialLoad,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import CustomLinear
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
|
||||
apply_custom_layers_to_model,
|
||||
)
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_total_bytes(device: str):
|
||||
@pytest.fixture
|
||||
def model():
|
||||
model = DummyModule()
|
||||
apply_custom_layers_to_model(model)
|
||||
return model
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_total_bytes(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
linear1_numel = 10 * 32 + 32
|
||||
linear2_numel = 32 * 64 + 64
|
||||
@ -22,8 +31,7 @@ def test_cached_model_total_bytes(device: str):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_cur_vram_bytes(device: str):
|
||||
model = DummyModule()
|
||||
def test_cached_model_cur_vram_bytes(device: str, model: DummyModule):
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
@ -37,8 +45,7 @@ def test_cached_model_cur_vram_bytes(device: str):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_load(device: str):
|
||||
model = DummyModule()
|
||||
def test_cached_model_partial_load(device: str, model: DummyModule):
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
@ -58,14 +65,13 @@ def test_cached_model_partial_load(device: str):
|
||||
if p.device.type == device and n != "buffer2"
|
||||
)
|
||||
|
||||
# Check that the model's modules have been patched with CustomLinear layers.
|
||||
assert type(model.linear1) is CustomLinear
|
||||
assert type(model.linear2) is CustomLinear
|
||||
# Check that the model's modules have device autocasting enabled.
|
||||
assert model.linear1.is_device_autocasting_enabled()
|
||||
assert model.linear2.is_device_autocasting_enabled()
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_unload(device: str):
|
||||
model = DummyModule()
|
||||
def test_cached_model_partial_unload(device: str, model: DummyModule):
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
@ -87,14 +93,13 @@ def test_cached_model_partial_unload(device: str):
|
||||
calc_tensor_size(p) for p in itertools.chain(model.parameters(), model.buffers()) if p.device.type == "cpu"
|
||||
)
|
||||
|
||||
# Check that the model's modules are still patched with CustomLinear layers.
|
||||
assert type(model.linear1) is CustomLinear
|
||||
assert type(model.linear2) is CustomLinear
|
||||
# Check that the model's modules still have device autocasting enabled.
|
||||
assert model.linear1.is_device_autocasting_enabled()
|
||||
assert model.linear2.is_device_autocasting_enabled()
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_and_unload(device: str):
|
||||
model = DummyModule()
|
||||
def test_cached_model_full_load_and_unload(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
|
||||
# Model starts in CPU memory.
|
||||
@ -107,8 +112,8 @@ def test_cached_model_full_load_and_unload(device: str):
|
||||
assert loaded_bytes == model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
assert all(p.device.type == device for p in itertools.chain(model.parameters(), model.buffers()))
|
||||
assert type(model.linear1) is torch.nn.Linear
|
||||
assert type(model.linear2) is torch.nn.Linear
|
||||
assert not model.linear1.is_device_autocasting_enabled()
|
||||
assert not model.linear2.is_device_autocasting_enabled()
|
||||
|
||||
# Full unload the model from VRAM.
|
||||
unloaded_bytes = cached_model.full_unload_from_vram()
|
||||
@ -126,8 +131,7 @@ def test_cached_model_full_load_and_unload(device: str):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_from_partial(device: str):
|
||||
model = DummyModule()
|
||||
def test_cached_model_full_load_from_partial(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
|
||||
# Model starts in CPU memory.
|
||||
@ -140,8 +144,8 @@ def test_cached_model_full_load_from_partial(device: str):
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes < model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
assert type(model.linear1) is CustomLinear
|
||||
assert type(model.linear2) is CustomLinear
|
||||
assert model.linear1.is_device_autocasting_enabled()
|
||||
assert model.linear2.is_device_autocasting_enabled()
|
||||
|
||||
# Full load the rest of the model into VRAM.
|
||||
loaded_bytes_2 = cached_model.full_load_to_vram()
|
||||
@ -150,13 +154,12 @@ def test_cached_model_full_load_from_partial(device: str):
|
||||
assert loaded_bytes + loaded_bytes_2 == cached_model.cur_vram_bytes()
|
||||
assert loaded_bytes + loaded_bytes_2 == model_total_bytes
|
||||
assert all(p.device.type == device for p in itertools.chain(model.parameters(), model.buffers()))
|
||||
assert type(model.linear1) is torch.nn.Linear
|
||||
assert type(model.linear2) is torch.nn.Linear
|
||||
assert not model.linear1.is_device_autocasting_enabled()
|
||||
assert not model.linear2.is_device_autocasting_enabled()
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_unload_from_partial(device: str):
|
||||
model = DummyModule()
|
||||
def test_cached_model_full_unload_from_partial(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
|
||||
# Model starts in CPU memory.
|
||||
@ -184,8 +187,7 @@ def test_cached_model_full_unload_from_partial(device: str):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_get_cpu_state_dict(device: str):
|
||||
model = DummyModule()
|
||||
def test_cached_model_get_cpu_state_dict(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
|
||||
# Model starts in CPU memory.
|
||||
@ -209,8 +211,7 @@ def test_cached_model_get_cpu_state_dict(device: str):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_and_inference(device: str):
|
||||
model = DummyModule()
|
||||
def test_cached_model_full_load_and_inference(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
@ -237,8 +238,7 @@ def test_cached_model_full_load_and_inference(device: str):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_load_and_inference(device: str):
|
||||
model = DummyModule()
|
||||
def test_cached_model_partial_load_and_inference(device: str, model: DummyModule):
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
@ -262,9 +262,9 @@ def test_cached_model_partial_load_and_inference(device: str):
|
||||
for n, p in itertools.chain(model.named_parameters(), model.named_buffers())
|
||||
if p.device.type == device and n != "buffer2"
|
||||
)
|
||||
# Check that the model's modules have been patched with CustomLinear layers.
|
||||
assert type(model.linear1) is CustomLinear
|
||||
assert type(model.linear2) is CustomLinear
|
||||
# Check that the model's modules have device autocasting enabled.
|
||||
assert model.linear1.is_device_autocasting_enabled()
|
||||
assert model.linear2.is_device_autocasting_enabled()
|
||||
|
||||
# Run inference on the GPU.
|
||||
output2 = model(x.to(device))
|
||||
|
@ -0,0 +1,530 @@
|
||||
import copy
|
||||
|
||||
import gguf
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.modules.layers import RMSNorm
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
|
||||
AUTOCAST_MODULE_TYPE_MAPPING,
|
||||
AUTOCAST_MODULE_TYPE_MAPPING_INVERSE,
|
||||
unwrap_custom_layer,
|
||||
wrap_custom_layer,
|
||||
)
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_8_bit_lt import (
|
||||
build_linear_8bit_lt_layer,
|
||||
)
|
||||
from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_nf4 import (
|
||||
build_linear_nf4_layer,
|
||||
)
|
||||
from tests.backend.quantization.gguf.test_ggml_tensor import quantize_tensor
|
||||
|
||||
|
||||
def build_linear_layer_with_ggml_quantized_tensor(orig_layer: torch.nn.Linear | None = None):
|
||||
if orig_layer is None:
|
||||
orig_layer = torch.nn.Linear(32, 64)
|
||||
|
||||
ggml_quantized_weight = quantize_tensor(orig_layer.weight, gguf.GGMLQuantizationType.Q8_0)
|
||||
orig_layer.weight = torch.nn.Parameter(ggml_quantized_weight)
|
||||
ggml_quantized_bias = quantize_tensor(orig_layer.bias, gguf.GGMLQuantizationType.Q8_0)
|
||||
orig_layer.bias = torch.nn.Parameter(ggml_quantized_bias)
|
||||
return orig_layer
|
||||
|
||||
|
||||
parameterize_all_devices = pytest.mark.parametrize(
|
||||
("device"),
|
||||
[
|
||||
pytest.param("cpu"),
|
||||
pytest.param(
|
||||
"mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.")
|
||||
),
|
||||
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
|
||||
],
|
||||
)
|
||||
|
||||
parameterize_cuda_and_mps = pytest.mark.parametrize(
|
||||
("device"),
|
||||
[
|
||||
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
|
||||
pytest.param(
|
||||
"mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.")
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
LayerUnderTest = tuple[torch.nn.Module, torch.Tensor, bool]
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
"linear",
|
||||
"conv1d",
|
||||
"conv2d",
|
||||
"group_norm",
|
||||
"embedding",
|
||||
"flux_rms_norm",
|
||||
"linear_with_ggml_quantized_tensor",
|
||||
"invoke_linear_8_bit_lt",
|
||||
"invoke_linear_nf4",
|
||||
]
|
||||
)
|
||||
def layer_under_test(request: pytest.FixtureRequest) -> LayerUnderTest:
|
||||
"""A fixture that returns a tuple of (layer, input, supports_cpu_inference) for the layer under test."""
|
||||
layer_type = request.param
|
||||
if layer_type == "linear":
|
||||
return (torch.nn.Linear(8, 16), torch.randn(1, 8), True)
|
||||
elif layer_type == "conv1d":
|
||||
return (torch.nn.Conv1d(8, 16, 3), torch.randn(1, 8, 5), True)
|
||||
elif layer_type == "conv2d":
|
||||
return (torch.nn.Conv2d(8, 16, 3), torch.randn(1, 8, 5, 5), True)
|
||||
elif layer_type == "group_norm":
|
||||
return (torch.nn.GroupNorm(2, 8), torch.randn(1, 8, 5), True)
|
||||
elif layer_type == "embedding":
|
||||
return (torch.nn.Embedding(4, 8), torch.tensor([0, 1], dtype=torch.long), True)
|
||||
elif layer_type == "flux_rms_norm":
|
||||
return (RMSNorm(8), torch.randn(1, 8), True)
|
||||
elif layer_type == "linear_with_ggml_quantized_tensor":
|
||||
return (build_linear_layer_with_ggml_quantized_tensor(), torch.randn(1, 32), True)
|
||||
elif layer_type == "invoke_linear_8_bit_lt":
|
||||
return (build_linear_8bit_lt_layer(), torch.randn(1, 32), False)
|
||||
elif layer_type == "invoke_linear_nf4":
|
||||
return (build_linear_nf4_layer(), torch.randn(1, 64), False)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer_type: {layer_type}")
|
||||
|
||||
|
||||
def layer_to_device_via_state_dict(layer: torch.nn.Module, device: str):
|
||||
"""A helper function to move a layer to a device by roundtripping through a state dict. This most closely matches
|
||||
how models are moved in the app. Some of the quantization types have broken semantics around calling .to() on the
|
||||
layer directly, so this is a workaround.
|
||||
|
||||
We should fix this in the future.
|
||||
Relevant article: https://pytorch.org/tutorials/recipes/recipes/swap_tensors.html
|
||||
"""
|
||||
state_dict = layer.state_dict()
|
||||
state_dict = {k: v.to(device) for k, v in state_dict.items()}
|
||||
layer.load_state_dict(state_dict, assign=True)
|
||||
|
||||
|
||||
def wrap_single_custom_layer(layer: torch.nn.Module):
|
||||
custom_layer_type = AUTOCAST_MODULE_TYPE_MAPPING[type(layer)]
|
||||
return wrap_custom_layer(layer, custom_layer_type)
|
||||
|
||||
|
||||
def unwrap_single_custom_layer(layer: torch.nn.Module):
|
||||
orig_layer_type = AUTOCAST_MODULE_TYPE_MAPPING_INVERSE[type(layer)]
|
||||
return unwrap_custom_layer(layer, orig_layer_type)
|
||||
|
||||
|
||||
def test_isinstance(layer_under_test: LayerUnderTest):
|
||||
"""Test that isinstance() and type() behave as expected after wrapping a layer in a custom layer."""
|
||||
orig_layer, _, _ = layer_under_test
|
||||
orig_type = type(orig_layer)
|
||||
|
||||
custom_layer = wrap_single_custom_layer(orig_layer)
|
||||
|
||||
assert isinstance(custom_layer, orig_type)
|
||||
assert type(custom_layer) is not orig_type
|
||||
|
||||
|
||||
def test_wrap_and_unwrap(layer_under_test: LayerUnderTest):
|
||||
"""Test that wrapping and unwrapping a layer behaves as expected."""
|
||||
orig_layer, _, _ = layer_under_test
|
||||
orig_type = type(orig_layer)
|
||||
|
||||
# Wrap the original layer and assert that attributes of the custom layer can be accessed.
|
||||
custom_layer = wrap_single_custom_layer(orig_layer)
|
||||
custom_layer.set_device_autocasting_enabled(True)
|
||||
assert custom_layer._device_autocasting_enabled
|
||||
|
||||
# Unwrap the custom layer.
|
||||
# Assert that the methods of the wrapped layer are no longer accessible.
|
||||
unwrapped_layer = unwrap_single_custom_layer(custom_layer)
|
||||
with pytest.raises(AttributeError):
|
||||
_ = unwrapped_layer.set_device_autocasting_enabled(True)
|
||||
# For now, we have chosen to allow attributes to persist. We may revisit this in the future.
|
||||
assert unwrapped_layer._device_autocasting_enabled
|
||||
assert type(unwrapped_layer) is orig_type
|
||||
|
||||
|
||||
@parameterize_all_devices
|
||||
def test_state_dict(device: str, layer_under_test: LayerUnderTest):
|
||||
"""Test that .state_dict() behaves the same on the original layer and the wrapped layer."""
|
||||
orig_layer, _, _ = layer_under_test
|
||||
|
||||
# Get the original layer on the test device.
|
||||
orig_layer.to(device)
|
||||
orig_state_dict = orig_layer.state_dict()
|
||||
|
||||
# Wrap the original layer.
|
||||
custom_layer = copy.deepcopy(orig_layer)
|
||||
custom_layer = wrap_single_custom_layer(custom_layer)
|
||||
|
||||
custom_state_dict = custom_layer.state_dict()
|
||||
|
||||
assert set(orig_state_dict.keys()) == set(custom_state_dict.keys())
|
||||
for k in orig_state_dict:
|
||||
assert orig_state_dict[k].shape == custom_state_dict[k].shape
|
||||
assert orig_state_dict[k].dtype == custom_state_dict[k].dtype
|
||||
assert orig_state_dict[k].device == custom_state_dict[k].device
|
||||
assert torch.allclose(orig_state_dict[k], custom_state_dict[k])
|
||||
|
||||
|
||||
@parameterize_all_devices
|
||||
def test_load_state_dict(device: str, layer_under_test: LayerUnderTest):
|
||||
"""Test that .load_state_dict() behaves the same on the original layer and the wrapped layer."""
|
||||
orig_layer, _, _ = layer_under_test
|
||||
|
||||
orig_layer.to(device)
|
||||
|
||||
custom_layer = copy.deepcopy(orig_layer)
|
||||
custom_layer = wrap_single_custom_layer(custom_layer)
|
||||
|
||||
# Do a state dict roundtrip.
|
||||
orig_state_dict = orig_layer.state_dict()
|
||||
custom_state_dict = custom_layer.state_dict()
|
||||
|
||||
orig_layer.load_state_dict(custom_state_dict, assign=True)
|
||||
custom_layer.load_state_dict(orig_state_dict, assign=True)
|
||||
|
||||
orig_state_dict = orig_layer.state_dict()
|
||||
custom_state_dict = custom_layer.state_dict()
|
||||
|
||||
# Assert that the state dicts are the same after the roundtrip.
|
||||
assert set(orig_state_dict.keys()) == set(custom_state_dict.keys())
|
||||
for k in orig_state_dict:
|
||||
assert orig_state_dict[k].shape == custom_state_dict[k].shape
|
||||
assert orig_state_dict[k].dtype == custom_state_dict[k].dtype
|
||||
assert orig_state_dict[k].device == custom_state_dict[k].device
|
||||
assert torch.allclose(orig_state_dict[k], custom_state_dict[k])
|
||||
|
||||
|
||||
@parameterize_all_devices
|
||||
def test_inference_on_device(device: str, layer_under_test: LayerUnderTest):
|
||||
"""Test that inference behaves the same on the original layer and the wrapped layer when all weights are on the
|
||||
device.
|
||||
"""
|
||||
orig_layer, layer_input, supports_cpu_inference = layer_under_test
|
||||
|
||||
if device == "cpu" and not supports_cpu_inference:
|
||||
pytest.skip("Layer does not support CPU inference.")
|
||||
|
||||
layer_to_device_via_state_dict(orig_layer, device)
|
||||
|
||||
custom_layer = copy.deepcopy(orig_layer)
|
||||
custom_layer = wrap_single_custom_layer(custom_layer)
|
||||
|
||||
# Run inference with the original layer.
|
||||
x = layer_input.to(device)
|
||||
orig_output = orig_layer(x)
|
||||
|
||||
# Run inference with the wrapped layer.
|
||||
custom_output = custom_layer(x)
|
||||
|
||||
assert torch.allclose(orig_output, custom_output)
|
||||
|
||||
|
||||
@parameterize_cuda_and_mps
|
||||
def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: LayerUnderTest):
|
||||
"""Test that inference behaves the same on the original layer and the wrapped layer when all weights are on the
|
||||
device.
|
||||
"""
|
||||
orig_layer, layer_input, supports_cpu_inference = layer_under_test
|
||||
|
||||
if device == "cpu" and not supports_cpu_inference:
|
||||
pytest.skip("Layer does not support CPU inference.")
|
||||
|
||||
# Make sure the original layer is on the device.
|
||||
layer_to_device_via_state_dict(orig_layer, device)
|
||||
|
||||
x = layer_input.to(device)
|
||||
|
||||
# Run inference with the original layer on the device.
|
||||
orig_output = orig_layer(x)
|
||||
|
||||
# Move the original layer to the CPU.
|
||||
layer_to_device_via_state_dict(orig_layer, "cpu")
|
||||
|
||||
# Inference should fail with an input on the device.
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = orig_layer(x)
|
||||
|
||||
# Wrap the original layer.
|
||||
custom_layer = copy.deepcopy(orig_layer)
|
||||
custom_layer = wrap_single_custom_layer(custom_layer)
|
||||
|
||||
# Inference should still fail with autocasting disabled.
|
||||
custom_layer.set_device_autocasting_enabled(False)
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = custom_layer(x)
|
||||
|
||||
# Run inference with the wrapped layer on the device.
|
||||
custom_layer.set_device_autocasting_enabled(True)
|
||||
custom_output = custom_layer(x)
|
||||
assert custom_output.device.type == device
|
||||
|
||||
assert torch.allclose(orig_output, custom_output)
|
||||
|
||||
|
||||
PatchUnderTest = tuple[list[tuple[BaseLayerPatch, float]], torch.Tensor]
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
"single_lora",
|
||||
"multiple_loras",
|
||||
"concatenated_lora",
|
||||
"flux_control_lora",
|
||||
]
|
||||
)
|
||||
def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest:
|
||||
"""A fixture that returns a tuple of (patches, input) for the patch under test."""
|
||||
layer_type = request.param
|
||||
torch.manual_seed(0)
|
||||
|
||||
# The assumed in/out features of the base linear layer.
|
||||
in_features = 32
|
||||
out_features = 64
|
||||
|
||||
rank = 4
|
||||
|
||||
if layer_type == "single_lora":
|
||||
lora_layer = LoRALayer(
|
||||
up=torch.randn(out_features, rank),
|
||||
mid=None,
|
||||
down=torch.randn(rank, in_features),
|
||||
alpha=1.0,
|
||||
bias=torch.randn(out_features),
|
||||
)
|
||||
input = torch.randn(1, in_features)
|
||||
return ([(lora_layer, 0.7)], input)
|
||||
elif layer_type == "multiple_loras":
|
||||
lora_layer = LoRALayer(
|
||||
up=torch.randn(out_features, rank),
|
||||
mid=None,
|
||||
down=torch.randn(rank, in_features),
|
||||
alpha=1.0,
|
||||
bias=torch.randn(out_features),
|
||||
)
|
||||
lora_layer_2 = LoRALayer(
|
||||
up=torch.randn(out_features, rank),
|
||||
mid=None,
|
||||
down=torch.randn(rank, in_features),
|
||||
alpha=1.0,
|
||||
bias=torch.randn(out_features),
|
||||
)
|
||||
|
||||
input = torch.randn(1, in_features)
|
||||
return ([(lora_layer, 1.0), (lora_layer_2, 0.5)], input)
|
||||
elif layer_type == "concatenated_lora":
|
||||
sub_layer_out_features = [16, 16, 32]
|
||||
|
||||
# Create a ConcatenatedLoRA layer.
|
||||
sub_layers: list[LoRALayer] = []
|
||||
for out_features in sub_layer_out_features:
|
||||
down = torch.randn(rank, in_features)
|
||||
up = torch.randn(out_features, rank)
|
||||
bias = torch.randn(out_features)
|
||||
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias))
|
||||
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
|
||||
|
||||
input = torch.randn(1, in_features)
|
||||
return ([(concatenated_lora_layer, 0.7)], input)
|
||||
elif layer_type == "flux_control_lora":
|
||||
# Create a FluxControlLoRALayer.
|
||||
patched_in_features = 40
|
||||
lora_layer = FluxControlLoRALayer(
|
||||
up=torch.randn(out_features, rank),
|
||||
mid=None,
|
||||
down=torch.randn(rank, patched_in_features),
|
||||
alpha=1.0,
|
||||
bias=torch.randn(out_features),
|
||||
)
|
||||
|
||||
input = torch.randn(1, patched_in_features)
|
||||
return ([(lora_layer, 0.7)], input)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer_type: {layer_type}")
|
||||
|
||||
|
||||
@parameterize_all_devices
|
||||
def test_linear_sidecar_patches(device: str, patch_under_test: PatchUnderTest):
|
||||
patches, input = patch_under_test
|
||||
|
||||
# Build the base layer under test.
|
||||
layer = torch.nn.Linear(32, 64)
|
||||
|
||||
# Move the layer and input to the device.
|
||||
layer_to_device_via_state_dict(layer, device)
|
||||
input = input.to(torch.device(device))
|
||||
|
||||
# Patch the LoRA layer into the linear layer.
|
||||
layer_patched = copy.deepcopy(layer)
|
||||
for patch, weight in patches:
|
||||
LayerPatcher._apply_model_layer_patch(
|
||||
module_to_patch=layer_patched,
|
||||
module_to_patch_key="",
|
||||
patch=patch,
|
||||
patch_weight=weight,
|
||||
original_weights=OriginalWeightsStorage(),
|
||||
)
|
||||
|
||||
# Wrap the original layer in a custom layer and add the patch to it as a sidecar.
|
||||
custom_layer = wrap_single_custom_layer(layer)
|
||||
for patch, weight in patches:
|
||||
patch.to(torch.device(device))
|
||||
custom_layer.add_patch(patch, weight)
|
||||
|
||||
# Run inference with the original layer and the patched layer and assert they are equal.
|
||||
output_patched = layer_patched(input)
|
||||
output_custom = custom_layer(input)
|
||||
assert torch.allclose(output_patched, output_custom, atol=1e-6)
|
||||
|
||||
|
||||
@parameterize_cuda_and_mps
|
||||
def test_linear_sidecar_patches_with_autocast_from_cpu_to_device(device: str, patch_under_test: PatchUnderTest):
|
||||
"""Test that the output of a linear layer with sidecar patches is the same when the layer is on the device and
|
||||
when the layer is on the CPU and the patches are autocasted to the device.
|
||||
"""
|
||||
patches, input = patch_under_test
|
||||
|
||||
# Build the base layer under test.
|
||||
layer = torch.nn.Linear(32, 64)
|
||||
|
||||
# Move the layer and input to the device.
|
||||
layer_to_device_via_state_dict(layer, device)
|
||||
input = input.to(torch.device(device))
|
||||
|
||||
# Wrap the original layer in a custom layer and add the patch to it.
|
||||
custom_layer = wrap_single_custom_layer(layer)
|
||||
for patch, weight in patches:
|
||||
patch.to(torch.device(device))
|
||||
custom_layer.add_patch(patch, weight)
|
||||
|
||||
# Run inference with the custom layer on the device.
|
||||
expected_output = custom_layer(input)
|
||||
|
||||
# Move the custom layer to the CPU.
|
||||
layer_to_device_via_state_dict(custom_layer, "cpu")
|
||||
|
||||
# Move the patches to the CPU.
|
||||
custom_layer.clear_patches()
|
||||
for patch, weight in patches:
|
||||
patch.to(torch.device("cpu"))
|
||||
custom_layer.add_patch(patch, weight)
|
||||
|
||||
# Run inference with an input on the device, and all layer weights on the CPU. The weights should be autocasted to
|
||||
# the device.
|
||||
autocast_output = custom_layer(input)
|
||||
assert autocast_output.device.type == device
|
||||
|
||||
# Assert that the outputs with and without autocasting are the same.
|
||||
assert torch.allclose(expected_output, autocast_output, atol=1e-6)
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
"linear_ggml_quantized",
|
||||
"invoke_linear_8_bit_lt",
|
||||
"invoke_linear_nf4",
|
||||
]
|
||||
)
|
||||
def quantized_linear_layer_under_test(request: pytest.FixtureRequest):
|
||||
in_features = 32
|
||||
out_features = 64
|
||||
torch.manual_seed(0)
|
||||
layer_type = request.param
|
||||
orig_layer = torch.nn.Linear(in_features, out_features)
|
||||
if layer_type == "linear_ggml_quantized":
|
||||
return orig_layer, build_linear_layer_with_ggml_quantized_tensor(orig_layer)
|
||||
elif layer_type == "invoke_linear_8_bit_lt":
|
||||
return orig_layer, build_linear_8bit_lt_layer(orig_layer)
|
||||
elif layer_type == "invoke_linear_nf4":
|
||||
return orig_layer, build_linear_nf4_layer(orig_layer)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer_type: {layer_type}")
|
||||
|
||||
|
||||
@parameterize_cuda_and_mps
|
||||
def test_quantized_linear_sidecar_patches(
|
||||
device: str,
|
||||
quantized_linear_layer_under_test: tuple[torch.nn.Module, torch.nn.Module],
|
||||
patch_under_test: PatchUnderTest,
|
||||
):
|
||||
"""Test that patches can be applied to quantized linear layers and that the output is the same as when the patch is
|
||||
applied to a non-quantized linear layer.
|
||||
"""
|
||||
patches, input = patch_under_test
|
||||
|
||||
linear_layer, quantized_linear_layer = quantized_linear_layer_under_test
|
||||
|
||||
# Move everything to the device.
|
||||
layer_to_device_via_state_dict(linear_layer, device)
|
||||
layer_to_device_via_state_dict(quantized_linear_layer, device)
|
||||
input = input.to(torch.device(device))
|
||||
|
||||
# Wrap both layers in custom layers.
|
||||
linear_layer_custom = wrap_single_custom_layer(linear_layer)
|
||||
quantized_linear_layer_custom = wrap_single_custom_layer(quantized_linear_layer)
|
||||
|
||||
# Apply the patches to the custom layers.
|
||||
for patch, weight in patches:
|
||||
patch.to(torch.device(device))
|
||||
linear_layer_custom.add_patch(patch, weight)
|
||||
quantized_linear_layer_custom.add_patch(patch, weight)
|
||||
|
||||
# Run inference with the original layer and the patched layer and assert they are equal.
|
||||
output_linear_patched = linear_layer_custom(input)
|
||||
output_quantized_patched = quantized_linear_layer_custom(input)
|
||||
assert torch.allclose(output_linear_patched, output_quantized_patched, rtol=0.2, atol=0.2)
|
||||
|
||||
|
||||
@parameterize_cuda_and_mps
|
||||
def test_quantized_linear_sidecar_patches_with_autocast_from_cpu_to_device(
|
||||
device: str,
|
||||
quantized_linear_layer_under_test: tuple[torch.nn.Module, torch.nn.Module],
|
||||
patch_under_test: PatchUnderTest,
|
||||
):
|
||||
"""Test that the output of a linear layer with sidecar patches is the same when the layer is on the device and
|
||||
when the layer is on the CPU and the patches are autocasted to the device.
|
||||
"""
|
||||
patches, input = patch_under_test
|
||||
|
||||
_, quantized_linear_layer = quantized_linear_layer_under_test
|
||||
|
||||
# Move everything to the device.
|
||||
layer_to_device_via_state_dict(quantized_linear_layer, device)
|
||||
input = input.to(torch.device(device))
|
||||
|
||||
# Wrap the quantized linear layer in a custom layer and add the patch to it.
|
||||
quantized_linear_layer_custom = wrap_single_custom_layer(quantized_linear_layer)
|
||||
for patch, weight in patches:
|
||||
patch.to(torch.device(device))
|
||||
quantized_linear_layer_custom.add_patch(patch, weight)
|
||||
|
||||
# Run inference with the custom layer on the device.
|
||||
expected_output = quantized_linear_layer_custom(input)
|
||||
|
||||
# Move the custom layer to the CPU.
|
||||
layer_to_device_via_state_dict(quantized_linear_layer_custom, "cpu")
|
||||
|
||||
# Move the patches to the CPU.
|
||||
quantized_linear_layer_custom.clear_patches()
|
||||
for patch, weight in patches:
|
||||
patch.to(torch.device("cpu"))
|
||||
quantized_linear_layer_custom.add_patch(patch, weight)
|
||||
|
||||
# Run inference with an input on the device, and all layer weights on the CPU. The weights should be autocasted to
|
||||
# the device.
|
||||
autocast_output = quantized_linear_layer_custom(input)
|
||||
assert autocast_output.device.type == device
|
||||
|
||||
# Assert that the outputs with and without autocasting are the same.
|
||||
assert torch.allclose(expected_output, autocast_output, atol=1e-6)
|
@ -0,0 +1,31 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.modules.layers import RMSNorm
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_flux_rms_norm import (
|
||||
CustomFluxRMSNorm,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
|
||||
wrap_custom_layer,
|
||||
)
|
||||
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
|
||||
|
||||
|
||||
def test_custom_flux_rms_norm_patch():
|
||||
"""Test a SetParameterLayer patch on a CustomFluxRMSNorm layer."""
|
||||
# Create a RMSNorm layer.
|
||||
dim = 8
|
||||
rms_norm = RMSNorm(dim)
|
||||
|
||||
# Create a SetParameterLayer.
|
||||
new_scale = torch.randn(dim)
|
||||
set_parameter_layer = SetParameterLayer("scale", new_scale)
|
||||
|
||||
# Wrap the RMSNorm layer in a CustomFluxRMSNorm layer.
|
||||
custom_flux_rms_norm = wrap_custom_layer(rms_norm, CustomFluxRMSNorm)
|
||||
custom_flux_rms_norm.add_patch(set_parameter_layer, 1.0)
|
||||
|
||||
# Run the CustomFluxRMSNorm layer.
|
||||
input = torch.randn(1, dim)
|
||||
expected_output = torch.nn.functional.rms_norm(input, new_scale.shape, new_scale, eps=1e-6)
|
||||
output_custom = custom_flux_rms_norm(input)
|
||||
assert torch.allclose(output_custom, expected_output, atol=1e-6)
|
@ -0,0 +1,82 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
|
||||
wrap_custom_layer,
|
||||
)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available", allow_module_level=True)
|
||||
else:
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_8_bit_lt import (
|
||||
CustomInvokeLinear8bitLt,
|
||||
)
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
|
||||
|
||||
|
||||
def build_linear_8bit_lt_layer(orig_layer: torch.nn.Linear | None = None):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available")
|
||||
|
||||
torch.manual_seed(1)
|
||||
|
||||
if orig_layer is None:
|
||||
orig_layer = torch.nn.Linear(32, 64)
|
||||
orig_layer_state_dict = orig_layer.state_dict()
|
||||
|
||||
# Prepare a quantized InvokeLinear8bitLt layer.
|
||||
quantized_layer = InvokeLinear8bitLt(
|
||||
input_features=orig_layer.in_features, output_features=orig_layer.out_features, has_fp16_weights=False
|
||||
)
|
||||
quantized_layer.load_state_dict(orig_layer_state_dict)
|
||||
quantized_layer.to("cuda")
|
||||
|
||||
# Assert that the InvokeLinear8bitLt layer is quantized.
|
||||
assert quantized_layer.weight.CB is not None
|
||||
assert quantized_layer.weight.SCB is not None
|
||||
assert quantized_layer.weight.CB.dtype == torch.int8
|
||||
|
||||
return quantized_layer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def linear_8bit_lt_layer():
|
||||
return build_linear_8bit_lt_layer()
|
||||
|
||||
|
||||
def test_custom_invoke_linear_8bit_lt_all_weights_on_cuda(linear_8bit_lt_layer: InvokeLinear8bitLt):
|
||||
"""Test CustomInvokeLinear8bitLt inference with all weights on the GPU."""
|
||||
# Run inference on the original layer.
|
||||
x = torch.randn(1, 32).to("cuda")
|
||||
y_quantized = linear_8bit_lt_layer(x)
|
||||
|
||||
# Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it.
|
||||
custom_linear_8bit_lt_layer = wrap_custom_layer(linear_8bit_lt_layer, CustomInvokeLinear8bitLt)
|
||||
y_custom = custom_linear_8bit_lt_layer(x)
|
||||
|
||||
# Assert that the quantized and custom layers produce the same output.
|
||||
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
|
||||
|
||||
|
||||
def test_custom_invoke_linear_8bit_lt_all_weights_on_cpu(linear_8bit_lt_layer: InvokeLinear8bitLt):
|
||||
"""Test CustomInvokeLinear8bitLt inference with all weights on the CPU (streaming to the GPU)."""
|
||||
# Run inference on the original layer.
|
||||
x = torch.randn(1, 32).to("cuda")
|
||||
y_quantized = linear_8bit_lt_layer(x)
|
||||
|
||||
# Copy the state dict to the CPU and reload it.
|
||||
state_dict = linear_8bit_lt_layer.state_dict()
|
||||
state_dict = {k: v.to("cpu") for k, v in state_dict.items()}
|
||||
linear_8bit_lt_layer.load_state_dict(state_dict)
|
||||
|
||||
# Inference of the original layer should fail.
|
||||
with pytest.raises(RuntimeError):
|
||||
linear_8bit_lt_layer(x)
|
||||
|
||||
# Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it.
|
||||
custom_linear_8bit_lt_layer = wrap_custom_layer(linear_8bit_lt_layer, CustomInvokeLinear8bitLt)
|
||||
custom_linear_8bit_lt_layer.set_device_autocasting_enabled(True)
|
||||
y_custom = custom_linear_8bit_lt_layer(x)
|
||||
|
||||
# Assert that the quantized and custom layers produce the same output.
|
||||
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
|
@ -0,0 +1,92 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
|
||||
wrap_custom_layer,
|
||||
)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available", allow_module_level=True)
|
||||
else:
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_nf4 import (
|
||||
CustomInvokeLinearNF4,
|
||||
)
|
||||
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
|
||||
|
||||
|
||||
def build_linear_nf4_layer(orig_layer: torch.nn.Linear | None = None):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available")
|
||||
|
||||
torch.manual_seed(1)
|
||||
|
||||
if orig_layer is None:
|
||||
orig_layer = torch.nn.Linear(64, 16)
|
||||
|
||||
orig_layer_state_dict = orig_layer.state_dict()
|
||||
|
||||
# Prepare a quantized InvokeLinearNF4 layer.
|
||||
quantized_layer = InvokeLinearNF4(input_features=orig_layer.in_features, output_features=orig_layer.out_features)
|
||||
quantized_layer.load_state_dict(orig_layer_state_dict)
|
||||
quantized_layer.to("cuda")
|
||||
|
||||
# Assert that the InvokeLinearNF4 layer is quantized.
|
||||
assert quantized_layer.weight.bnb_quantized
|
||||
|
||||
return quantized_layer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def linear_nf4_layer():
|
||||
return build_linear_nf4_layer()
|
||||
|
||||
|
||||
def test_custom_invoke_linear_nf4_all_weights_on_cuda(linear_nf4_layer: InvokeLinearNF4):
|
||||
"""Test CustomInvokeLinearNF4 inference with all weights on the GPU."""
|
||||
# Run inference on the original layer.
|
||||
x = torch.randn(1, 64).to("cuda")
|
||||
y_quantized = linear_nf4_layer(x)
|
||||
|
||||
# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
|
||||
custom_linear_nf4_layer = wrap_custom_layer(linear_nf4_layer, CustomInvokeLinearNF4)
|
||||
custom_linear_nf4_layer.set_device_autocasting_enabled(True)
|
||||
y_custom = custom_linear_nf4_layer(x)
|
||||
|
||||
# Assert that the quantized and custom layers produce the same output.
|
||||
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
|
||||
|
||||
|
||||
# We run with two different input dimensions, because the NF4 layer follows a different code path depending on the
|
||||
# input dimension, and this has caused issues in the past.
|
||||
@pytest.mark.parametrize("input_dim_0", [1, 2])
|
||||
def test_custom_invoke_linear_nf4_all_weights_on_cpu(linear_nf4_layer: InvokeLinearNF4, input_dim_0: int):
|
||||
"""Test CustomInvokeLinearNF4 inference with all weights on the CPU (streaming to the GPU)."""
|
||||
# Run inference on the original layer.
|
||||
x = torch.randn(input_dim_0, 64).to(device="cuda")
|
||||
y_quantized = linear_nf4_layer(x)
|
||||
|
||||
# Copy the state dict to the CPU and reload it.
|
||||
state_dict = linear_nf4_layer.state_dict()
|
||||
state_dict = {k: v.to("cpu") for k, v in state_dict.items()}
|
||||
linear_nf4_layer.load_state_dict(state_dict)
|
||||
|
||||
# Inference of the original layer should fail.
|
||||
with pytest.raises(RuntimeError):
|
||||
linear_nf4_layer(x)
|
||||
|
||||
# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
|
||||
custom_linear_nf4_layer = wrap_custom_layer(linear_nf4_layer, CustomInvokeLinearNF4)
|
||||
custom_linear_nf4_layer.set_device_autocasting_enabled(True)
|
||||
y_custom = custom_linear_nf4_layer(x)
|
||||
|
||||
# Assert that the state dict (and the tensors that it references) are still on the CPU.
|
||||
assert all(v.device == torch.device("cpu") for v in state_dict.values())
|
||||
|
||||
# Assert that the weight, bias, and quant_state are all on the CPU.
|
||||
assert custom_linear_nf4_layer.weight.device == torch.device("cpu")
|
||||
assert custom_linear_nf4_layer.bias.device == torch.device("cpu")
|
||||
assert custom_linear_nf4_layer.weight.quant_state.absmax.device == torch.device("cpu")
|
||||
assert custom_linear_nf4_layer.weight.quant_state.code.device == torch.device("cpu")
|
||||
|
||||
# Assert that the quantized and custom layers produce the same output.
|
||||
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
|
@ -1,144 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available", allow_module_level=True)
|
||||
else:
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_8_bit_lt import (
|
||||
CustomInvokeLinear8bitLt,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_nf4 import (
|
||||
CustomInvokeLinearNF4,
|
||||
)
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
|
||||
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def linear_8bit_lt_layer():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available")
|
||||
|
||||
torch.manual_seed(1)
|
||||
|
||||
orig_layer = torch.nn.Linear(32, 64)
|
||||
orig_layer_state_dict = orig_layer.state_dict()
|
||||
|
||||
# Prepare a quantized InvokeLinear8bitLt layer.
|
||||
quantized_layer = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False)
|
||||
quantized_layer.load_state_dict(orig_layer_state_dict)
|
||||
quantized_layer.to("cuda")
|
||||
|
||||
# Assert that the InvokeLinear8bitLt layer is quantized.
|
||||
assert quantized_layer.weight.CB is not None
|
||||
assert quantized_layer.weight.SCB is not None
|
||||
assert quantized_layer.weight.CB.dtype == torch.int8
|
||||
|
||||
return quantized_layer
|
||||
|
||||
|
||||
def test_custom_invoke_linear_8bit_lt_all_weights_on_cuda(linear_8bit_lt_layer: InvokeLinear8bitLt):
|
||||
"""Test CustomInvokeLinear8bitLt inference with all weights on the GPU."""
|
||||
# Run inference on the original layer.
|
||||
x = torch.randn(1, 32).to("cuda")
|
||||
y_quantized = linear_8bit_lt_layer(x)
|
||||
|
||||
# Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it.
|
||||
linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt
|
||||
y_custom = linear_8bit_lt_layer(x)
|
||||
|
||||
# Assert that the quantized and custom layers produce the same output.
|
||||
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
|
||||
|
||||
|
||||
def test_custom_invoke_linear_8bit_lt_all_weights_on_cpu(linear_8bit_lt_layer: InvokeLinear8bitLt):
|
||||
"""Test CustomInvokeLinear8bitLt inference with all weights on the CPU (streaming to the GPU)."""
|
||||
# Run inference on the original layer.
|
||||
x = torch.randn(1, 32).to("cuda")
|
||||
y_quantized = linear_8bit_lt_layer(x)
|
||||
|
||||
# Copy the state dict to the CPU and reload it.
|
||||
state_dict = linear_8bit_lt_layer.state_dict()
|
||||
state_dict = {k: v.to("cpu") for k, v in state_dict.items()}
|
||||
linear_8bit_lt_layer.load_state_dict(state_dict)
|
||||
|
||||
# Inference of the original layer should fail.
|
||||
with pytest.raises(RuntimeError):
|
||||
linear_8bit_lt_layer(x)
|
||||
|
||||
# Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it.
|
||||
linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt
|
||||
y_custom = linear_8bit_lt_layer(x)
|
||||
|
||||
# Assert that the quantized and custom layers produce the same output.
|
||||
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def linear_nf4_layer():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available")
|
||||
|
||||
torch.manual_seed(1)
|
||||
|
||||
orig_layer = torch.nn.Linear(64, 16)
|
||||
orig_layer_state_dict = orig_layer.state_dict()
|
||||
|
||||
# Prepare a quantized InvokeLinearNF4 layer.
|
||||
quantized_layer = InvokeLinearNF4(input_features=64, output_features=16)
|
||||
quantized_layer.load_state_dict(orig_layer_state_dict)
|
||||
quantized_layer.to("cuda")
|
||||
|
||||
# Assert that the InvokeLinearNF4 layer is quantized.
|
||||
assert quantized_layer.weight.bnb_quantized
|
||||
|
||||
return quantized_layer
|
||||
|
||||
|
||||
def test_custom_invoke_linear_nf4_all_weights_on_cuda(linear_nf4_layer: InvokeLinearNF4):
|
||||
"""Test CustomInvokeLinearNF4 inference with all weights on the GPU."""
|
||||
# Run inference on the original layer.
|
||||
x = torch.randn(1, 64).to("cuda")
|
||||
y_quantized = linear_nf4_layer(x)
|
||||
|
||||
# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
|
||||
linear_nf4_layer.__class__ = CustomInvokeLinearNF4
|
||||
y_custom = linear_nf4_layer(x)
|
||||
|
||||
# Assert that the quantized and custom layers produce the same output.
|
||||
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
|
||||
|
||||
|
||||
# We run with two different input dimensions, because the NF4 layer follows a different code path depending on the
|
||||
# input dimension, and this has caused issues in the past.
|
||||
@pytest.mark.parametrize("input_dim_0", [1, 2])
|
||||
def test_custom_invoke_linear_nf4_all_weights_on_cpu(linear_nf4_layer: InvokeLinearNF4, input_dim_0: int):
|
||||
"""Test CustomInvokeLinearNF4 inference with all weights on the CPU (streaming to the GPU)."""
|
||||
# Run inference on the original layer.
|
||||
x = torch.randn(input_dim_0, 64).to(device="cuda")
|
||||
y_quantized = linear_nf4_layer(x)
|
||||
|
||||
# Copy the state dict to the CPU and reload it.
|
||||
state_dict = linear_nf4_layer.state_dict()
|
||||
state_dict = {k: v.to("cpu") for k, v in state_dict.items()}
|
||||
linear_nf4_layer.load_state_dict(state_dict)
|
||||
|
||||
# Inference of the original layer should fail.
|
||||
with pytest.raises(RuntimeError):
|
||||
linear_nf4_layer(x)
|
||||
|
||||
# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
|
||||
linear_nf4_layer.__class__ = CustomInvokeLinearNF4
|
||||
y_custom = linear_nf4_layer(x)
|
||||
|
||||
# Assert that the state dict (and the tensors that it references) are still on the CPU.
|
||||
assert all(v.device == torch.device("cpu") for v in state_dict.values())
|
||||
|
||||
# Assert that the weight, bias, and quant_state are all on the CPU.
|
||||
assert linear_nf4_layer.weight.device == torch.device("cpu")
|
||||
assert linear_nf4_layer.bias.device == torch.device("cpu")
|
||||
assert linear_nf4_layer.weight.quant_state.absmax.device == torch.device("cpu")
|
||||
assert linear_nf4_layer.weight.quant_state.code.device == torch.device("cpu")
|
||||
|
||||
# Assert that the quantized and custom layers produce the same output.
|
||||
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
|
@ -72,7 +72,7 @@ def test_torch_module_autocast_linear_layer(device: torch.device, model: torch.n
|
||||
assert expected.device.type == "cpu"
|
||||
|
||||
# Apply the custom layers to the model.
|
||||
apply_custom_layers_to_model(model)
|
||||
apply_custom_layers_to_model(model, device_autocasting_enabled=True)
|
||||
|
||||
# Run the model on the device.
|
||||
autocast_result = model(x.to(device))
|
||||
@ -122,7 +122,7 @@ def test_torch_module_autocast_bnb_llm_int8_linear_layer():
|
||||
|
||||
# Move the model back to the CPU and add the custom layers to the model.
|
||||
model.to("cpu")
|
||||
apply_custom_layers_to_model(model)
|
||||
apply_custom_layers_to_model(model, device_autocasting_enabled=True)
|
||||
|
||||
# Run inference with weights being streamed to the GPU.
|
||||
autocast_result = model(x.to("cuda"))
|
||||
|
@ -18,7 +18,7 @@ def test_flux_control_lora_layer_get_parameters():
|
||||
orig_module = torch.nn.Linear(small_in_features, out_features)
|
||||
|
||||
# Test that get_parameters() behaves as expected in spite of the difference in in_features shapes.
|
||||
params = layer.get_parameters(orig_module, weight=1.0)
|
||||
params = layer.get_parameters(dict(orig_module.named_parameters(recurse=False)), weight=1.0)
|
||||
assert "weight" in params
|
||||
assert params["weight"].shape == (out_features, big_in_features)
|
||||
assert params["weight"].allclose(torch.ones(out_features, big_in_features) * alpha)
|
||||
|
@ -107,7 +107,7 @@ def test_lora_layer_get_parameters():
|
||||
# Create mock original module
|
||||
orig_module = torch.nn.Linear(in_features, out_features)
|
||||
|
||||
params = layer.get_parameters(orig_module, weight=1.0)
|
||||
params = layer.get_parameters(dict(orig_module.named_parameters(recurse=False)), weight=1.0)
|
||||
assert "weight" in params
|
||||
assert params["weight"].shape == orig_module.weight.shape
|
||||
assert params["weight"].allclose(torch.ones(out_features, in_features) * alpha)
|
||||
|
@ -10,7 +10,7 @@ def test_set_parameter_layer_get_parameters():
|
||||
target_weight = torch.randn(8, 4)
|
||||
layer = SetParameterLayer(param_name="weight", weight=target_weight)
|
||||
|
||||
params = layer.get_parameters(orig_module, weight=1.0)
|
||||
params = layer.get_parameters(dict(orig_module.named_parameters(recurse=False)), weight=1.0)
|
||||
assert len(params) == 1
|
||||
new_weight = orig_module.weight + params["weight"]
|
||||
assert torch.allclose(new_weight, target_weight)
|
||||
|
@ -1,23 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
|
||||
from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper
|
||||
|
||||
|
||||
def test_flux_rms_norm_sidecar_wrapper():
|
||||
# Create a RMSNorm layer.
|
||||
dim = 10
|
||||
rms_norm = torch.nn.RMSNorm(dim)
|
||||
|
||||
# Create a SetParameterLayer.
|
||||
new_scale = torch.randn(dim)
|
||||
set_parameter_layer = SetParameterLayer("scale", new_scale)
|
||||
|
||||
# Create a FluxRMSNormSidecarWrapper.
|
||||
rms_norm_wrapped = FluxRMSNormSidecarWrapper(rms_norm, [(set_parameter_layer, 1.0)])
|
||||
|
||||
# Run the FluxRMSNormSidecarWrapper.
|
||||
input = torch.randn(1, dim)
|
||||
expected_output = torch.nn.functional.rms_norm(input, new_scale.shape, new_scale, eps=1e-6)
|
||||
output_wrapped = rms_norm_wrapped(input)
|
||||
assert torch.allclose(output_wrapped, expected_output, atol=1e-6)
|
@ -1,182 +0,0 @@
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
|
||||
from invokeai.backend.patches.layers.full_layer import FullLayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.pad_with_zeros import pad_with_zeros
|
||||
from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_linear_sidecar_wrapper_lora():
|
||||
# Create a linear layer.
|
||||
in_features = 10
|
||||
out_features = 20
|
||||
linear = torch.nn.Linear(in_features, out_features)
|
||||
|
||||
# Create a LoRA layer.
|
||||
rank = 4
|
||||
down = torch.randn(rank, in_features)
|
||||
up = torch.randn(out_features, rank)
|
||||
bias = torch.randn(out_features)
|
||||
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)
|
||||
|
||||
# Patch the LoRA layer into the linear layer.
|
||||
linear_patched = copy.deepcopy(linear)
|
||||
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
|
||||
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
|
||||
|
||||
# Create a LinearSidecarWrapper.
|
||||
lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, 1.0)])
|
||||
|
||||
# Run the LoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal.
|
||||
input = torch.randn(1, in_features)
|
||||
output_patched = linear_patched(input)
|
||||
output_wrapped = lora_wrapped(input)
|
||||
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_linear_sidecar_wrapper_multiple_loras():
|
||||
# Create a linear layer.
|
||||
in_features = 10
|
||||
out_features = 20
|
||||
linear = torch.nn.Linear(in_features, out_features)
|
||||
|
||||
# Create two LoRA layers.
|
||||
rank = 4
|
||||
lora_layer = LoRALayer(
|
||||
up=torch.randn(out_features, rank),
|
||||
mid=None,
|
||||
down=torch.randn(rank, in_features),
|
||||
alpha=1.0,
|
||||
bias=torch.randn(out_features),
|
||||
)
|
||||
lora_layer_2 = LoRALayer(
|
||||
up=torch.randn(out_features, rank),
|
||||
mid=None,
|
||||
down=torch.randn(rank, in_features),
|
||||
alpha=1.0,
|
||||
bias=torch.randn(out_features),
|
||||
)
|
||||
# We use different weights for the two LoRA layers to ensure this is working.
|
||||
lora_weight = 1.0
|
||||
lora_weight_2 = 0.5
|
||||
|
||||
# Patch the LoRA layers into the linear layer.
|
||||
linear_patched = copy.deepcopy(linear)
|
||||
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * (lora_layer.scale() * lora_weight)
|
||||
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * (lora_layer.scale() * lora_weight)
|
||||
linear_patched.weight.data += lora_layer_2.get_weight(linear_patched.weight) * (
|
||||
lora_layer_2.scale() * lora_weight_2
|
||||
)
|
||||
linear_patched.bias.data += lora_layer_2.get_bias(linear_patched.bias) * (lora_layer_2.scale() * lora_weight_2)
|
||||
|
||||
# Create a LinearSidecarWrapper.
|
||||
lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, lora_weight), (lora_layer_2, lora_weight_2)])
|
||||
|
||||
# Run the LoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal.
|
||||
input = torch.randn(1, in_features)
|
||||
output_patched = linear_patched(input)
|
||||
output_wrapped = lora_wrapped(input)
|
||||
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_linear_sidecar_wrapper_concatenated_lora():
|
||||
# Create a linear layer.
|
||||
in_features = 5
|
||||
sub_layer_out_features = [5, 10, 15]
|
||||
linear = torch.nn.Linear(in_features, sum(sub_layer_out_features))
|
||||
|
||||
# Create a ConcatenatedLoRA layer.
|
||||
rank = 4
|
||||
sub_layers: list[LoRALayer] = []
|
||||
for out_features in sub_layer_out_features:
|
||||
down = torch.randn(rank, in_features)
|
||||
up = torch.randn(out_features, rank)
|
||||
bias = torch.randn(out_features)
|
||||
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias))
|
||||
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
|
||||
|
||||
# Patch the ConcatenatedLoRA layer into the linear layer.
|
||||
linear_patched = copy.deepcopy(linear)
|
||||
linear_patched.weight.data += (
|
||||
concatenated_lora_layer.get_weight(linear_patched.weight) * concatenated_lora_layer.scale()
|
||||
)
|
||||
linear_patched.bias.data += concatenated_lora_layer.get_bias(linear_patched.bias) * concatenated_lora_layer.scale()
|
||||
|
||||
# Create a LinearSidecarWrapper.
|
||||
lora_wrapped = LinearSidecarWrapper(linear, [(concatenated_lora_layer, 1.0)])
|
||||
|
||||
# Run the ConcatenatedLoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal.
|
||||
input = torch.randn(1, in_features)
|
||||
output_patched = linear_patched(input)
|
||||
output_wrapped = lora_wrapped(input)
|
||||
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
|
||||
|
||||
|
||||
def test_linear_sidecar_wrapper_full_layer():
|
||||
# Create a linear layer.
|
||||
in_features = 10
|
||||
out_features = 20
|
||||
linear = torch.nn.Linear(in_features, out_features)
|
||||
|
||||
# Create a FullLayer.
|
||||
full_layer = FullLayer(weight=torch.randn(out_features, in_features), bias=torch.randn(out_features))
|
||||
|
||||
# Patch the FullLayer into the linear layer.
|
||||
linear_patched = copy.deepcopy(linear)
|
||||
linear_patched.weight.data += full_layer.get_weight(linear_patched.weight)
|
||||
linear_patched.bias.data += full_layer.get_bias(linear_patched.bias)
|
||||
|
||||
# Create a LinearSidecarWrapper.
|
||||
full_wrapped = LinearSidecarWrapper(linear, [(full_layer, 1.0)])
|
||||
|
||||
# Run the FullLayer-patched linear layer and the LinearSidecarWrapper and assert they are equal.
|
||||
input = torch.randn(1, in_features)
|
||||
output_patched = linear_patched(input)
|
||||
output_wrapped = full_wrapped(input)
|
||||
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
|
||||
|
||||
|
||||
def test_linear_sidecar_wrapper_flux_control_lora_layer():
|
||||
# Create a linear layer.
|
||||
orig_in_features = 10
|
||||
out_features = 40
|
||||
linear = torch.nn.Linear(orig_in_features, out_features)
|
||||
|
||||
# Create a FluxControlLoRALayer.
|
||||
patched_in_features = 20
|
||||
rank = 4
|
||||
lora_layer = FluxControlLoRALayer(
|
||||
up=torch.randn(out_features, rank),
|
||||
mid=None,
|
||||
down=torch.randn(rank, patched_in_features),
|
||||
alpha=1.0,
|
||||
bias=torch.randn(out_features),
|
||||
)
|
||||
|
||||
# Patch the FluxControlLoRALayer into the linear layer.
|
||||
linear_patched = copy.deepcopy(linear)
|
||||
# Expand the existing weight.
|
||||
expanded_weight = pad_with_zeros(linear_patched.weight, torch.Size([out_features, patched_in_features]))
|
||||
linear_patched.weight = torch.nn.Parameter(expanded_weight, requires_grad=linear_patched.weight.requires_grad)
|
||||
# Expand the existing bias.
|
||||
expanded_bias = pad_with_zeros(linear_patched.bias, torch.Size([out_features]))
|
||||
linear_patched.bias = torch.nn.Parameter(expanded_bias, requires_grad=linear_patched.bias.requires_grad)
|
||||
# Add the residuals.
|
||||
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
|
||||
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
|
||||
|
||||
# Create a LinearSidecarWrapper.
|
||||
lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, 1.0)])
|
||||
|
||||
# Run the FluxControlLoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal.
|
||||
input = torch.randn(1, patched_in_features)
|
||||
output_patched = linear_patched(input)
|
||||
output_wrapped = lora_wrapped(input)
|
||||
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
|
313
tests/backend/patches/test_layer_patcher.py
Normal file
313
tests/backend/patches/test_layer_patcher.py
Normal file
@ -0,0 +1,313 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
|
||||
CachedModelWithPartialLoad,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
|
||||
apply_custom_layers_to_model,
|
||||
)
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
|
||||
|
||||
class DummyModuleWithOneLayer(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype):
|
||||
super().__init__()
|
||||
self.linear_layer_1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear_layer_1(x)
|
||||
|
||||
|
||||
class DummyModuleWithTwoLayers(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype):
|
||||
super().__init__()
|
||||
self.linear_layer_1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
|
||||
self.linear_layer_2 = torch.nn.Linear(out_features, out_features, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear_layer_2(self.linear_layer_1(x))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"device",
|
||||
[
|
||||
"cpu",
|
||||
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("num_loras", [1, 2])
|
||||
@pytest.mark.parametrize(
|
||||
["force_sidecar_patching", "force_direct_patching"], [(True, False), (False, True), (False, False)]
|
||||
)
|
||||
@torch.no_grad()
|
||||
def test_apply_smart_model_patches(
|
||||
device: str, num_loras: int, force_sidecar_patching: bool, force_direct_patching: bool
|
||||
):
|
||||
"""Test the basic behavior of ModelPatcher.apply_smart_model_patches(...). Check that unpatching works correctly."""
|
||||
dtype = torch.float16
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_rank = 2
|
||||
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=dtype)
|
||||
apply_custom_layers_to_model(model)
|
||||
|
||||
# Initialize num_loras LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[ModelPatchRaw, float]] = []
|
||||
for _ in range(num_loras):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
lora_models.append((lora, lora_weight))
|
||||
|
||||
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
|
||||
expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_loras)
|
||||
|
||||
# Run inference before patching the model.
|
||||
input = torch.randn(1, linear_in_features, device=device, dtype=dtype)
|
||||
output_before_patch = model(input)
|
||||
|
||||
expect_sidecar_wrappers = device == "cpu"
|
||||
if force_sidecar_patching:
|
||||
expect_sidecar_wrappers = True
|
||||
elif force_direct_patching:
|
||||
expect_sidecar_wrappers = False
|
||||
|
||||
# Patch the model and run inference during the patch.
|
||||
with LayerPatcher.apply_smart_model_patches(
|
||||
model=model,
|
||||
patches=lora_models,
|
||||
prefix="",
|
||||
dtype=dtype,
|
||||
force_direct_patching=force_direct_patching,
|
||||
force_sidecar_patching=force_sidecar_patching,
|
||||
):
|
||||
if expect_sidecar_wrappers:
|
||||
# There should be sidecar patches in the model.
|
||||
assert model.linear_layer_1.get_num_patches() == num_loras
|
||||
else:
|
||||
# There should be no sidecar patches in the model.
|
||||
assert model.linear_layer_1.get_num_patches() == 0
|
||||
torch.testing.assert_close(model.linear_layer_1.weight.data, expected_patched_linear_weight)
|
||||
|
||||
# After patching, the patched model should still be on its original device.
|
||||
assert model.linear_layer_1.weight.data.device.type == device
|
||||
|
||||
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
||||
for lora, _ in lora_models:
|
||||
assert lora.layers["linear_layer_1"].up.device.type == "cpu"
|
||||
assert lora.layers["linear_layer_1"].down.device.type == "cpu"
|
||||
|
||||
output_during_patch = model(input)
|
||||
|
||||
# Run inference after unpatching.
|
||||
output_after_patch = model(input)
|
||||
|
||||
# Check that the output before patching is different from the output during patching.
|
||||
assert not torch.allclose(output_before_patch, output_during_patch)
|
||||
|
||||
# Check that the output before patching is the same as the output after patching.
|
||||
assert torch.allclose(output_before_patch, output_after_patch)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["num_loras"], [(1,), (2,)])
|
||||
@torch.no_grad()
|
||||
def test_apply_smart_lora_patches_to_partially_loaded_model(num_loras: int):
|
||||
"""Test the behavior of ModelPatcher.apply_smart_lora_patches(...) when it is applied to a
|
||||
CachedModelWithPartialLoad that is partially loaded into VRAM.
|
||||
"""
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("requires CUDA device")
|
||||
|
||||
# Initialize the model on the CPU.
|
||||
dtype = torch.float16
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_rank = 2
|
||||
model = DummyModuleWithTwoLayers(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
|
||||
apply_custom_layers_to_model(model)
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device("cuda"))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Partially load the model into VRAM.
|
||||
target_vram_bytes = int(model_total_bytes * 0.6)
|
||||
_ = cached_model.partial_load_to_vram(target_vram_bytes)
|
||||
assert cached_model.model.linear_layer_1.weight.device.type == "cuda"
|
||||
assert cached_model.model.linear_layer_2.weight.device.type == "cpu"
|
||||
|
||||
# Initialize num_loras LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[ModelPatchRaw, float]] = []
|
||||
for _ in range(num_loras):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
),
|
||||
"linear_layer_2": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_out_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
),
|
||||
}
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
lora_models.append((lora, lora_weight))
|
||||
|
||||
# Run inference before patching the model.
|
||||
input = torch.randn(1, linear_in_features, device="cuda", dtype=dtype)
|
||||
output_before_patch = cached_model.model(input)
|
||||
|
||||
# Patch the model and run inference during the patch.
|
||||
with LayerPatcher.apply_smart_model_patches(model=cached_model.model, patches=lora_models, prefix="", dtype=dtype):
|
||||
# Check that the second layer has sidecar patches, but the first layer does not.
|
||||
assert cached_model.model.linear_layer_1.get_num_patches() == 0
|
||||
assert cached_model.model.linear_layer_2.get_num_patches() == num_loras
|
||||
|
||||
output_during_patch = cached_model.model(input)
|
||||
|
||||
# Run inference after unpatching.
|
||||
output_after_patch = cached_model.model(input)
|
||||
|
||||
# Check that the output before patching is different from the output during patching.
|
||||
assert not torch.allclose(output_before_patch, output_during_patch)
|
||||
|
||||
# Check that the output before patching is the same as the output after patching.
|
||||
assert torch.allclose(output_before_patch, output_after_patch)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize(["num_loras"], [(1,), (2,)])
|
||||
def test_all_patching_methods_produce_same_output(num_loras: int):
|
||||
"""Test that apply_lora_wrapper_patches(...) produces the same model outputs as apply_lora_patches(...)."""
|
||||
dtype = torch.float32
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_rank = 2
|
||||
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
|
||||
apply_custom_layers_to_model(model)
|
||||
|
||||
# Initialize num_loras LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[ModelPatchRaw, float]] = []
|
||||
for _ in range(num_loras):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
lora_models.append((lora, lora_weight))
|
||||
|
||||
input = torch.randn(1, linear_in_features, device="cpu", dtype=dtype)
|
||||
|
||||
with LayerPatcher.apply_smart_model_patches(
|
||||
model=model, patches=lora_models, prefix="", dtype=dtype, force_direct_patching=True
|
||||
):
|
||||
output_force_direct = model(input)
|
||||
|
||||
with LayerPatcher.apply_smart_model_patches(
|
||||
model=model, patches=lora_models, prefix="", dtype=dtype, force_sidecar_patching=True
|
||||
):
|
||||
output_force_sidecar = model(input)
|
||||
|
||||
with LayerPatcher.apply_smart_model_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||
output_smart = model(input)
|
||||
|
||||
# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical
|
||||
# differences are tolerable and expected due to the difference between sidecar vs. patching.
|
||||
assert torch.allclose(output_force_direct, output_force_sidecar, atol=1e-5)
|
||||
assert torch.allclose(output_force_direct, output_smart, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
|
||||
@torch.no_grad()
|
||||
def test_apply_smart_model_patches_change_device():
|
||||
"""Test that if LoRA patching is applied on the CPU, and then the patched model is moved to the GPU, unpatching
|
||||
still behaves correctly.
|
||||
"""
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_dim = 2
|
||||
# Initialize the model on the CPU.
|
||||
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
|
||||
apply_custom_layers_to_model(model)
|
||||
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
|
||||
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
|
||||
|
||||
with LayerPatcher.apply_smart_model_patches(
|
||||
model=model, patches=[(lora, 0.5)], prefix="", dtype=torch.float16, force_direct_patching=True
|
||||
):
|
||||
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
||||
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
|
||||
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
|
||||
|
||||
# After patching, the patched model should still be on the CPU.
|
||||
assert model.linear_layer_1.weight.data.device.type == "cpu"
|
||||
|
||||
# There should be no sidecar patches in the model.
|
||||
assert model.linear_layer_1.get_num_patches() == 0
|
||||
|
||||
# Move the model to the GPU.
|
||||
assert model.to("cuda")
|
||||
|
||||
# After unpatching, the original model weights should have been restored on the GPU.
|
||||
assert model.linear_layer_1.weight.data.device.type == "cuda"
|
||||
torch.testing.assert_close(model.linear_layer_1.weight.data, orig_linear_weight, check_device=False)
|
||||
|
||||
|
||||
def test_apply_smart_model_patches_force_sidecar_and_direct_patching():
|
||||
"""Test that ModelPatcher.apply_smart_model_patches(..., force_direct_patching=True, force_sidecar_patching=True)
|
||||
raises an error.
|
||||
"""
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_rank = 2
|
||||
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
|
||||
apply_custom_layers_to_model(model)
|
||||
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
with pytest.raises(ValueError, match="Cannot force both direct and sidecar patching."):
|
||||
with LayerPatcher.apply_smart_model_patches(
|
||||
model=model,
|
||||
patches=[(lora, 0.5)],
|
||||
prefix="",
|
||||
dtype=torch.float16,
|
||||
force_direct_patching=True,
|
||||
force_sidecar_patching=True,
|
||||
):
|
||||
pass
|
@ -1,197 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import LayerPatcher
|
||||
|
||||
|
||||
class DummyModule(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype):
|
||||
super().__init__()
|
||||
self.linear_layer_1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear_layer_1(x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["device", "num_layers"],
|
||||
[
|
||||
("cpu", 1),
|
||||
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
("cpu", 2),
|
||||
pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
],
|
||||
)
|
||||
@torch.no_grad()
|
||||
def test_apply_lora_patches(device: str, num_layers: int):
|
||||
"""Test the basic behavior of ModelPatcher.apply_lora_patches(...). Check that patching and unpatching produce the
|
||||
correct result, and that model/LoRA tensors are moved between devices as expected.
|
||||
"""
|
||||
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_rank = 2
|
||||
model = DummyModule(linear_in_features, linear_out_features, device=device, dtype=torch.float16)
|
||||
|
||||
# Initialize num_layers LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[ModelPatchRaw, float]] = []
|
||||
for _ in range(num_layers):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
lora_models.append((lora, lora_weight))
|
||||
|
||||
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
|
||||
expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_layers)
|
||||
|
||||
with LayerPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""):
|
||||
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
||||
for lora, _ in lora_models:
|
||||
assert lora.layers["linear_layer_1"].up.device.type == "cpu"
|
||||
assert lora.layers["linear_layer_1"].down.device.type == "cpu"
|
||||
|
||||
# After patching, the patched model should still be on its original device.
|
||||
assert model.linear_layer_1.weight.data.device.type == device
|
||||
|
||||
torch.testing.assert_close(model.linear_layer_1.weight.data, expected_patched_linear_weight)
|
||||
|
||||
# After unpatching, the original model weights should have been restored on the original device.
|
||||
assert model.linear_layer_1.weight.data.device.type == device
|
||||
torch.testing.assert_close(model.linear_layer_1.weight.data, orig_linear_weight)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
|
||||
@torch.no_grad()
|
||||
def test_apply_lora_patches_change_device():
|
||||
"""Test that if LoRA patching is applied on the CPU, and then the patched model is moved to the GPU, unpatching
|
||||
still behaves correctly.
|
||||
"""
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_dim = 2
|
||||
# Initialize the model on the CPU.
|
||||
model = DummyModule(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
|
||||
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
|
||||
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
|
||||
|
||||
with LayerPatcher.apply_model_patches(model=model, patches=[(lora, 0.5)], prefix=""):
|
||||
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
||||
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
|
||||
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
|
||||
|
||||
# After patching, the patched model should still be on the CPU.
|
||||
assert model.linear_layer_1.weight.data.device.type == "cpu"
|
||||
|
||||
# Move the model to the GPU.
|
||||
assert model.to("cuda")
|
||||
|
||||
# After unpatching, the original model weights should have been restored on the GPU.
|
||||
assert model.linear_layer_1.weight.data.device.type == "cuda"
|
||||
torch.testing.assert_close(model.linear_layer_1.weight.data, orig_linear_weight, check_device=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["device", "num_layers"],
|
||||
[
|
||||
("cpu", 1),
|
||||
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
("cpu", 2),
|
||||
pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
],
|
||||
)
|
||||
def test_apply_lora_sidecar_patches(device: str, num_layers: int):
|
||||
"""Test the basic behavior of ModelPatcher.apply_lora_sidecar_patches(...). Check that unpatching works correctly."""
|
||||
dtype = torch.float16
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_rank = 2
|
||||
model = DummyModule(linear_in_features, linear_out_features, device=device, dtype=dtype)
|
||||
|
||||
# Initialize num_layers LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[ModelPatchRaw, float]] = []
|
||||
for _ in range(num_layers):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
lora_models.append((lora, lora_weight))
|
||||
|
||||
# Run inference before patching the model.
|
||||
input = torch.randn(1, linear_in_features, device=device, dtype=dtype)
|
||||
output_before_patch = model(input)
|
||||
|
||||
# Patch the model and run inference during the patch.
|
||||
with LayerPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||
output_during_patch = model(input)
|
||||
|
||||
# Run inference after unpatching.
|
||||
output_after_patch = model(input)
|
||||
|
||||
# Check that the output before patching is different from the output during patching.
|
||||
assert not torch.allclose(output_before_patch, output_during_patch)
|
||||
|
||||
# Check that the output before patching is the same as the output after patching.
|
||||
assert torch.allclose(output_before_patch, output_after_patch)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize(["num_layers"], [(1,), (2,)])
|
||||
def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int):
|
||||
"""Test that apply_lora_sidecar_patches(...) produces the same model outputs as apply_lora_patches(...)."""
|
||||
dtype = torch.float32
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_rank = 2
|
||||
model = DummyModule(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
|
||||
|
||||
# Initialize num_layers LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[ModelPatchRaw, float]] = []
|
||||
for _ in range(num_layers):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
lora_models.append((lora, lora_weight))
|
||||
|
||||
input = torch.randn(1, linear_in_features, device="cpu", dtype=dtype)
|
||||
|
||||
with LayerPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""):
|
||||
output_lora_patches = model(input)
|
||||
|
||||
with LayerPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||
output_lora_sidecar_patches = model(input)
|
||||
|
||||
# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical
|
||||
# differences are tolerable and expected due to the difference between sidecar vs. patching.
|
||||
assert torch.allclose(output_lora_patches, output_lora_sidecar_patches, atol=1e-5)
|
Loading…
Reference in New Issue
Block a user