From 1b560208761482cae0c5e1b044d4b2b0cd58dc67 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sat, 21 Dec 2024 22:41:19 +0000 Subject: [PATCH] Add CustomInvokeLinear8bitLt layer for device streaming with InvokeLinear8bitLt layers. --- .../torch_module_autocast/autocast_modules.py | 25 +++++++ .../torch_module_autocast.py | 3 + .../test_autocast_modules.py | 67 +++++++++++++++++++ .../test_torch_module_autocast.py | 31 +++++++++ 4 files changed, 126 insertions(+) create mode 100644 tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py index 03849c5b0e..31b5be060f 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py @@ -1,7 +1,10 @@ from typing import TypeVar +import bitsandbytes as bnb import torch +from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt + T = TypeVar("T", torch.Tensor, None, torch.Tensor | None) # This file contains custom torch.nn.Module classes that support streaming of weights to the target device. @@ -59,3 +62,25 @@ class CustomEmbedding(torch.nn.Embedding): self.scale_grad_by_freq, self.sparse, ) + + +class CustomInvokeLinear8bitLt(InvokeLinear8bitLt): + def forward(self, x: torch.Tensor) -> torch.Tensor: + matmul_state = bnb.MatmulLtState() + matmul_state.threshold = self.state.threshold + matmul_state.has_fp16_weights = self.state.has_fp16_weights + matmul_state.use_pool = self.state.use_pool + matmul_state.is_training = self.training + # The underlying InvokeInt8Params weight must already be quantized. + assert self.weight.CB is not None + matmul_state.CB = cast_to_device(self.weight.CB, x.device) + matmul_state.SCB = cast_to_device(self.weight.SCB, x.device) + + # weights are cast automatically as Int8Params, but the bias has to be cast manually. + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + # NOTE(ryand): The second parameter should not be needed at all given our expected inference configuration, but + # it's dtype field must be accessible, even though it's not used. We pass in self.weight even though it could be + # on the wrong device. + return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py index 625f1943a5..59c99ab411 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py @@ -5,8 +5,10 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autoc CustomConv2d, CustomEmbedding, CustomGroupNorm, + CustomInvokeLinear8bitLt, CustomLinear, ) +from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]] = { torch.nn.Linear: CustomLinear, @@ -14,6 +16,7 @@ AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]] torch.nn.Conv2d: CustomConv2d, torch.nn.GroupNorm: CustomGroupNorm, torch.nn.Embedding: CustomEmbedding, + InvokeLinear8bitLt: CustomInvokeLinear8bitLt, } diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py new file mode 100644 index 0000000000..06887f6968 --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py @@ -0,0 +1,67 @@ +import pytest +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import ( + CustomInvokeLinear8bitLt, +) +from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt + + +@pytest.fixture +def linear_8bit_lt_layer(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + torch.manual_seed(1) + + orig_layer = torch.nn.Linear(32, 64) + orig_layer_state_dict = orig_layer.state_dict() + + # Prepare a quantized InvokeLinear8bitLt layer. + quantized_layer = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False) + quantized_layer.load_state_dict(orig_layer_state_dict) + quantized_layer.to("cuda") + + # Assert that the InvokeLinear8bitLt layer is quantized. + assert quantized_layer.weight.CB is not None + assert quantized_layer.weight.SCB is not None + assert quantized_layer.weight.CB.dtype == torch.int8 + + return quantized_layer + + +def test_custom_invoke_linear_8bit_lt_all_weights_on_cuda(linear_8bit_lt_layer: InvokeLinear8bitLt): + """Test CustomInvokeLinear8bitLt inference with all weights on the GPU.""" + # Run inference on the original layer. + x = torch.randn(10, 32).to("cuda") + y_quantized = linear_8bit_lt_layer(x) + + # Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it. + linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt + y_custom = linear_8bit_lt_layer(x) + + # Assert that the quantized and custom layers produce the same output. + assert torch.allclose(y_quantized, y_custom, atol=1e-5) + + +def test_custom_invoke_linear_8bit_lt_all_weights_on_cpu(linear_8bit_lt_layer: InvokeLinear8bitLt): + """Test CustomInvokeLinear8bitLt inference with all weights on the CPU (streaming to the GPU).""" + # Run inference on the original layer. + x = torch.randn(10, 32).to("cuda") + y_quantized = linear_8bit_lt_layer(x) + + # Copy the state dict to the CPU and reload it. + state_dict = linear_8bit_lt_layer.state_dict() + state_dict = {k: v.to("cpu") for k, v in state_dict.items()} + linear_8bit_lt_layer.load_state_dict(state_dict) + + # Inference of the original layer should fail. + with pytest.raises(RuntimeError): + linear_8bit_lt_layer(x) + + # Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it. + linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt + y_custom = linear_8bit_lt_layer(x) + + # Assert that the quantized and custom layers produce the same output. + assert torch.allclose(y_quantized, y_custom, atol=1e-5) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py index cef1cdf9a4..1c47d297cb 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py @@ -6,6 +6,7 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch apply_custom_layers_to_model, remove_custom_layers_from_model, ) +from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt, quantize_model_llm_int8 from tests.backend.quantization.gguf.test_ggml_tensor import quantize_tensor cuda_and_mps = pytest.mark.parametrize( @@ -81,3 +82,33 @@ def test_torch_module_autocast_linear_layer(device: torch.device, model: torch.n # The results from all inference runs should be the same. assert torch.allclose(autocast_result.to("cpu"), expected, atol=1e-5) assert torch.allclose(after_result, expected, atol=1e-5) + + +def test_torch_module_autocast_bnb_llm_int8_linear_layer(): + if not torch.cuda.is_available(): + pytest.skip("requires CUDA device") + + model = ModelWithLinearLayer() + model = quantize_model_llm_int8(model, modules_to_not_convert=set()) + # The act of moving the model to the CUDA device will trigger quantization. + model.to("cuda") + # Confirm that the layer is quantized. + assert isinstance(model.linear, InvokeLinear8bitLt) + assert model.linear.weight.CB is not None + assert model.linear.weight.SCB is not None + + # Run inference on the GPU. + x = torch.randn(10, 32) + expected = model(x.to("cuda")) + assert expected.device.type == "cuda" + + # Move the model back to the CPU and add the custom layers to the model. + model.to("cpu") + apply_custom_layers_to_model(model) + + # Run inference with weights being streamed to the GPU. + autocast_result = model(x.to("cuda")) + assert autocast_result.device.type == "cuda" + + # The results from all inference runs should be the same. + assert torch.allclose(autocast_result, expected, atol=1e-5)