Get custom layer patches working with all quantized linear layer types.

This commit is contained in:
Ryan Dick 2024-12-27 22:00:22 +00:00
parent ef970a1cdc
commit f2981979f9
5 changed files with 121 additions and 58 deletions

View File

@ -2,6 +2,9 @@ import bitsandbytes as bnb
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import (
autocast_linear_forward_sidecar_patches,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
@ -9,6 +12,9 @@ from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
class CustomInvokeLinear8bitLt(InvokeLinear8bitLt, CustomModuleMixin):
def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor:
return autocast_linear_forward_sidecar_patches(self, x, self._patches_and_weights)
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
matmul_state = bnb.MatmulLtState()
matmul_state.threshold = self.state.threshold
@ -30,7 +36,9 @@ class CustomInvokeLinear8bitLt(InvokeLinear8bitLt, CustomModuleMixin):
return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._device_autocasting_enabled:
if len(self._patches_and_weights) > 0:
return self._autocast_forward_with_patches(x)
elif self._device_autocasting_enabled:
return self._autocast_forward(x)
else:
return super().forward(x)

View File

@ -4,6 +4,9 @@ import bitsandbytes as bnb
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import (
autocast_linear_forward_sidecar_patches,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
@ -11,6 +14,9 @@ from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
class CustomInvokeLinearNF4(InvokeLinearNF4, CustomModuleMixin):
def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor:
return autocast_linear_forward_sidecar_patches(self, x, self._patches_and_weights)
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
bnb.nn.modules.fix_4bit_weight_quant_state_from_module(self)
@ -48,7 +54,9 @@ class CustomInvokeLinearNF4(InvokeLinearNF4, CustomModuleMixin):
return bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state).to(inp_dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._device_autocasting_enabled:
if len(self._patches_and_weights) > 0:
return self._autocast_forward_with_patches(x)
elif self._device_autocasting_enabled:
return self._autocast_forward(x)
else:
return super().forward(x)

View File

@ -25,13 +25,15 @@ from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_m
from tests.backend.quantization.gguf.test_ggml_tensor import quantize_tensor
def build_linear_layer_with_ggml_quantized_tensor():
layer = torch.nn.Linear(32, 64)
ggml_quantized_weight = quantize_tensor(layer.weight, gguf.GGMLQuantizationType.Q8_0)
layer.weight = torch.nn.Parameter(ggml_quantized_weight)
ggml_quantized_bias = quantize_tensor(layer.bias, gguf.GGMLQuantizationType.Q8_0)
layer.bias = torch.nn.Parameter(ggml_quantized_bias)
return layer
def build_linear_layer_with_ggml_quantized_tensor(orig_layer: torch.nn.Linear | None = None):
if orig_layer is None:
orig_layer = torch.nn.Linear(32, 64)
ggml_quantized_weight = quantize_tensor(orig_layer.weight, gguf.GGMLQuantizationType.Q8_0)
orig_layer.weight = torch.nn.Parameter(ggml_quantized_weight)
ggml_quantized_bias = quantize_tensor(orig_layer.bias, gguf.GGMLQuantizationType.Q8_0)
orig_layer.bias = torch.nn.Parameter(ggml_quantized_bias)
return orig_layer
parameterize_all_devices = pytest.mark.parametrize(
@ -267,30 +269,29 @@ def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: La
assert torch.allclose(orig_output, custom_output)
LayerAndPatchUnderTest = tuple[torch.nn.Module, list[tuple[BaseLayerPatch, float]], torch.Tensor, bool]
PatchUnderTest = tuple[list[tuple[BaseLayerPatch, float]], torch.Tensor]
@pytest.fixture(
params=[
"linear_single_lora",
"linear_multiple_loras",
"linear_concatenated_lora",
"linear_flux_control_lora",
"single_lora",
"multiple_loras",
"concatenated_lora",
"flux_control_lora",
]
)
def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchUnderTest:
"""A fixture that returns a tuple of (layer, input, supports_cpu_inference) for the layer under test."""
def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest:
"""A fixture that returns a tuple of (patches, input) for the patch under test."""
layer_type = request.param
torch.manual_seed(0)
if layer_type == "linear_single_lora":
# Create a linear layer.
in_features = 10
out_features = 20
layer = torch.nn.Linear(in_features, out_features)
# The assumed in/out features of the base linear layer.
in_features = 32
out_features = 64
# Create a LoRA layer.
rank = 4
rank = 4
if layer_type == "single_lora":
lora_layer = LoRALayer(
up=torch.randn(out_features, rank),
mid=None,
@ -299,14 +300,8 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU
bias=torch.randn(out_features),
)
input = torch.randn(1, in_features)
return (layer, [(lora_layer, 0.7)], input, True)
elif layer_type == "linear_multiple_loras":
# Create a linear layer.
rank = 4
in_features = 10
out_features = 20
layer = torch.nn.Linear(in_features, out_features)
return ([(lora_layer, 0.7)], input)
elif layer_type == "multiple_loras":
lora_layer = LoRALayer(
up=torch.randn(out_features, rank),
mid=None,
@ -323,15 +318,11 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU
)
input = torch.randn(1, in_features)
return (layer, [(lora_layer, 1.0), (lora_layer_2, 0.5)], input, True)
elif layer_type == "linear_concatenated_lora":
# Create a linear layer.
in_features = 5
sub_layer_out_features = [5, 10, 15]
layer = torch.nn.Linear(in_features, sum(sub_layer_out_features))
return ([(lora_layer, 1.0), (lora_layer_2, 0.5)], input)
elif layer_type == "concatenated_lora":
sub_layer_out_features = [16, 16, 32]
# Create a ConcatenatedLoRA layer.
rank = 4
sub_layers: list[LoRALayer] = []
for out_features in sub_layer_out_features:
down = torch.randn(rank, in_features)
@ -341,16 +332,10 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
input = torch.randn(1, in_features)
return (layer, [(concatenated_lora_layer, 0.7)], input, True)
elif layer_type == "linear_flux_control_lora":
# Create a linear layer.
orig_in_features = 10
out_features = 40
layer = torch.nn.Linear(orig_in_features, out_features)
return ([(concatenated_lora_layer, 0.7)], input)
elif layer_type == "flux_control_lora":
# Create a FluxControlLoRALayer.
patched_in_features = 20
rank = 4
patched_in_features = 40
lora_layer = FluxControlLoRALayer(
up=torch.randn(out_features, rank),
mid=None,
@ -360,17 +345,17 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU
)
input = torch.randn(1, patched_in_features)
return (layer, [(lora_layer, 0.7)], input, True)
return ([(lora_layer, 0.7)], input)
else:
raise ValueError(f"Unsupported layer_type: {layer_type}")
@parameterize_all_devices
def test_sidecar_patches(device: str, layer_and_patch_under_test: LayerAndPatchUnderTest):
layer, patches, input, supports_cpu_inference = layer_and_patch_under_test
def test_linear_sidecar_patches(device: str, layer_type: str, patch_under_test: PatchUnderTest):
patches, input = patch_under_test
if device == "cpu" and not supports_cpu_inference:
pytest.skip("Layer does not support CPU inference.")
# Build the base layer under test.
layer = torch.nn.Linear(32, 64)
# Move the layer and input to the device.
layer_to_device_via_state_dict(layer, device)
@ -397,3 +382,60 @@ def test_sidecar_patches(device: str, layer_and_patch_under_test: LayerAndPatchU
output_patched = layer_patched(input)
output_custom = custom_layer(input)
assert torch.allclose(output_patched, output_custom, atol=1e-6)
@pytest.fixture(
params=[
"linear_ggml_quantized",
"invoke_linear_8_bit_lt",
"invoke_linear_nf4",
]
)
def quantized_linear_layer_under_test(request: pytest.FixtureRequest):
in_features = 32
out_features = 64
torch.manual_seed(0)
layer_type = request.param
orig_layer = torch.nn.Linear(in_features, out_features)
if layer_type == "linear_ggml_quantized":
return orig_layer, build_linear_layer_with_ggml_quantized_tensor(orig_layer)
elif layer_type == "invoke_linear_8_bit_lt":
return orig_layer, build_linear_8bit_lt_layer(orig_layer)
elif layer_type == "invoke_linear_nf4":
return orig_layer, build_linear_nf4_layer(orig_layer)
else:
raise ValueError(f"Unsupported layer_type: {layer_type}")
@parameterize_cuda_and_mps
def test_quantized_linear_sidecar_patches(
device: str,
quantized_linear_layer_under_test: tuple[torch.nn.Module, torch.nn.Module],
patch_under_test: PatchUnderTest,
):
"""Test that patches can be applied to quantized linear layers and that the output is the same as when the patch is
applied to a non-quantized linear layer.
"""
patches, input = patch_under_test
linear_layer, quantized_linear_layer = quantized_linear_layer_under_test
# Move everything to the device.
layer_to_device_via_state_dict(linear_layer, device)
layer_to_device_via_state_dict(quantized_linear_layer, device)
input = input.to(torch.device(device))
# Wrap both layers in custom layers.
linear_layer_custom = wrap_single_custom_layer(linear_layer)
quantized_linear_layer_custom = wrap_single_custom_layer(quantized_linear_layer)
# Apply the patches to the custom layers.
for patch, weight in patches:
patch.to(torch.device(device))
linear_layer_custom.add_patch(patch, weight)
quantized_linear_layer_custom.add_patch(patch, weight)
# Run inference with the original layer and the patched layer and assert they are equal.
output_linear_patched = linear_layer_custom(input)
output_quantized_patched = quantized_linear_layer_custom(input)
assert torch.allclose(output_linear_patched, output_quantized_patched, rtol=0.2, atol=0.2)

View File

@ -14,17 +14,20 @@ else:
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
def build_linear_8bit_lt_layer():
def build_linear_8bit_lt_layer(orig_layer: torch.nn.Linear | None = None):
if not torch.cuda.is_available():
pytest.skip("CUDA is not available")
torch.manual_seed(1)
orig_layer = torch.nn.Linear(32, 64)
if orig_layer is None:
orig_layer = torch.nn.Linear(32, 64)
orig_layer_state_dict = orig_layer.state_dict()
# Prepare a quantized InvokeLinear8bitLt layer.
quantized_layer = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False)
quantized_layer = InvokeLinear8bitLt(
input_features=orig_layer.in_features, output_features=orig_layer.out_features, has_fp16_weights=False
)
quantized_layer.load_state_dict(orig_layer_state_dict)
quantized_layer.to("cuda")

View File

@ -10,17 +10,19 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
def build_linear_nf4_layer():
def build_linear_nf4_layer(orig_layer: torch.nn.Linear | None = None):
if not torch.cuda.is_available():
pytest.skip("CUDA is not available")
torch.manual_seed(1)
orig_layer = torch.nn.Linear(64, 16)
if orig_layer is None:
orig_layer = torch.nn.Linear(64, 16)
orig_layer_state_dict = orig_layer.state_dict()
# Prepare a quantized InvokeLinearNF4 layer.
quantized_layer = InvokeLinearNF4(input_features=64, output_features=16)
quantized_layer = InvokeLinearNF4(input_features=orig_layer.in_features, output_features=orig_layer.out_features)
quantized_layer.load_state_dict(orig_layer_state_dict)
quantized_layer.to("cuda")