Fix LoRAPatcher.apply_lora_wrapper_patches(...)

This commit is contained in:
Ryan Dick 2024-12-10 03:10:23 +00:00
parent 4c84d39e7d
commit 80128e1e14
3 changed files with 7 additions and 3 deletions

View File

@ -315,6 +315,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
dtype=inference_dtype,
)
)
else:

View File

@ -126,6 +126,7 @@ class LoRAPatcher:
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
dtype: torch.dtype,
):
"""Apply one or more LoRA wrapper patches to a model within a context manager. Wrapper patches incur some
runtime overhead compared to normal LoRA patching, but they enable:
@ -149,6 +150,7 @@ class LoRAPatcher:
patch=patch,
patch_weight=patch_weight,
original_modules=original_modules,
dtype=dtype,
)
yield
finally:
@ -166,6 +168,7 @@ class LoRAPatcher:
patch_weight: float,
prefix: str,
original_modules: dict[str, torch.nn.Module],
dtype: torch.dtype,
):
"""Apply a single LoRA wrapper patch to a model."""
@ -201,7 +204,7 @@ class LoRAPatcher:
orig_module = module.orig_module
# Move the LoRA layer to the same device/dtype as the orig module.
layer.to(device=orig_module.weight.device, dtype=orig_module.weight.dtype)
layer.to(device=orig_module.weight.device, dtype=dtype)
# Add the LoRA wrapper layer to the LoRASidecarWrapper.
lora_wrapper_layer.add_lora_layer(layer, patch_weight)

View File

@ -146,7 +146,7 @@ def test_apply_lora_wrapper_patches(device: str, num_layers: int):
output_before_patch = model(input)
# Patch the model and run inference during the patch.
with LoRAPatcher.apply_lora_wrapper_patches(model=model, patches=lora_models, prefix=""):
with LoRAPatcher.apply_lora_wrapper_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_during_patch = model(input)
# Run inference after unpatching.
@ -189,7 +189,7 @@ def test_all_patching_methods_produce_same_output(num_layers: int):
with LoRAPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
output_lora_patches = model(input)
with LoRAPatcher.apply_lora_wrapper_patches(model=model, patches=lora_models, prefix=""):
with LoRAPatcher.apply_lora_wrapper_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_lora_wrapper_patches = model(input)
# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical