mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-09 04:18:46 +08:00
Fix LoRAPatcher.apply_lora_wrapper_patches(...)
This commit is contained in:
parent
4c84d39e7d
commit
80128e1e14
@ -315,6 +315,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
@ -126,6 +126,7 @@ class LoRAPatcher:
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Apply one or more LoRA wrapper patches to a model within a context manager. Wrapper patches incur some
|
||||
runtime overhead compared to normal LoRA patching, but they enable:
|
||||
@ -149,6 +150,7 @@ class LoRAPatcher:
|
||||
patch=patch,
|
||||
patch_weight=patch_weight,
|
||||
original_modules=original_modules,
|
||||
dtype=dtype,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
@ -166,6 +168,7 @@ class LoRAPatcher:
|
||||
patch_weight: float,
|
||||
prefix: str,
|
||||
original_modules: dict[str, torch.nn.Module],
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Apply a single LoRA wrapper patch to a model."""
|
||||
|
||||
@ -201,7 +204,7 @@ class LoRAPatcher:
|
||||
orig_module = module.orig_module
|
||||
|
||||
# Move the LoRA layer to the same device/dtype as the orig module.
|
||||
layer.to(device=orig_module.weight.device, dtype=orig_module.weight.dtype)
|
||||
layer.to(device=orig_module.weight.device, dtype=dtype)
|
||||
|
||||
# Add the LoRA wrapper layer to the LoRASidecarWrapper.
|
||||
lora_wrapper_layer.add_lora_layer(layer, patch_weight)
|
||||
|
@ -146,7 +146,7 @@ def test_apply_lora_wrapper_patches(device: str, num_layers: int):
|
||||
output_before_patch = model(input)
|
||||
|
||||
# Patch the model and run inference during the patch.
|
||||
with LoRAPatcher.apply_lora_wrapper_patches(model=model, patches=lora_models, prefix=""):
|
||||
with LoRAPatcher.apply_lora_wrapper_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||
output_during_patch = model(input)
|
||||
|
||||
# Run inference after unpatching.
|
||||
@ -189,7 +189,7 @@ def test_all_patching_methods_produce_same_output(num_layers: int):
|
||||
with LoRAPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
|
||||
output_lora_patches = model(input)
|
||||
|
||||
with LoRAPatcher.apply_lora_wrapper_patches(model=model, patches=lora_models, prefix=""):
|
||||
with LoRAPatcher.apply_lora_wrapper_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||
output_lora_wrapper_patches = model(input)
|
||||
|
||||
# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical
|
||||
|
Loading…
Reference in New Issue
Block a user