InvokeAI/invokeai/backend/patches/layer_patcher.py

278 lines
12 KiB
Python

from contextlib import contextmanager
from typing import Dict, Iterable, Optional, Tuple
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.pad_with_zeros import pad_with_zeros
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
class LayerPatcher:
@staticmethod
@torch.no_grad()
@contextmanager
def apply_smart_model_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[ModelPatchRaw, float]],
prefix: str,
dtype: torch.dtype,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
force_direct_patching: bool = False,
force_sidecar_patching: bool = False,
):
"""Apply 'smart' model patching that chooses whether to use direct patching or a sidecar wrapper for each
module.
"""
# original_weights are stored for unpatching layers that are directly patched.
original_weights = OriginalWeightsStorage(cached_weights)
# original_modules are stored for unpatching layers that are wrapped.
original_modules: dict[str, torch.nn.Module] = {}
try:
for patch, patch_weight in patches:
LayerPatcher.apply_smart_model_patch(
model=model,
prefix=prefix,
patch=patch,
patch_weight=patch_weight,
original_weights=original_weights,
original_modules=original_modules,
dtype=dtype,
force_direct_patching=force_direct_patching,
force_sidecar_patching=force_sidecar_patching,
)
yield
finally:
# Restore directly patched layers.
for param_key, weight in original_weights.get_changed_weights():
cur_param = model.get_parameter(param_key)
cur_param.data = weight.to(dtype=cur_param.dtype, device=cur_param.device, copy=True)
# Clear patches from all patched modules.
# Note: This logic assumes no nested modules in original_modules.
for orig_module in original_modules.values():
orig_module.clear_patches()
@staticmethod
@torch.no_grad()
def apply_smart_model_patch(
model: torch.nn.Module,
prefix: str,
patch: ModelPatchRaw,
patch_weight: float,
original_weights: OriginalWeightsStorage,
original_modules: dict[str, torch.nn.Module],
dtype: torch.dtype,
force_direct_patching: bool,
force_sidecar_patching: bool,
):
"""Apply a single LoRA patch to a model using the 'smart' patching strategy that chooses whether to use direct
patching or a sidecar wrapper for each module.
"""
if patch_weight == 0:
return
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
# without searching, but some legacy code still uses flattened keys.
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
prefix_len = len(prefix)
for layer_key, layer in patch.layers.items():
if not layer_key.startswith(prefix):
continue
module_key, module = LayerPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
# Decide whether to use direct patching or a sidecar patch.
# Direct patching is preferred, because it results in better runtime speed.
# Reasons to use sidecar patching:
# - The module is quantized, so the caller passed force_sidecar_patching=True.
# - The module already has sidecar patches.
# - The module is on the CPU (and we don't want to store a second full copy of the original weights on the
# CPU, since this would double the RAM usage)
# NOTE: For now, we don't check if the layer is quantized here. We assume that this is checked in the caller
# and that the caller will set force_sidecar_patching=True if the layer is quantized.
# TODO(ryand): Handle the case where we are running without a GPU. Should we set a config flag that allows
# forcing full patching even on the CPU?
use_sidecar_patching = False
if force_direct_patching and force_sidecar_patching:
raise ValueError("Cannot force both direct and sidecar patching.")
elif force_direct_patching:
use_sidecar_patching = False
elif force_sidecar_patching:
use_sidecar_patching = True
elif module.get_num_patches() > 0:
use_sidecar_patching = True
elif LayerPatcher._is_any_part_of_layer_on_cpu(module):
use_sidecar_patching = True
if use_sidecar_patching:
LayerPatcher._apply_model_layer_wrapper_patch(
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
patch_weight=patch_weight,
original_modules=original_modules,
dtype=dtype,
)
else:
LayerPatcher._apply_model_layer_patch(
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
patch_weight=patch_weight,
original_weights=original_weights,
)
@staticmethod
def _is_any_part_of_layer_on_cpu(layer: torch.nn.Module) -> bool:
return any(p.device.type == "cpu" for p in layer.parameters())
@staticmethod
@torch.no_grad()
def _apply_model_layer_patch(
module_to_patch: torch.nn.Module,
module_to_patch_key: str,
patch: BaseLayerPatch,
patch_weight: float,
original_weights: OriginalWeightsStorage,
):
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
first_param = next(module_to_patch.parameters())
device = first_param.device
dtype = first_param.dtype
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
patch.to(device=device)
patch.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, param_weight in patch.get_parameters(
dict(module_to_patch.named_parameters(recurse=False)), weight=patch_weight
).items():
param_key = module_to_patch_key + "." + param_name
module_param = module_to_patch.get_parameter(param_name)
# Save original weight
original_weights.save(param_key, module_param)
# HACK(ryand): This condition is only necessary to handle layers in FLUX control LoRAs that change the
# shape of the original layer.
if module_param.nelement() != param_weight.nelement():
assert isinstance(patch, FluxControlLoRALayer)
expanded_weight = pad_with_zeros(module_param, param_weight.shape)
setattr(
module_to_patch,
param_name,
torch.nn.Parameter(expanded_weight, requires_grad=module_param.requires_grad),
)
module_param = expanded_weight
module_param += param_weight.to(dtype=dtype)
patch.to(device=TorchDevice.CPU_DEVICE)
@staticmethod
@torch.no_grad()
def _apply_model_layer_wrapper_patch(
module_to_patch: torch.nn.Module,
module_to_patch_key: str,
patch: BaseLayerPatch,
patch_weight: float,
original_modules: dict[str, torch.nn.Module],
dtype: torch.dtype,
):
"""Apply a single LoRA wrapper patch to a module."""
# Move the LoRA layer to the same device/dtype as the orig module.
first_param = next(module_to_patch.parameters())
device = first_param.device
patch.to(device=device, dtype=dtype)
if module_to_patch_key not in original_modules:
original_modules[module_to_patch_key] = module_to_patch
module_to_patch.add_patch(patch, patch_weight)
@staticmethod
def _split_parent_key(module_key: str) -> tuple[str, str]:
"""Split a module key into its parent key and module name.
Args:
module_key (str): The module key to split.
Returns:
tuple[str, str]: A tuple containing the parent key and module name.
"""
split_key = module_key.rsplit(".", 1)
if len(split_key) == 2:
return tuple(split_key)
elif len(split_key) == 1:
return "", split_key[0]
else:
raise ValueError(f"Invalid module key: {module_key}")
@staticmethod
def _set_submodule(parent_module: torch.nn.Module, module_name: str, submodule: torch.nn.Module):
try:
submodule_index = int(module_name)
# If the module name is an integer, then we use the __setitem__ method to set the submodule.
parent_module[submodule_index] = submodule # type: ignore
except ValueError:
# If the module name is not an integer, then we use the setattr method to set the submodule.
setattr(parent_module, module_name, submodule)
@staticmethod
def _get_submodule(
model: torch.nn.Module, layer_key: str, layer_key_is_flattened: bool
) -> tuple[str, torch.nn.Module]:
"""Get the submodule corresponding to the given layer key.
Args:
model (torch.nn.Module): The model to search.
layer_key (str): The layer key to search for.
layer_key_is_flattened (bool): Whether the layer key is flattened. If flattened, then all '.' have been
replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed
directly without searching, but some legacy code still uses flattened keys.
Returns:
tuple[str, torch.nn.Module]: A tuple containing the module key and the submodule.
"""
if not layer_key_is_flattened:
return layer_key, model.get_submodule(layer_key)
# Handle flattened keys.
assert "." not in layer_key
module = model
module_key = ""
key_parts = layer_key.split("_")
submodule_name = key_parts.pop(0)
while len(key_parts) > 0:
try:
module = module.get_submodule(submodule_name)
module_key += "." + submodule_name
submodule_name = key_parts.pop(0)
except Exception:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
module_key = (module_key + "." + submodule_name).lstrip(".")
return module_key, module