mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-07 03:17:05 +08:00
Get custom layer patches working with all quantized linear layer types.
This commit is contained in:
parent
ef970a1cdc
commit
f2981979f9
@ -2,6 +2,9 @@ import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import (
|
||||
autocast_linear_forward_sidecar_patches,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
@ -9,6 +12,9 @@ from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
|
||||
|
||||
|
||||
class CustomInvokeLinear8bitLt(InvokeLinear8bitLt, CustomModuleMixin):
|
||||
def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return autocast_linear_forward_sidecar_patches(self, x, self._patches_and_weights)
|
||||
|
||||
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
matmul_state = bnb.MatmulLtState()
|
||||
matmul_state.threshold = self.state.threshold
|
||||
@ -30,7 +36,9 @@ class CustomInvokeLinear8bitLt(InvokeLinear8bitLt, CustomModuleMixin):
|
||||
return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self._device_autocasting_enabled:
|
||||
if len(self._patches_and_weights) > 0:
|
||||
return self._autocast_forward_with_patches(x)
|
||||
elif self._device_autocasting_enabled:
|
||||
return self._autocast_forward(x)
|
||||
else:
|
||||
return super().forward(x)
|
||||
|
@ -4,6 +4,9 @@ import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import (
|
||||
autocast_linear_forward_sidecar_patches,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
@ -11,6 +14,9 @@ from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
|
||||
|
||||
|
||||
class CustomInvokeLinearNF4(InvokeLinearNF4, CustomModuleMixin):
|
||||
def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return autocast_linear_forward_sidecar_patches(self, x, self._patches_and_weights)
|
||||
|
||||
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
bnb.nn.modules.fix_4bit_weight_quant_state_from_module(self)
|
||||
|
||||
@ -48,7 +54,9 @@ class CustomInvokeLinearNF4(InvokeLinearNF4, CustomModuleMixin):
|
||||
return bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state).to(inp_dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self._device_autocasting_enabled:
|
||||
if len(self._patches_and_weights) > 0:
|
||||
return self._autocast_forward_with_patches(x)
|
||||
elif self._device_autocasting_enabled:
|
||||
return self._autocast_forward(x)
|
||||
else:
|
||||
return super().forward(x)
|
||||
|
@ -25,13 +25,15 @@ from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_m
|
||||
from tests.backend.quantization.gguf.test_ggml_tensor import quantize_tensor
|
||||
|
||||
|
||||
def build_linear_layer_with_ggml_quantized_tensor():
|
||||
layer = torch.nn.Linear(32, 64)
|
||||
ggml_quantized_weight = quantize_tensor(layer.weight, gguf.GGMLQuantizationType.Q8_0)
|
||||
layer.weight = torch.nn.Parameter(ggml_quantized_weight)
|
||||
ggml_quantized_bias = quantize_tensor(layer.bias, gguf.GGMLQuantizationType.Q8_0)
|
||||
layer.bias = torch.nn.Parameter(ggml_quantized_bias)
|
||||
return layer
|
||||
def build_linear_layer_with_ggml_quantized_tensor(orig_layer: torch.nn.Linear | None = None):
|
||||
if orig_layer is None:
|
||||
orig_layer = torch.nn.Linear(32, 64)
|
||||
|
||||
ggml_quantized_weight = quantize_tensor(orig_layer.weight, gguf.GGMLQuantizationType.Q8_0)
|
||||
orig_layer.weight = torch.nn.Parameter(ggml_quantized_weight)
|
||||
ggml_quantized_bias = quantize_tensor(orig_layer.bias, gguf.GGMLQuantizationType.Q8_0)
|
||||
orig_layer.bias = torch.nn.Parameter(ggml_quantized_bias)
|
||||
return orig_layer
|
||||
|
||||
|
||||
parameterize_all_devices = pytest.mark.parametrize(
|
||||
@ -267,30 +269,29 @@ def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: La
|
||||
assert torch.allclose(orig_output, custom_output)
|
||||
|
||||
|
||||
LayerAndPatchUnderTest = tuple[torch.nn.Module, list[tuple[BaseLayerPatch, float]], torch.Tensor, bool]
|
||||
PatchUnderTest = tuple[list[tuple[BaseLayerPatch, float]], torch.Tensor]
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
"linear_single_lora",
|
||||
"linear_multiple_loras",
|
||||
"linear_concatenated_lora",
|
||||
"linear_flux_control_lora",
|
||||
"single_lora",
|
||||
"multiple_loras",
|
||||
"concatenated_lora",
|
||||
"flux_control_lora",
|
||||
]
|
||||
)
|
||||
def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchUnderTest:
|
||||
"""A fixture that returns a tuple of (layer, input, supports_cpu_inference) for the layer under test."""
|
||||
def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest:
|
||||
"""A fixture that returns a tuple of (patches, input) for the patch under test."""
|
||||
layer_type = request.param
|
||||
torch.manual_seed(0)
|
||||
|
||||
if layer_type == "linear_single_lora":
|
||||
# Create a linear layer.
|
||||
in_features = 10
|
||||
out_features = 20
|
||||
layer = torch.nn.Linear(in_features, out_features)
|
||||
# The assumed in/out features of the base linear layer.
|
||||
in_features = 32
|
||||
out_features = 64
|
||||
|
||||
# Create a LoRA layer.
|
||||
rank = 4
|
||||
rank = 4
|
||||
|
||||
if layer_type == "single_lora":
|
||||
lora_layer = LoRALayer(
|
||||
up=torch.randn(out_features, rank),
|
||||
mid=None,
|
||||
@ -299,14 +300,8 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU
|
||||
bias=torch.randn(out_features),
|
||||
)
|
||||
input = torch.randn(1, in_features)
|
||||
return (layer, [(lora_layer, 0.7)], input, True)
|
||||
elif layer_type == "linear_multiple_loras":
|
||||
# Create a linear layer.
|
||||
rank = 4
|
||||
in_features = 10
|
||||
out_features = 20
|
||||
layer = torch.nn.Linear(in_features, out_features)
|
||||
|
||||
return ([(lora_layer, 0.7)], input)
|
||||
elif layer_type == "multiple_loras":
|
||||
lora_layer = LoRALayer(
|
||||
up=torch.randn(out_features, rank),
|
||||
mid=None,
|
||||
@ -323,15 +318,11 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU
|
||||
)
|
||||
|
||||
input = torch.randn(1, in_features)
|
||||
return (layer, [(lora_layer, 1.0), (lora_layer_2, 0.5)], input, True)
|
||||
elif layer_type == "linear_concatenated_lora":
|
||||
# Create a linear layer.
|
||||
in_features = 5
|
||||
sub_layer_out_features = [5, 10, 15]
|
||||
layer = torch.nn.Linear(in_features, sum(sub_layer_out_features))
|
||||
return ([(lora_layer, 1.0), (lora_layer_2, 0.5)], input)
|
||||
elif layer_type == "concatenated_lora":
|
||||
sub_layer_out_features = [16, 16, 32]
|
||||
|
||||
# Create a ConcatenatedLoRA layer.
|
||||
rank = 4
|
||||
sub_layers: list[LoRALayer] = []
|
||||
for out_features in sub_layer_out_features:
|
||||
down = torch.randn(rank, in_features)
|
||||
@ -341,16 +332,10 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU
|
||||
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
|
||||
|
||||
input = torch.randn(1, in_features)
|
||||
return (layer, [(concatenated_lora_layer, 0.7)], input, True)
|
||||
elif layer_type == "linear_flux_control_lora":
|
||||
# Create a linear layer.
|
||||
orig_in_features = 10
|
||||
out_features = 40
|
||||
layer = torch.nn.Linear(orig_in_features, out_features)
|
||||
|
||||
return ([(concatenated_lora_layer, 0.7)], input)
|
||||
elif layer_type == "flux_control_lora":
|
||||
# Create a FluxControlLoRALayer.
|
||||
patched_in_features = 20
|
||||
rank = 4
|
||||
patched_in_features = 40
|
||||
lora_layer = FluxControlLoRALayer(
|
||||
up=torch.randn(out_features, rank),
|
||||
mid=None,
|
||||
@ -360,17 +345,17 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU
|
||||
)
|
||||
|
||||
input = torch.randn(1, patched_in_features)
|
||||
return (layer, [(lora_layer, 0.7)], input, True)
|
||||
return ([(lora_layer, 0.7)], input)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer_type: {layer_type}")
|
||||
|
||||
|
||||
@parameterize_all_devices
|
||||
def test_sidecar_patches(device: str, layer_and_patch_under_test: LayerAndPatchUnderTest):
|
||||
layer, patches, input, supports_cpu_inference = layer_and_patch_under_test
|
||||
def test_linear_sidecar_patches(device: str, layer_type: str, patch_under_test: PatchUnderTest):
|
||||
patches, input = patch_under_test
|
||||
|
||||
if device == "cpu" and not supports_cpu_inference:
|
||||
pytest.skip("Layer does not support CPU inference.")
|
||||
# Build the base layer under test.
|
||||
layer = torch.nn.Linear(32, 64)
|
||||
|
||||
# Move the layer and input to the device.
|
||||
layer_to_device_via_state_dict(layer, device)
|
||||
@ -397,3 +382,60 @@ def test_sidecar_patches(device: str, layer_and_patch_under_test: LayerAndPatchU
|
||||
output_patched = layer_patched(input)
|
||||
output_custom = custom_layer(input)
|
||||
assert torch.allclose(output_patched, output_custom, atol=1e-6)
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
"linear_ggml_quantized",
|
||||
"invoke_linear_8_bit_lt",
|
||||
"invoke_linear_nf4",
|
||||
]
|
||||
)
|
||||
def quantized_linear_layer_under_test(request: pytest.FixtureRequest):
|
||||
in_features = 32
|
||||
out_features = 64
|
||||
torch.manual_seed(0)
|
||||
layer_type = request.param
|
||||
orig_layer = torch.nn.Linear(in_features, out_features)
|
||||
if layer_type == "linear_ggml_quantized":
|
||||
return orig_layer, build_linear_layer_with_ggml_quantized_tensor(orig_layer)
|
||||
elif layer_type == "invoke_linear_8_bit_lt":
|
||||
return orig_layer, build_linear_8bit_lt_layer(orig_layer)
|
||||
elif layer_type == "invoke_linear_nf4":
|
||||
return orig_layer, build_linear_nf4_layer(orig_layer)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer_type: {layer_type}")
|
||||
|
||||
|
||||
@parameterize_cuda_and_mps
|
||||
def test_quantized_linear_sidecar_patches(
|
||||
device: str,
|
||||
quantized_linear_layer_under_test: tuple[torch.nn.Module, torch.nn.Module],
|
||||
patch_under_test: PatchUnderTest,
|
||||
):
|
||||
"""Test that patches can be applied to quantized linear layers and that the output is the same as when the patch is
|
||||
applied to a non-quantized linear layer.
|
||||
"""
|
||||
patches, input = patch_under_test
|
||||
|
||||
linear_layer, quantized_linear_layer = quantized_linear_layer_under_test
|
||||
|
||||
# Move everything to the device.
|
||||
layer_to_device_via_state_dict(linear_layer, device)
|
||||
layer_to_device_via_state_dict(quantized_linear_layer, device)
|
||||
input = input.to(torch.device(device))
|
||||
|
||||
# Wrap both layers in custom layers.
|
||||
linear_layer_custom = wrap_single_custom_layer(linear_layer)
|
||||
quantized_linear_layer_custom = wrap_single_custom_layer(quantized_linear_layer)
|
||||
|
||||
# Apply the patches to the custom layers.
|
||||
for patch, weight in patches:
|
||||
patch.to(torch.device(device))
|
||||
linear_layer_custom.add_patch(patch, weight)
|
||||
quantized_linear_layer_custom.add_patch(patch, weight)
|
||||
|
||||
# Run inference with the original layer and the patched layer and assert they are equal.
|
||||
output_linear_patched = linear_layer_custom(input)
|
||||
output_quantized_patched = quantized_linear_layer_custom(input)
|
||||
assert torch.allclose(output_linear_patched, output_quantized_patched, rtol=0.2, atol=0.2)
|
||||
|
@ -14,17 +14,20 @@ else:
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
|
||||
|
||||
|
||||
def build_linear_8bit_lt_layer():
|
||||
def build_linear_8bit_lt_layer(orig_layer: torch.nn.Linear | None = None):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available")
|
||||
|
||||
torch.manual_seed(1)
|
||||
|
||||
orig_layer = torch.nn.Linear(32, 64)
|
||||
if orig_layer is None:
|
||||
orig_layer = torch.nn.Linear(32, 64)
|
||||
orig_layer_state_dict = orig_layer.state_dict()
|
||||
|
||||
# Prepare a quantized InvokeLinear8bitLt layer.
|
||||
quantized_layer = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False)
|
||||
quantized_layer = InvokeLinear8bitLt(
|
||||
input_features=orig_layer.in_features, output_features=orig_layer.out_features, has_fp16_weights=False
|
||||
)
|
||||
quantized_layer.load_state_dict(orig_layer_state_dict)
|
||||
quantized_layer.to("cuda")
|
||||
|
||||
|
@ -10,17 +10,19 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch
|
||||
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
|
||||
|
||||
|
||||
def build_linear_nf4_layer():
|
||||
def build_linear_nf4_layer(orig_layer: torch.nn.Linear | None = None):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available")
|
||||
|
||||
torch.manual_seed(1)
|
||||
|
||||
orig_layer = torch.nn.Linear(64, 16)
|
||||
if orig_layer is None:
|
||||
orig_layer = torch.nn.Linear(64, 16)
|
||||
|
||||
orig_layer_state_dict = orig_layer.state_dict()
|
||||
|
||||
# Prepare a quantized InvokeLinearNF4 layer.
|
||||
quantized_layer = InvokeLinearNF4(input_features=64, output_features=16)
|
||||
quantized_layer = InvokeLinearNF4(input_features=orig_layer.in_features, output_features=orig_layer.out_features)
|
||||
quantized_layer.load_state_dict(orig_layer_state_dict)
|
||||
quantized_layer.to("cuda")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user