diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_8_bit_lt.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_8_bit_lt.py index aa6acd31c5..2b9d8e9e98 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_8_bit_lt.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_8_bit_lt.py @@ -2,6 +2,9 @@ import bitsandbytes as bnb import torch from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import ( + autocast_linear_forward_sidecar_patches, +) from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( CustomModuleMixin, ) @@ -9,6 +12,9 @@ from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt class CustomInvokeLinear8bitLt(InvokeLinear8bitLt, CustomModuleMixin): + def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor: + return autocast_linear_forward_sidecar_patches(self, x, self._patches_and_weights) + def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor: matmul_state = bnb.MatmulLtState() matmul_state.threshold = self.state.threshold @@ -30,7 +36,9 @@ class CustomInvokeLinear8bitLt(InvokeLinear8bitLt, CustomModuleMixin): return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state) def forward(self, x: torch.Tensor) -> torch.Tensor: - if self._device_autocasting_enabled: + if len(self._patches_and_weights) > 0: + return self._autocast_forward_with_patches(x) + elif self._device_autocasting_enabled: return self._autocast_forward(x) else: return super().forward(x) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_nf4.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_nf4.py index 60e987b3f3..89284d5509 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_nf4.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_nf4.py @@ -4,6 +4,9 @@ import bitsandbytes as bnb import torch from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import ( + autocast_linear_forward_sidecar_patches, +) from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( CustomModuleMixin, ) @@ -11,6 +14,9 @@ from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4 class CustomInvokeLinearNF4(InvokeLinearNF4, CustomModuleMixin): + def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor: + return autocast_linear_forward_sidecar_patches(self, x, self._patches_and_weights) + def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor: bnb.nn.modules.fix_4bit_weight_quant_state_from_module(self) @@ -48,7 +54,9 @@ class CustomInvokeLinearNF4(InvokeLinearNF4, CustomModuleMixin): return bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state).to(inp_dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: - if self._device_autocasting_enabled: + if len(self._patches_and_weights) > 0: + return self._autocast_forward_with_patches(x) + elif self._device_autocasting_enabled: return self._autocast_forward(x) else: return super().forward(x) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py index b01a744be6..666ea1d8cf 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py @@ -25,13 +25,15 @@ from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_m from tests.backend.quantization.gguf.test_ggml_tensor import quantize_tensor -def build_linear_layer_with_ggml_quantized_tensor(): - layer = torch.nn.Linear(32, 64) - ggml_quantized_weight = quantize_tensor(layer.weight, gguf.GGMLQuantizationType.Q8_0) - layer.weight = torch.nn.Parameter(ggml_quantized_weight) - ggml_quantized_bias = quantize_tensor(layer.bias, gguf.GGMLQuantizationType.Q8_0) - layer.bias = torch.nn.Parameter(ggml_quantized_bias) - return layer +def build_linear_layer_with_ggml_quantized_tensor(orig_layer: torch.nn.Linear | None = None): + if orig_layer is None: + orig_layer = torch.nn.Linear(32, 64) + + ggml_quantized_weight = quantize_tensor(orig_layer.weight, gguf.GGMLQuantizationType.Q8_0) + orig_layer.weight = torch.nn.Parameter(ggml_quantized_weight) + ggml_quantized_bias = quantize_tensor(orig_layer.bias, gguf.GGMLQuantizationType.Q8_0) + orig_layer.bias = torch.nn.Parameter(ggml_quantized_bias) + return orig_layer parameterize_all_devices = pytest.mark.parametrize( @@ -267,30 +269,29 @@ def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: La assert torch.allclose(orig_output, custom_output) -LayerAndPatchUnderTest = tuple[torch.nn.Module, list[tuple[BaseLayerPatch, float]], torch.Tensor, bool] +PatchUnderTest = tuple[list[tuple[BaseLayerPatch, float]], torch.Tensor] @pytest.fixture( params=[ - "linear_single_lora", - "linear_multiple_loras", - "linear_concatenated_lora", - "linear_flux_control_lora", + "single_lora", + "multiple_loras", + "concatenated_lora", + "flux_control_lora", ] ) -def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchUnderTest: - """A fixture that returns a tuple of (layer, input, supports_cpu_inference) for the layer under test.""" +def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest: + """A fixture that returns a tuple of (patches, input) for the patch under test.""" layer_type = request.param torch.manual_seed(0) - if layer_type == "linear_single_lora": - # Create a linear layer. - in_features = 10 - out_features = 20 - layer = torch.nn.Linear(in_features, out_features) + # The assumed in/out features of the base linear layer. + in_features = 32 + out_features = 64 - # Create a LoRA layer. - rank = 4 + rank = 4 + + if layer_type == "single_lora": lora_layer = LoRALayer( up=torch.randn(out_features, rank), mid=None, @@ -299,14 +300,8 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU bias=torch.randn(out_features), ) input = torch.randn(1, in_features) - return (layer, [(lora_layer, 0.7)], input, True) - elif layer_type == "linear_multiple_loras": - # Create a linear layer. - rank = 4 - in_features = 10 - out_features = 20 - layer = torch.nn.Linear(in_features, out_features) - + return ([(lora_layer, 0.7)], input) + elif layer_type == "multiple_loras": lora_layer = LoRALayer( up=torch.randn(out_features, rank), mid=None, @@ -323,15 +318,11 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU ) input = torch.randn(1, in_features) - return (layer, [(lora_layer, 1.0), (lora_layer_2, 0.5)], input, True) - elif layer_type == "linear_concatenated_lora": - # Create a linear layer. - in_features = 5 - sub_layer_out_features = [5, 10, 15] - layer = torch.nn.Linear(in_features, sum(sub_layer_out_features)) + return ([(lora_layer, 1.0), (lora_layer_2, 0.5)], input) + elif layer_type == "concatenated_lora": + sub_layer_out_features = [16, 16, 32] # Create a ConcatenatedLoRA layer. - rank = 4 sub_layers: list[LoRALayer] = [] for out_features in sub_layer_out_features: down = torch.randn(rank, in_features) @@ -341,16 +332,10 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0) input = torch.randn(1, in_features) - return (layer, [(concatenated_lora_layer, 0.7)], input, True) - elif layer_type == "linear_flux_control_lora": - # Create a linear layer. - orig_in_features = 10 - out_features = 40 - layer = torch.nn.Linear(orig_in_features, out_features) - + return ([(concatenated_lora_layer, 0.7)], input) + elif layer_type == "flux_control_lora": # Create a FluxControlLoRALayer. - patched_in_features = 20 - rank = 4 + patched_in_features = 40 lora_layer = FluxControlLoRALayer( up=torch.randn(out_features, rank), mid=None, @@ -360,17 +345,17 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU ) input = torch.randn(1, patched_in_features) - return (layer, [(lora_layer, 0.7)], input, True) + return ([(lora_layer, 0.7)], input) else: raise ValueError(f"Unsupported layer_type: {layer_type}") @parameterize_all_devices -def test_sidecar_patches(device: str, layer_and_patch_under_test: LayerAndPatchUnderTest): - layer, patches, input, supports_cpu_inference = layer_and_patch_under_test +def test_linear_sidecar_patches(device: str, layer_type: str, patch_under_test: PatchUnderTest): + patches, input = patch_under_test - if device == "cpu" and not supports_cpu_inference: - pytest.skip("Layer does not support CPU inference.") + # Build the base layer under test. + layer = torch.nn.Linear(32, 64) # Move the layer and input to the device. layer_to_device_via_state_dict(layer, device) @@ -397,3 +382,60 @@ def test_sidecar_patches(device: str, layer_and_patch_under_test: LayerAndPatchU output_patched = layer_patched(input) output_custom = custom_layer(input) assert torch.allclose(output_patched, output_custom, atol=1e-6) + + +@pytest.fixture( + params=[ + "linear_ggml_quantized", + "invoke_linear_8_bit_lt", + "invoke_linear_nf4", + ] +) +def quantized_linear_layer_under_test(request: pytest.FixtureRequest): + in_features = 32 + out_features = 64 + torch.manual_seed(0) + layer_type = request.param + orig_layer = torch.nn.Linear(in_features, out_features) + if layer_type == "linear_ggml_quantized": + return orig_layer, build_linear_layer_with_ggml_quantized_tensor(orig_layer) + elif layer_type == "invoke_linear_8_bit_lt": + return orig_layer, build_linear_8bit_lt_layer(orig_layer) + elif layer_type == "invoke_linear_nf4": + return orig_layer, build_linear_nf4_layer(orig_layer) + else: + raise ValueError(f"Unsupported layer_type: {layer_type}") + + +@parameterize_cuda_and_mps +def test_quantized_linear_sidecar_patches( + device: str, + quantized_linear_layer_under_test: tuple[torch.nn.Module, torch.nn.Module], + patch_under_test: PatchUnderTest, +): + """Test that patches can be applied to quantized linear layers and that the output is the same as when the patch is + applied to a non-quantized linear layer. + """ + patches, input = patch_under_test + + linear_layer, quantized_linear_layer = quantized_linear_layer_under_test + + # Move everything to the device. + layer_to_device_via_state_dict(linear_layer, device) + layer_to_device_via_state_dict(quantized_linear_layer, device) + input = input.to(torch.device(device)) + + # Wrap both layers in custom layers. + linear_layer_custom = wrap_single_custom_layer(linear_layer) + quantized_linear_layer_custom = wrap_single_custom_layer(quantized_linear_layer) + + # Apply the patches to the custom layers. + for patch, weight in patches: + patch.to(torch.device(device)) + linear_layer_custom.add_patch(patch, weight) + quantized_linear_layer_custom.add_patch(patch, weight) + + # Run inference with the original layer and the patched layer and assert they are equal. + output_linear_patched = linear_layer_custom(input) + output_quantized_patched = quantized_linear_layer_custom(input) + assert torch.allclose(output_linear_patched, output_quantized_patched, rtol=0.2, atol=0.2) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py index e23cb25eb0..9a225267fb 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py @@ -14,17 +14,20 @@ else: from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt -def build_linear_8bit_lt_layer(): +def build_linear_8bit_lt_layer(orig_layer: torch.nn.Linear | None = None): if not torch.cuda.is_available(): pytest.skip("CUDA is not available") torch.manual_seed(1) - orig_layer = torch.nn.Linear(32, 64) + if orig_layer is None: + orig_layer = torch.nn.Linear(32, 64) orig_layer_state_dict = orig_layer.state_dict() # Prepare a quantized InvokeLinear8bitLt layer. - quantized_layer = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False) + quantized_layer = InvokeLinear8bitLt( + input_features=orig_layer.in_features, output_features=orig_layer.out_features, has_fp16_weights=False + ) quantized_layer.load_state_dict(orig_layer_state_dict) quantized_layer.to("cuda") diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py index 17854597ec..3559ddea6c 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py @@ -10,17 +10,19 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4 -def build_linear_nf4_layer(): +def build_linear_nf4_layer(orig_layer: torch.nn.Linear | None = None): if not torch.cuda.is_available(): pytest.skip("CUDA is not available") torch.manual_seed(1) - orig_layer = torch.nn.Linear(64, 16) + if orig_layer is None: + orig_layer = torch.nn.Linear(64, 16) + orig_layer_state_dict = orig_layer.state_dict() # Prepare a quantized InvokeLinearNF4 layer. - quantized_layer = InvokeLinearNF4(input_features=64, output_features=16) + quantized_layer = InvokeLinearNF4(input_features=orig_layer.in_features, output_features=orig_layer.out_features) quantized_layer.load_state_dict(orig_layer_state_dict) quantized_layer.to("cuda")