Add test_apply_smart_lora_patches_to_partially_loaded_model(...).

This commit is contained in:
Ryan Dick 2024-12-10 16:38:48 +00:00
parent cefcb340d9
commit d0f35fceed

View File

@ -1,9 +1,13 @@
import pytest
import torch
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
CachedModelWithPartialLoad,
)
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
class DummyModuleWithOneLayer(torch.nn.Module):
@ -220,10 +224,79 @@ def test_apply_smart_model_patches(device: str, num_layers: int):
assert torch.allclose(output_before_patch, output_after_patch)
@pytest.mark.parametrize(["num_layers"], [(1,), (2,)])
@torch.no_grad()
def test_apply_smart_lora_patches_to_partially_loaded_model(num_layers: int):
"""Test the behavior of ModelPatcher.apply_smart_lora_patches(...) when it is applied to a
CachedModelWithPartialLoad that is partially loaded into VRAM.
"""
if not torch.cuda.is_available():
pytest.skip("requires CUDA device")
# Initialize the model on the CPU.
dtype = torch.float16
linear_in_features = 4
linear_out_features = 8
lora_rank = 2
model = DummyModuleWithTwoLayers(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device("cuda"))
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
# Partially load the model into VRAM.
target_vram_bytes = int(model_total_bytes * 0.6)
_ = cached_model.partial_load_to_vram(target_vram_bytes)
assert cached_model.model.linear_layer_1.weight.device.type == "cuda"
assert cached_model.model.linear_layer_2.weight.device.type == "cpu"
# Initialize num_layers LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[ModelPatchRaw, float]] = []
for _ in range(num_layers):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
),
"linear_layer_2": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_out_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
),
}
lora = ModelPatchRaw(lora_layers)
lora_models.append((lora, lora_weight))
# Run inference before patching the model.
input = torch.randn(1, linear_in_features, device="cuda", dtype=dtype)
output_before_patch = cached_model.model(input)
# Patch the model and run inference during the patch.
with LayerPatcher.apply_smart_model_patches(model=cached_model.model, patches=lora_models, prefix="", dtype=dtype):
# Check that the second layer is wrapped in a LoRASidecarWrapper, but the first layer is not.
assert not isinstance(cached_model.model.linear_layer_1, BaseSidecarWrapper)
assert isinstance(cached_model.model.linear_layer_2, BaseSidecarWrapper)
output_during_patch = cached_model.model(input)
# Run inference after unpatching.
output_after_patch = cached_model.model(input)
# Check that the output before patching is different from the output during patching.
assert not torch.allclose(output_before_patch, output_during_patch)
# Check that the output before patching is the same as the output after patching.
assert torch.allclose(output_before_patch, output_after_patch)
@torch.no_grad()
@pytest.mark.parametrize(["num_layers"], [(1,), (2,)])
def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int):
"""Test that apply_lora_sidecar_patches(...) produces the same model outputs as apply_lora_patches(...)."""
def test_apply_model_sidecar_patches_matches_apply_model_patches(num_layers: int):
"""Test that apply_model_sidecar_patches(...) produces the same model outputs as apply__patches(...)."""
dtype = torch.float32
linear_in_features = 4
linear_out_features = 8
@ -253,6 +326,10 @@ def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int):
with LayerPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_lora_sidecar_patches = model(input)
with LayerPatcher.apply_smart_model_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_smart_lora_patches = model(input)
# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical
# differences are tolerable and expected due to the difference between sidecar vs. patching.
assert torch.allclose(output_lora_patches, output_lora_sidecar_patches, atol=1e-5)
assert torch.allclose(output_lora_patches, output_smart_lora_patches, atol=1e-5)