mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-07 03:17:05 +08:00
Add support for FluxControlLoRALayer in CustomLinear layers and add a unit test for it.
This commit is contained in:
parent
5ee7405f97
commit
ef970a1cdc
@ -4,17 +4,80 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
|
||||
add_nullable_tensors,
|
||||
)
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
|
||||
|
||||
def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
|
||||
"""An optimized implementation of the residual calculation for a sidecar linear LoRALayer."""
|
||||
x = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x = torch.nn.functional.linear(x, lora_layer.mid)
|
||||
x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias)
|
||||
x *= lora_weight * lora_layer.scale()
|
||||
return x
|
||||
|
||||
|
||||
def concatenated_lora_forward(
|
||||
input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float
|
||||
) -> torch.Tensor:
|
||||
"""An optimized implementation of the residual calculation for a sidecar ConcatenatedLoRALayer."""
|
||||
x_chunks: list[torch.Tensor] = []
|
||||
for lora_layer in concatenated_lora_layer.lora_layers:
|
||||
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
|
||||
x_chunk *= lora_weight * lora_layer.scale()
|
||||
x_chunks.append(x_chunk)
|
||||
|
||||
# TODO(ryand): Generalize to support concat_axis != 0.
|
||||
assert concatenated_lora_layer.concat_axis == 0
|
||||
x = torch.cat(x_chunks, dim=-1)
|
||||
return x
|
||||
|
||||
|
||||
def autocast_linear_forward_sidecar_patches(
|
||||
orig_module: torch.nn.Linear, input: torch.Tensor, patches_and_weights: list[tuple[BaseLayerPatch, float]]
|
||||
) -> torch.Tensor:
|
||||
"""A function that runs a linear layer (quantized or non-quantized) with sidecar patches for a linear layer.
|
||||
Compatible with both quantized and non-quantized Linear layers.
|
||||
"""
|
||||
# First, apply the original linear layer.
|
||||
# NOTE: We slice the input to match the original weight shape in order to work with FluxControlLoRAs, which
|
||||
# change the linear layer's in_features.
|
||||
orig_input = input
|
||||
input = orig_input[..., : orig_module.in_features]
|
||||
output = orig_module._autocast_forward(input)
|
||||
|
||||
# Then, apply layers for which we have optimized implementations.
|
||||
unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = []
|
||||
for patch, patch_weight in patches_and_weights:
|
||||
if isinstance(patch, FluxControlLoRALayer):
|
||||
# Note that we use the original input here, not the sliced input.
|
||||
output += linear_lora_forward(orig_input, patch, patch_weight)
|
||||
elif isinstance(patch, LoRALayer):
|
||||
output += linear_lora_forward(input, patch, patch_weight)
|
||||
elif isinstance(patch, ConcatenatedLoRALayer):
|
||||
output += concatenated_lora_forward(input, patch, patch_weight)
|
||||
else:
|
||||
unprocessed_patches_and_weights.append((patch, patch_weight))
|
||||
|
||||
# Finally, apply any remaining patches.
|
||||
if len(unprocessed_patches_and_weights) > 0:
|
||||
aggregated_param_residuals = orig_module._aggregate_patch_parameters(unprocessed_patches_and_weights)
|
||||
output += torch.nn.functional.linear(
|
||||
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class CustomLinear(torch.nn.Linear, CustomModuleMixin):
|
||||
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
|
||||
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
|
||||
weight = add_nullable_tensors(self.weight, aggregated_param_residuals["weight"])
|
||||
bias = add_nullable_tensors(self.bias, aggregated_param_residuals.get("bias", None))
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
return autocast_linear_forward_sidecar_patches(self, input, self._patches_and_weights)
|
||||
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
|
@ -10,9 +10,12 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch
|
||||
unwrap_custom_layer,
|
||||
wrap_custom_layer,
|
||||
)
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_8_bit_lt import (
|
||||
build_linear_8bit_lt_layer,
|
||||
)
|
||||
@ -272,6 +275,7 @@ LayerAndPatchUnderTest = tuple[torch.nn.Module, list[tuple[BaseLayerPatch, float
|
||||
"linear_single_lora",
|
||||
"linear_multiple_loras",
|
||||
"linear_concatenated_lora",
|
||||
"linear_flux_control_lora",
|
||||
]
|
||||
)
|
||||
def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchUnderTest:
|
||||
@ -338,6 +342,25 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU
|
||||
|
||||
input = torch.randn(1, in_features)
|
||||
return (layer, [(concatenated_lora_layer, 0.7)], input, True)
|
||||
elif layer_type == "linear_flux_control_lora":
|
||||
# Create a linear layer.
|
||||
orig_in_features = 10
|
||||
out_features = 40
|
||||
layer = torch.nn.Linear(orig_in_features, out_features)
|
||||
|
||||
# Create a FluxControlLoRALayer.
|
||||
patched_in_features = 20
|
||||
rank = 4
|
||||
lora_layer = FluxControlLoRALayer(
|
||||
up=torch.randn(out_features, rank),
|
||||
mid=None,
|
||||
down=torch.randn(rank, patched_in_features),
|
||||
alpha=1.0,
|
||||
bias=torch.randn(out_features),
|
||||
)
|
||||
|
||||
input = torch.randn(1, patched_in_features)
|
||||
return (layer, [(lora_layer, 0.7)], input, True)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer_type: {layer_type}")
|
||||
|
||||
@ -356,18 +379,21 @@ def test_sidecar_patches(device: str, layer_and_patch_under_test: LayerAndPatchU
|
||||
# Patch the LoRA layer into the linear layer.
|
||||
layer_patched = copy.deepcopy(layer)
|
||||
for patch, weight in patches:
|
||||
patch.to(torch.device(device))
|
||||
parameters = patch.get_parameters(layer_patched, weight=weight)
|
||||
for param_name, param_weight in parameters.items():
|
||||
module_param = getattr(layer_patched, param_name)
|
||||
module_param.data += param_weight
|
||||
LayerPatcher._apply_model_layer_patch(
|
||||
module_to_patch=layer_patched,
|
||||
module_to_patch_key="",
|
||||
patch=patch,
|
||||
patch_weight=weight,
|
||||
original_weights=OriginalWeightsStorage(),
|
||||
)
|
||||
|
||||
# Wrap the original layer in a custom layer and add the patch to it as a sidecar.
|
||||
custom_layer = wrap_single_custom_layer(layer)
|
||||
for patch, weight in patches:
|
||||
patch.to(torch.device(device))
|
||||
custom_layer.add_patch(patch, weight)
|
||||
|
||||
# Run inference with the original layer and the patched layer and assert they are equal.
|
||||
output_patched = layer_patched(input)
|
||||
output_custom = custom_layer(input)
|
||||
assert torch.allclose(output_patched, output_custom)
|
||||
assert torch.allclose(output_patched, output_custom, atol=1e-6)
|
||||
|
Loading…
Reference in New Issue
Block a user