diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py index 58027b2951..e833591109 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py @@ -4,17 +4,80 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( CustomModuleMixin, ) -from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import ( - add_nullable_tensors, -) +from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch +from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer +from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer +from invokeai.backend.patches.layers.lora_layer import LoRALayer + + +def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor: + """An optimized implementation of the residual calculation for a sidecar linear LoRALayer.""" + x = torch.nn.functional.linear(input, lora_layer.down) + if lora_layer.mid is not None: + x = torch.nn.functional.linear(x, lora_layer.mid) + x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias) + x *= lora_weight * lora_layer.scale() + return x + + +def concatenated_lora_forward( + input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float +) -> torch.Tensor: + """An optimized implementation of the residual calculation for a sidecar ConcatenatedLoRALayer.""" + x_chunks: list[torch.Tensor] = [] + for lora_layer in concatenated_lora_layer.lora_layers: + x_chunk = torch.nn.functional.linear(input, lora_layer.down) + if lora_layer.mid is not None: + x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid) + x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias) + x_chunk *= lora_weight * lora_layer.scale() + x_chunks.append(x_chunk) + + # TODO(ryand): Generalize to support concat_axis != 0. + assert concatenated_lora_layer.concat_axis == 0 + x = torch.cat(x_chunks, dim=-1) + return x + + +def autocast_linear_forward_sidecar_patches( + orig_module: torch.nn.Linear, input: torch.Tensor, patches_and_weights: list[tuple[BaseLayerPatch, float]] +) -> torch.Tensor: + """A function that runs a linear layer (quantized or non-quantized) with sidecar patches for a linear layer. + Compatible with both quantized and non-quantized Linear layers. + """ + # First, apply the original linear layer. + # NOTE: We slice the input to match the original weight shape in order to work with FluxControlLoRAs, which + # change the linear layer's in_features. + orig_input = input + input = orig_input[..., : orig_module.in_features] + output = orig_module._autocast_forward(input) + + # Then, apply layers for which we have optimized implementations. + unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = [] + for patch, patch_weight in patches_and_weights: + if isinstance(patch, FluxControlLoRALayer): + # Note that we use the original input here, not the sliced input. + output += linear_lora_forward(orig_input, patch, patch_weight) + elif isinstance(patch, LoRALayer): + output += linear_lora_forward(input, patch, patch_weight) + elif isinstance(patch, ConcatenatedLoRALayer): + output += concatenated_lora_forward(input, patch, patch_weight) + else: + unprocessed_patches_and_weights.append((patch, patch_weight)) + + # Finally, apply any remaining patches. + if len(unprocessed_patches_and_weights) > 0: + aggregated_param_residuals = orig_module._aggregate_patch_parameters(unprocessed_patches_and_weights) + output += torch.nn.functional.linear( + input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None) + ) + + return output class CustomLinear(torch.nn.Linear, CustomModuleMixin): def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor: - aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights) - weight = add_nullable_tensors(self.weight, aggregated_param_residuals["weight"]) - bias = add_nullable_tensors(self.bias, aggregated_param_residuals.get("bias", None)) - return torch.nn.functional.linear(input, weight, bias) + return autocast_linear_forward_sidecar_patches(self, input, self._patches_and_weights) def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: weight = cast_to_device(self.weight, input.device) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py index 2668ca61a4..b01a744be6 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py @@ -10,9 +10,12 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch unwrap_custom_layer, wrap_custom_layer, ) +from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer +from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer from invokeai.backend.patches.layers.lora_layer import LoRALayer +from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_8_bit_lt import ( build_linear_8bit_lt_layer, ) @@ -272,6 +275,7 @@ LayerAndPatchUnderTest = tuple[torch.nn.Module, list[tuple[BaseLayerPatch, float "linear_single_lora", "linear_multiple_loras", "linear_concatenated_lora", + "linear_flux_control_lora", ] ) def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchUnderTest: @@ -338,6 +342,25 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU input = torch.randn(1, in_features) return (layer, [(concatenated_lora_layer, 0.7)], input, True) + elif layer_type == "linear_flux_control_lora": + # Create a linear layer. + orig_in_features = 10 + out_features = 40 + layer = torch.nn.Linear(orig_in_features, out_features) + + # Create a FluxControlLoRALayer. + patched_in_features = 20 + rank = 4 + lora_layer = FluxControlLoRALayer( + up=torch.randn(out_features, rank), + mid=None, + down=torch.randn(rank, patched_in_features), + alpha=1.0, + bias=torch.randn(out_features), + ) + + input = torch.randn(1, patched_in_features) + return (layer, [(lora_layer, 0.7)], input, True) else: raise ValueError(f"Unsupported layer_type: {layer_type}") @@ -356,18 +379,21 @@ def test_sidecar_patches(device: str, layer_and_patch_under_test: LayerAndPatchU # Patch the LoRA layer into the linear layer. layer_patched = copy.deepcopy(layer) for patch, weight in patches: - patch.to(torch.device(device)) - parameters = patch.get_parameters(layer_patched, weight=weight) - for param_name, param_weight in parameters.items(): - module_param = getattr(layer_patched, param_name) - module_param.data += param_weight + LayerPatcher._apply_model_layer_patch( + module_to_patch=layer_patched, + module_to_patch_key="", + patch=patch, + patch_weight=weight, + original_weights=OriginalWeightsStorage(), + ) # Wrap the original layer in a custom layer and add the patch to it as a sidecar. custom_layer = wrap_single_custom_layer(layer) for patch, weight in patches: + patch.to(torch.device(device)) custom_layer.add_patch(patch, weight) # Run inference with the original layer and the patched layer and assert they are equal. output_patched = layer_patched(input) output_custom = custom_layer(input) - assert torch.allclose(output_patched, output_custom) + assert torch.allclose(output_patched, output_custom, atol=1e-6)