mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-08 11:57:36 +08:00
Add CustomInvokeLinear8bitLt layer for device streaming with InvokeLinear8bitLt layers.
This commit is contained in:
parent
3f990393a1
commit
1b56020876
@ -1,7 +1,10 @@
|
||||
from typing import TypeVar
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
|
||||
|
||||
T = TypeVar("T", torch.Tensor, None, torch.Tensor | None)
|
||||
|
||||
# This file contains custom torch.nn.Module classes that support streaming of weights to the target device.
|
||||
@ -59,3 +62,25 @@ class CustomEmbedding(torch.nn.Embedding):
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
||||
|
||||
|
||||
class CustomInvokeLinear8bitLt(InvokeLinear8bitLt):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
matmul_state = bnb.MatmulLtState()
|
||||
matmul_state.threshold = self.state.threshold
|
||||
matmul_state.has_fp16_weights = self.state.has_fp16_weights
|
||||
matmul_state.use_pool = self.state.use_pool
|
||||
matmul_state.is_training = self.training
|
||||
# The underlying InvokeInt8Params weight must already be quantized.
|
||||
assert self.weight.CB is not None
|
||||
matmul_state.CB = cast_to_device(self.weight.CB, x.device)
|
||||
matmul_state.SCB = cast_to_device(self.weight.SCB, x.device)
|
||||
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually.
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
# NOTE(ryand): The second parameter should not be needed at all given our expected inference configuration, but
|
||||
# it's dtype field must be accessible, even though it's not used. We pass in self.weight even though it could be
|
||||
# on the wrong device.
|
||||
return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state)
|
||||
|
@ -5,8 +5,10 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autoc
|
||||
CustomConv2d,
|
||||
CustomEmbedding,
|
||||
CustomGroupNorm,
|
||||
CustomInvokeLinear8bitLt,
|
||||
CustomLinear,
|
||||
)
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
|
||||
|
||||
AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]] = {
|
||||
torch.nn.Linear: CustomLinear,
|
||||
@ -14,6 +16,7 @@ AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]]
|
||||
torch.nn.Conv2d: CustomConv2d,
|
||||
torch.nn.GroupNorm: CustomGroupNorm,
|
||||
torch.nn.Embedding: CustomEmbedding,
|
||||
InvokeLinear8bitLt: CustomInvokeLinear8bitLt,
|
||||
}
|
||||
|
||||
|
||||
|
@ -0,0 +1,67 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import (
|
||||
CustomInvokeLinear8bitLt,
|
||||
)
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def linear_8bit_lt_layer():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available")
|
||||
|
||||
torch.manual_seed(1)
|
||||
|
||||
orig_layer = torch.nn.Linear(32, 64)
|
||||
orig_layer_state_dict = orig_layer.state_dict()
|
||||
|
||||
# Prepare a quantized InvokeLinear8bitLt layer.
|
||||
quantized_layer = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False)
|
||||
quantized_layer.load_state_dict(orig_layer_state_dict)
|
||||
quantized_layer.to("cuda")
|
||||
|
||||
# Assert that the InvokeLinear8bitLt layer is quantized.
|
||||
assert quantized_layer.weight.CB is not None
|
||||
assert quantized_layer.weight.SCB is not None
|
||||
assert quantized_layer.weight.CB.dtype == torch.int8
|
||||
|
||||
return quantized_layer
|
||||
|
||||
|
||||
def test_custom_invoke_linear_8bit_lt_all_weights_on_cuda(linear_8bit_lt_layer: InvokeLinear8bitLt):
|
||||
"""Test CustomInvokeLinear8bitLt inference with all weights on the GPU."""
|
||||
# Run inference on the original layer.
|
||||
x = torch.randn(10, 32).to("cuda")
|
||||
y_quantized = linear_8bit_lt_layer(x)
|
||||
|
||||
# Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it.
|
||||
linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt
|
||||
y_custom = linear_8bit_lt_layer(x)
|
||||
|
||||
# Assert that the quantized and custom layers produce the same output.
|
||||
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
|
||||
|
||||
|
||||
def test_custom_invoke_linear_8bit_lt_all_weights_on_cpu(linear_8bit_lt_layer: InvokeLinear8bitLt):
|
||||
"""Test CustomInvokeLinear8bitLt inference with all weights on the CPU (streaming to the GPU)."""
|
||||
# Run inference on the original layer.
|
||||
x = torch.randn(10, 32).to("cuda")
|
||||
y_quantized = linear_8bit_lt_layer(x)
|
||||
|
||||
# Copy the state dict to the CPU and reload it.
|
||||
state_dict = linear_8bit_lt_layer.state_dict()
|
||||
state_dict = {k: v.to("cpu") for k, v in state_dict.items()}
|
||||
linear_8bit_lt_layer.load_state_dict(state_dict)
|
||||
|
||||
# Inference of the original layer should fail.
|
||||
with pytest.raises(RuntimeError):
|
||||
linear_8bit_lt_layer(x)
|
||||
|
||||
# Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it.
|
||||
linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt
|
||||
y_custom = linear_8bit_lt_layer(x)
|
||||
|
||||
# Assert that the quantized and custom layers produce the same output.
|
||||
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
|
@ -6,6 +6,7 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch
|
||||
apply_custom_layers_to_model,
|
||||
remove_custom_layers_from_model,
|
||||
)
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt, quantize_model_llm_int8
|
||||
from tests.backend.quantization.gguf.test_ggml_tensor import quantize_tensor
|
||||
|
||||
cuda_and_mps = pytest.mark.parametrize(
|
||||
@ -81,3 +82,33 @@ def test_torch_module_autocast_linear_layer(device: torch.device, model: torch.n
|
||||
# The results from all inference runs should be the same.
|
||||
assert torch.allclose(autocast_result.to("cpu"), expected, atol=1e-5)
|
||||
assert torch.allclose(after_result, expected, atol=1e-5)
|
||||
|
||||
|
||||
def test_torch_module_autocast_bnb_llm_int8_linear_layer():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("requires CUDA device")
|
||||
|
||||
model = ModelWithLinearLayer()
|
||||
model = quantize_model_llm_int8(model, modules_to_not_convert=set())
|
||||
# The act of moving the model to the CUDA device will trigger quantization.
|
||||
model.to("cuda")
|
||||
# Confirm that the layer is quantized.
|
||||
assert isinstance(model.linear, InvokeLinear8bitLt)
|
||||
assert model.linear.weight.CB is not None
|
||||
assert model.linear.weight.SCB is not None
|
||||
|
||||
# Run inference on the GPU.
|
||||
x = torch.randn(10, 32)
|
||||
expected = model(x.to("cuda"))
|
||||
assert expected.device.type == "cuda"
|
||||
|
||||
# Move the model back to the CPU and add the custom layers to the model.
|
||||
model.to("cpu")
|
||||
apply_custom_layers_to_model(model)
|
||||
|
||||
# Run inference with weights being streamed to the GPU.
|
||||
autocast_result = model(x.to("cuda"))
|
||||
assert autocast_result.device.type == "cuda"
|
||||
|
||||
# The results from all inference runs should be the same.
|
||||
assert torch.allclose(autocast_result, expected, atol=1e-5)
|
||||
|
Loading…
Reference in New Issue
Block a user