mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-04-03 07:21:32 +08:00
278 lines
12 KiB
Python
278 lines
12 KiB
Python
from contextlib import contextmanager
|
|
from typing import Dict, Iterable, Optional, Tuple
|
|
|
|
import torch
|
|
|
|
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
|
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
|
|
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
|
from invokeai.backend.patches.pad_with_zeros import pad_with_zeros
|
|
from invokeai.backend.util.devices import TorchDevice
|
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
|
|
|
|
|
class LayerPatcher:
|
|
@staticmethod
|
|
@torch.no_grad()
|
|
@contextmanager
|
|
def apply_smart_model_patches(
|
|
model: torch.nn.Module,
|
|
patches: Iterable[Tuple[ModelPatchRaw, float]],
|
|
prefix: str,
|
|
dtype: torch.dtype,
|
|
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
|
force_direct_patching: bool = False,
|
|
force_sidecar_patching: bool = False,
|
|
):
|
|
"""Apply 'smart' model patching that chooses whether to use direct patching or a sidecar wrapper for each
|
|
module.
|
|
"""
|
|
|
|
# original_weights are stored for unpatching layers that are directly patched.
|
|
original_weights = OriginalWeightsStorage(cached_weights)
|
|
# original_modules are stored for unpatching layers that are wrapped.
|
|
original_modules: dict[str, torch.nn.Module] = {}
|
|
try:
|
|
for patch, patch_weight in patches:
|
|
LayerPatcher.apply_smart_model_patch(
|
|
model=model,
|
|
prefix=prefix,
|
|
patch=patch,
|
|
patch_weight=patch_weight,
|
|
original_weights=original_weights,
|
|
original_modules=original_modules,
|
|
dtype=dtype,
|
|
force_direct_patching=force_direct_patching,
|
|
force_sidecar_patching=force_sidecar_patching,
|
|
)
|
|
|
|
yield
|
|
finally:
|
|
# Restore directly patched layers.
|
|
for param_key, weight in original_weights.get_changed_weights():
|
|
cur_param = model.get_parameter(param_key)
|
|
cur_param.data = weight.to(dtype=cur_param.dtype, device=cur_param.device, copy=True)
|
|
|
|
# Clear patches from all patched modules.
|
|
# Note: This logic assumes no nested modules in original_modules.
|
|
for orig_module in original_modules.values():
|
|
orig_module.clear_patches()
|
|
|
|
@staticmethod
|
|
@torch.no_grad()
|
|
def apply_smart_model_patch(
|
|
model: torch.nn.Module,
|
|
prefix: str,
|
|
patch: ModelPatchRaw,
|
|
patch_weight: float,
|
|
original_weights: OriginalWeightsStorage,
|
|
original_modules: dict[str, torch.nn.Module],
|
|
dtype: torch.dtype,
|
|
force_direct_patching: bool,
|
|
force_sidecar_patching: bool,
|
|
):
|
|
"""Apply a single LoRA patch to a model using the 'smart' patching strategy that chooses whether to use direct
|
|
patching or a sidecar wrapper for each module.
|
|
"""
|
|
if patch_weight == 0:
|
|
return
|
|
|
|
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
|
|
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
|
|
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
|
|
# without searching, but some legacy code still uses flattened keys.
|
|
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
|
|
|
|
prefix_len = len(prefix)
|
|
|
|
for layer_key, layer in patch.layers.items():
|
|
if not layer_key.startswith(prefix):
|
|
continue
|
|
|
|
module_key, module = LayerPatcher._get_submodule(
|
|
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
|
)
|
|
|
|
# Decide whether to use direct patching or a sidecar patch.
|
|
# Direct patching is preferred, because it results in better runtime speed.
|
|
# Reasons to use sidecar patching:
|
|
# - The module is quantized, so the caller passed force_sidecar_patching=True.
|
|
# - The module already has sidecar patches.
|
|
# - The module is on the CPU (and we don't want to store a second full copy of the original weights on the
|
|
# CPU, since this would double the RAM usage)
|
|
# NOTE: For now, we don't check if the layer is quantized here. We assume that this is checked in the caller
|
|
# and that the caller will set force_sidecar_patching=True if the layer is quantized.
|
|
# TODO(ryand): Handle the case where we are running without a GPU. Should we set a config flag that allows
|
|
# forcing full patching even on the CPU?
|
|
use_sidecar_patching = False
|
|
if force_direct_patching and force_sidecar_patching:
|
|
raise ValueError("Cannot force both direct and sidecar patching.")
|
|
elif force_direct_patching:
|
|
use_sidecar_patching = False
|
|
elif force_sidecar_patching:
|
|
use_sidecar_patching = True
|
|
elif module.get_num_patches() > 0:
|
|
use_sidecar_patching = True
|
|
elif LayerPatcher._is_any_part_of_layer_on_cpu(module):
|
|
use_sidecar_patching = True
|
|
|
|
if use_sidecar_patching:
|
|
LayerPatcher._apply_model_layer_wrapper_patch(
|
|
module_to_patch=module,
|
|
module_to_patch_key=module_key,
|
|
patch=layer,
|
|
patch_weight=patch_weight,
|
|
original_modules=original_modules,
|
|
dtype=dtype,
|
|
)
|
|
else:
|
|
LayerPatcher._apply_model_layer_patch(
|
|
module_to_patch=module,
|
|
module_to_patch_key=module_key,
|
|
patch=layer,
|
|
patch_weight=patch_weight,
|
|
original_weights=original_weights,
|
|
)
|
|
|
|
@staticmethod
|
|
def _is_any_part_of_layer_on_cpu(layer: torch.nn.Module) -> bool:
|
|
return any(p.device.type == "cpu" for p in layer.parameters())
|
|
|
|
@staticmethod
|
|
@torch.no_grad()
|
|
def _apply_model_layer_patch(
|
|
module_to_patch: torch.nn.Module,
|
|
module_to_patch_key: str,
|
|
patch: BaseLayerPatch,
|
|
patch_weight: float,
|
|
original_weights: OriginalWeightsStorage,
|
|
):
|
|
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
|
# (Performance will be best if this is a CUDA device.)
|
|
first_param = next(module_to_patch.parameters())
|
|
device = first_param.device
|
|
dtype = first_param.dtype
|
|
|
|
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
|
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
|
# same thing in a single call to '.to(...)'.
|
|
patch.to(device=device)
|
|
patch.to(dtype=torch.float32)
|
|
|
|
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
|
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
|
for param_name, param_weight in patch.get_parameters(
|
|
dict(module_to_patch.named_parameters(recurse=False)), weight=patch_weight
|
|
).items():
|
|
param_key = module_to_patch_key + "." + param_name
|
|
module_param = module_to_patch.get_parameter(param_name)
|
|
|
|
# Save original weight
|
|
original_weights.save(param_key, module_param)
|
|
|
|
# HACK(ryand): This condition is only necessary to handle layers in FLUX control LoRAs that change the
|
|
# shape of the original layer.
|
|
if module_param.nelement() != param_weight.nelement():
|
|
assert isinstance(patch, FluxControlLoRALayer)
|
|
expanded_weight = pad_with_zeros(module_param, param_weight.shape)
|
|
setattr(
|
|
module_to_patch,
|
|
param_name,
|
|
torch.nn.Parameter(expanded_weight, requires_grad=module_param.requires_grad),
|
|
)
|
|
module_param = expanded_weight
|
|
|
|
module_param += param_weight.to(dtype=dtype)
|
|
|
|
patch.to(device=TorchDevice.CPU_DEVICE)
|
|
|
|
@staticmethod
|
|
@torch.no_grad()
|
|
def _apply_model_layer_wrapper_patch(
|
|
module_to_patch: torch.nn.Module,
|
|
module_to_patch_key: str,
|
|
patch: BaseLayerPatch,
|
|
patch_weight: float,
|
|
original_modules: dict[str, torch.nn.Module],
|
|
dtype: torch.dtype,
|
|
):
|
|
"""Apply a single LoRA wrapper patch to a module."""
|
|
# Move the LoRA layer to the same device/dtype as the orig module.
|
|
first_param = next(module_to_patch.parameters())
|
|
device = first_param.device
|
|
patch.to(device=device, dtype=dtype)
|
|
|
|
if module_to_patch_key not in original_modules:
|
|
original_modules[module_to_patch_key] = module_to_patch
|
|
|
|
module_to_patch.add_patch(patch, patch_weight)
|
|
|
|
@staticmethod
|
|
def _split_parent_key(module_key: str) -> tuple[str, str]:
|
|
"""Split a module key into its parent key and module name.
|
|
|
|
Args:
|
|
module_key (str): The module key to split.
|
|
|
|
Returns:
|
|
tuple[str, str]: A tuple containing the parent key and module name.
|
|
"""
|
|
split_key = module_key.rsplit(".", 1)
|
|
if len(split_key) == 2:
|
|
return tuple(split_key)
|
|
elif len(split_key) == 1:
|
|
return "", split_key[0]
|
|
else:
|
|
raise ValueError(f"Invalid module key: {module_key}")
|
|
|
|
@staticmethod
|
|
def _set_submodule(parent_module: torch.nn.Module, module_name: str, submodule: torch.nn.Module):
|
|
try:
|
|
submodule_index = int(module_name)
|
|
# If the module name is an integer, then we use the __setitem__ method to set the submodule.
|
|
parent_module[submodule_index] = submodule # type: ignore
|
|
except ValueError:
|
|
# If the module name is not an integer, then we use the setattr method to set the submodule.
|
|
setattr(parent_module, module_name, submodule)
|
|
|
|
@staticmethod
|
|
def _get_submodule(
|
|
model: torch.nn.Module, layer_key: str, layer_key_is_flattened: bool
|
|
) -> tuple[str, torch.nn.Module]:
|
|
"""Get the submodule corresponding to the given layer key.
|
|
|
|
Args:
|
|
model (torch.nn.Module): The model to search.
|
|
layer_key (str): The layer key to search for.
|
|
layer_key_is_flattened (bool): Whether the layer key is flattened. If flattened, then all '.' have been
|
|
replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed
|
|
directly without searching, but some legacy code still uses flattened keys.
|
|
|
|
Returns:
|
|
tuple[str, torch.nn.Module]: A tuple containing the module key and the submodule.
|
|
"""
|
|
if not layer_key_is_flattened:
|
|
return layer_key, model.get_submodule(layer_key)
|
|
|
|
# Handle flattened keys.
|
|
assert "." not in layer_key
|
|
|
|
module = model
|
|
module_key = ""
|
|
key_parts = layer_key.split("_")
|
|
|
|
submodule_name = key_parts.pop(0)
|
|
|
|
while len(key_parts) > 0:
|
|
try:
|
|
module = module.get_submodule(submodule_name)
|
|
module_key += "." + submodule_name
|
|
submodule_name = key_parts.pop(0)
|
|
except Exception:
|
|
submodule_name += "_" + key_parts.pop(0)
|
|
|
|
module = module.get_submodule(submodule_name)
|
|
module_key = (module_key + "." + submodule_name).lstrip(".")
|
|
|
|
return module_key, module
|