Minor cleanup and documentation updates.

This commit is contained in:
Ryan Dick 2024-09-13 13:57:00 +00:00 committed by Kent Keirsey
parent ae41651346
commit 61d3d566de
5 changed files with 47 additions and 28 deletions

View File

@ -40,7 +40,7 @@ class FluxLoRALoaderInvocation(BaseInvocation):
raise ValueError(f"Unknown lora: {lora_key}!")
if any(lora.lora.key == lora_key for lora in self.transformer.loras):
raise Exception(f'LoRA "{lora_key}" already applied to transformer.')
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
transformer = self.transformer.model_copy(deep=True)
transformer.loras.append(

View File

@ -30,7 +30,7 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
return all_keys_in_peft_format and all_expected_keys_present
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float) -> LoRAModelRaw: # pyright: ignore[reportRedeclaration] (state_dict is intentionally re-declared)
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float) -> LoRAModelRaw:
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
This function is based on:

View File

@ -41,6 +41,3 @@ class ConcatenatedLoRALayer(LoRALayerBase):
assert len(layer_biases) == len(self.lora_layers)
return torch.cat(layer_biases, dim=self.concat_axis)
def calc_size(self) -> int:
return sum(lora_layer.calc_size() for lora_layer in self.lora_layers)

View File

@ -28,11 +28,14 @@ class LoRAPatcher:
):
"""Apply one or more LoRA patches to a model within a context manager.
:param model: The model to patch.
:param loras: An iterator that returns tuples of LoRA patches and associated weights. An iterator is used so
that the LoRA patches do not need to be loaded into memory all at once.
:param prefix: The keys in the patches will be filtered to only include weights with this prefix.
:cached_weights: Read-only copy of the model's state dict in CPU, for efficient unpatching purposes.
Args:
model (torch.nn.Module): The model to patch.
patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
all at once.
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
cached_weights (Optional[Dict[str, torch.Tensor]], optional): Read-only copy of the model's state dict in
CPU RAM, for efficient unpatching purposes.
"""
original_weights = OriginalWeightsStorage(cached_weights)
try:
@ -60,15 +63,15 @@ class LoRAPatcher:
patch_weight: float,
original_weights: OriginalWeightsStorage,
):
"""
Apply a single LoRA patch to a model.
:param model: The model to patch.
:param patch: LoRA model to patch in.
:param patch_weight: LoRA patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
"""
"""Apply a single LoRA patch to a model.
Args:
model (torch.nn.Module): The model to patch.
prefix (str): A string prefix that precedes keys used in the LoRAs weight layers.
patch (LoRAModelRaw): The LoRA model to patch in.
patch_weight (float): The weight of the LoRA patch.
original_weights (OriginalWeightsStorage): Storage for the original weights of the model, for unpatching.
"""
if patch_weight == 0:
return
@ -126,6 +129,17 @@ class LoRAPatcher:
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
):
"""Apply one or more LoRA sidecar patches to a model within a context manager. Sidecar patches incur some
overhead compared to normal LoRA patching, but they allow for LoRA layers to applied to base layers in any
quantization format.
Args:
model (torch.nn.Module): The model to patch.
patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
all at once.
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
"""
original_modules: dict[str, torch.nn.Module] = {}
try:
for patch, patch_weight in patches:
@ -136,7 +150,6 @@ class LoRAPatcher:
patch_weight=patch_weight,
original_modules=original_modules,
)
yield
finally:
# Restore original modules.
@ -154,6 +167,8 @@ class LoRAPatcher:
prefix: str,
original_modules: dict[str, torch.nn.Module],
):
"""Apply a single LoRA sidecar patch to a model."""
if patch_weight == 0:
return
@ -178,8 +193,8 @@ class LoRAPatcher:
# Move the LoRA sidecar layer to the same device/dtype as the orig module.
# TODO(ryand): Experiment with moving to the device first, then casting. This could be faster.
# HACK(ryand): Set the dtype properly here. We want to set it to the *compute* dtype of the original module.
# In the case of quantized layers, this may be different than the weight dtype.
# HACK(ryand): Figure out how to set the dtype properly here. We want to set it to the *compute* dtype of
# the original module. In the case of quantized layers, this may be different than the weight dtype.
lora_sidecar_layer.to(device=module.weight.device, dtype=torch.bfloat16)
if module_key in original_modules:
@ -196,6 +211,7 @@ class LoRAPatcher:
@staticmethod
def _initialize_lora_sidecar_layer(orig_layer: torch.nn.Module, lora_layer: AnyLoRALayer, patch_weight: float):
# TODO(ryand): Add support for more original layer types and LoRA layer types.
if isinstance(orig_layer, torch.nn.Linear):
if isinstance(lora_layer, LoRALayer):
return LoRALinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight)
@ -211,7 +227,7 @@ class LoRAPatcher:
try:
submodule_index = int(module_name)
# If the module name is an integer, then we use the __setitem__ method to set the submodule.
parent_module[submodule_index] = submodule
parent_module[submodule_index] = submodule # type: ignore
except ValueError:
# If the module name is not an integer, then we use the setattr method to set the submodule.
setattr(parent_module, module_name, submodule)
@ -221,12 +237,16 @@ class LoRAPatcher:
model: torch.nn.Module, layer_key: str, layer_key_is_flattened: bool
) -> tuple[str, torch.nn.Module]:
"""Get the submodule corresponding to the given layer key.
:param model: The model to search.
:param layer_key: The layer key to search for.
:param layer_key_is_flattened: Whether the layer key is flattened. If flattened, then all '.' have been replaced
with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly without
searching, but some legacy code still uses flattened keys.
:return: A tuple containing the module key and the submodule.
Args:
model (torch.nn.Module): The model to search.
layer_key (str): The layer key to search for.
layer_key_is_flattened (bool): Whether the layer key is flattened. If flattened, then all '.' have been
replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed
directly without searching, but some legacy code still uses flattened keys.
Returns:
tuple[str, torch.nn.Module]: A tuple containing the module key and the submodule.
"""
if not layer_key_is_flattened:
return layer_key, model.get_submodule(layer_key)

View File

@ -2,6 +2,8 @@ import torch
class LoRASidecarModule(torch.nn.Module):
"""A LoRA sidecar module that wraps an original module and adds LoRA layers to it."""
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[torch.nn.Module]):
super().__init__()
self._orig_module = orig_module