diff --git a/tests/backend/patches/test_lora_patcher.py b/tests/backend/patches/test_lora_patcher.py index 5561e882fb..dd250b6535 100644 --- a/tests/backend/patches/test_lora_patcher.py +++ b/tests/backend/patches/test_lora_patcher.py @@ -1,9 +1,13 @@ import pytest import torch +from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( + CachedModelWithPartialLoad, +) from invokeai.backend.patches.layers.lora_layer import LoRALayer from invokeai.backend.patches.model_patch_raw import ModelPatchRaw from invokeai.backend.patches.model_patcher import LayerPatcher +from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper class DummyModuleWithOneLayer(torch.nn.Module): @@ -220,10 +224,79 @@ def test_apply_smart_model_patches(device: str, num_layers: int): assert torch.allclose(output_before_patch, output_after_patch) +@pytest.mark.parametrize(["num_layers"], [(1,), (2,)]) +@torch.no_grad() +def test_apply_smart_lora_patches_to_partially_loaded_model(num_layers: int): + """Test the behavior of ModelPatcher.apply_smart_lora_patches(...) when it is applied to a + CachedModelWithPartialLoad that is partially loaded into VRAM. + """ + + if not torch.cuda.is_available(): + pytest.skip("requires CUDA device") + + # Initialize the model on the CPU. + dtype = torch.float16 + linear_in_features = 4 + linear_out_features = 8 + lora_rank = 2 + model = DummyModuleWithTwoLayers(linear_in_features, linear_out_features, device="cpu", dtype=dtype) + cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device("cuda")) + model_total_bytes = cached_model.total_bytes() + assert cached_model.cur_vram_bytes() == 0 + + # Partially load the model into VRAM. + target_vram_bytes = int(model_total_bytes * 0.6) + _ = cached_model.partial_load_to_vram(target_vram_bytes) + assert cached_model.model.linear_layer_1.weight.device.type == "cuda" + assert cached_model.model.linear_layer_2.weight.device.type == "cpu" + + # Initialize num_layers LoRA models with weights of 0.5. + lora_weight = 0.5 + lora_models: list[tuple[ModelPatchRaw, float]] = [] + for _ in range(num_layers): + lora_layers = { + "linear_layer_1": LoRALayer.from_state_dict_values( + values={ + "lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16), + "lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16), + }, + ), + "linear_layer_2": LoRALayer.from_state_dict_values( + values={ + "lora_down.weight": torch.ones((lora_rank, linear_out_features), device="cpu", dtype=torch.float16), + "lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16), + }, + ), + } + lora = ModelPatchRaw(lora_layers) + lora_models.append((lora, lora_weight)) + + # Run inference before patching the model. + input = torch.randn(1, linear_in_features, device="cuda", dtype=dtype) + output_before_patch = cached_model.model(input) + + # Patch the model and run inference during the patch. + with LayerPatcher.apply_smart_model_patches(model=cached_model.model, patches=lora_models, prefix="", dtype=dtype): + # Check that the second layer is wrapped in a LoRASidecarWrapper, but the first layer is not. + assert not isinstance(cached_model.model.linear_layer_1, BaseSidecarWrapper) + assert isinstance(cached_model.model.linear_layer_2, BaseSidecarWrapper) + + output_during_patch = cached_model.model(input) + + # Run inference after unpatching. + output_after_patch = cached_model.model(input) + + # Check that the output before patching is different from the output during patching. + assert not torch.allclose(output_before_patch, output_during_patch) + + # Check that the output before patching is the same as the output after patching. + assert torch.allclose(output_before_patch, output_after_patch) + + @torch.no_grad() @pytest.mark.parametrize(["num_layers"], [(1,), (2,)]) -def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int): - """Test that apply_lora_sidecar_patches(...) produces the same model outputs as apply_lora_patches(...).""" +def test_apply_model_sidecar_patches_matches_apply_model_patches(num_layers: int): + """Test that apply_model_sidecar_patches(...) produces the same model outputs as apply__patches(...).""" dtype = torch.float32 linear_in_features = 4 linear_out_features = 8 @@ -253,6 +326,10 @@ def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int): with LayerPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype): output_lora_sidecar_patches = model(input) + with LayerPatcher.apply_smart_model_patches(model=model, patches=lora_models, prefix="", dtype=dtype): + output_smart_lora_patches = model(input) + # Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical # differences are tolerable and expected due to the difference between sidecar vs. patching. assert torch.allclose(output_lora_patches, output_lora_sidecar_patches, atol=1e-5) + assert torch.allclose(output_lora_patches, output_smart_lora_patches, atol=1e-5)