Update BaseLayerPatch.get_parameters(...) to accept a dict of orig_parameters rather than orig_module. This will enable compatibility between patching and cpu->gpu streaming.

This commit is contained in:
Ryan Dick 2024-12-28 21:12:53 +00:00
parent 20acfc9a00
commit 2855bb6b41
11 changed files with 25 additions and 21 deletions

View File

@ -32,9 +32,9 @@ class CustomModuleMixin:
params: dict[str, torch.Tensor] = {}
for patch, patch_weight in patches_and_weights:
# TODO(ryand): self._orig_module could be quantized. Depending on what the patch is doing with the original
# module, this might fail or return incorrect results.
layer_params = patch.get_parameters(self, weight=patch_weight)
# TODO(ryand): `self` could be a quantized module. Depending on what the patch is doing with the original
# parameters, this might fail or return incorrect results.
layer_params = patch.get_parameters(dict(self.named_parameters(recurse=False)), weight=patch_weight) # type: ignore
for param_name, param_weight in layer_params.items():
if param_name not in params:

View File

@ -166,7 +166,9 @@ class LayerPatcher:
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, param_weight in patch.get_parameters(module_to_patch, weight=patch_weight).items():
for param_name, param_weight in patch.get_parameters(
dict(module_to_patch.named_parameters(recurse=False)), weight=patch_weight
).items():
param_key = module_to_patch_key + "." + param_name
module_param = module_to_patch.get_parameter(param_name)

View File

@ -5,7 +5,7 @@ import torch
class BaseLayerPatch(ABC):
@abstractmethod
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
"""Get the parameter residual updates that should be applied to the original parameters. Parameters omitted
from the returned dict are not updated.
"""

View File

@ -30,7 +30,7 @@ class ConcatenatedLoRALayer(LoRALayerBase):
layer_weights = [lora_layer.get_weight(None) * lora_layer.scale() for lora_layer in self.lora_layers] # pyright: ignore[reportArgumentType]
return torch.cat(layer_weights, dim=self.concat_axis)
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
def get_bias(self, orig_bias: torch.Tensor | None) -> Optional[torch.Tensor]:
# TODO(ryand): Currently, we pass orig_bias=None to the sub-layers. If we want to support sub-layers that
# require this value, we will need to implement chunking of the original bias tensor here.
# Note that we must apply the sub-layer scales here.

View File

@ -8,11 +8,11 @@ class FluxControlLoRALayer(LoRALayer):
shapes don't match.
"""
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
"""This overrides the base class behavior to skip the reshaping step."""
scale = self.scale()
params = {"weight": self.get_weight(orig_module.weight) * (weight * scale)}
bias = self.get_bias(orig_module.bias)
params = {"weight": self.get_weight(orig_parameters["weight"]) * (weight * scale)}
bias = self.get_bias(orig_parameters.get("bias", None))
if bias is not None:
params["bias"] = bias * (weight * scale)

View File

@ -54,19 +54,19 @@ class LoRALayerBase(BaseLayerPatch):
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
def get_bias(self, orig_bias: torch.Tensor | None) -> Optional[torch.Tensor]:
return self.bias
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
scale = self.scale()
params = {"weight": self.get_weight(orig_module.weight) * (weight * scale)}
bias = self.get_bias(orig_module.bias)
params = {"weight": self.get_weight(orig_parameters["weight"]) * (weight * scale)}
bias = self.get_bias(orig_parameters.get("bias", None))
if bias is not None:
params["bias"] = bias * (weight * scale)
# Reshape all params to match the original module's shape.
for param_name, param_weight in params.items():
orig_param = orig_module.get_parameter(param_name)
orig_param = orig_parameters[param_name]
if param_weight.shape != orig_param.shape:
params[param_name] = param_weight.reshape(orig_param.shape)

View File

@ -14,10 +14,10 @@ class SetParameterLayer(BaseLayerPatch):
self.weight = weight
self.param_name = param_name
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
# Note: We intentionally ignore the weight parameter here. This matches the behavior in the official FLUX
# Control LoRA implementation.
diff = self.weight - orig_module.get_parameter(self.param_name)
diff = self.weight - orig_parameters[self.param_name]
return {self.param_name: diff}
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):

View File

@ -39,8 +39,10 @@ class BaseSidecarWrapper(torch.nn.Module):
for patch, patch_weight in patches_and_weights:
# TODO(ryand): self._orig_module could be quantized. Depending on what the patch is doing with the original
# module, this might fail or return incorrect results.
layer_params = patch.get_parameters(self._orig_module, weight=patch_weight)
# parameters, this might fail or return incorrect results.
layer_params = patch.get_parameters(
dict(self._orig_module.named_parameters(recurse=False)), weight=patch_weight
)
for param_name, param_weight in layer_params.items():
if param_name not in params:

View File

@ -18,7 +18,7 @@ def test_flux_control_lora_layer_get_parameters():
orig_module = torch.nn.Linear(small_in_features, out_features)
# Test that get_parameters() behaves as expected in spite of the difference in in_features shapes.
params = layer.get_parameters(orig_module, weight=1.0)
params = layer.get_parameters(dict(orig_module.named_parameters(recurse=False)), weight=1.0)
assert "weight" in params
assert params["weight"].shape == (out_features, big_in_features)
assert params["weight"].allclose(torch.ones(out_features, big_in_features) * alpha)

View File

@ -107,7 +107,7 @@ def test_lora_layer_get_parameters():
# Create mock original module
orig_module = torch.nn.Linear(in_features, out_features)
params = layer.get_parameters(orig_module, weight=1.0)
params = layer.get_parameters(dict(orig_module.named_parameters(recurse=False)), weight=1.0)
assert "weight" in params
assert params["weight"].shape == orig_module.weight.shape
assert params["weight"].allclose(torch.ones(out_features, in_features) * alpha)

View File

@ -10,7 +10,7 @@ def test_set_parameter_layer_get_parameters():
target_weight = torch.randn(8, 4)
layer = SetParameterLayer(param_name="weight", weight=target_weight)
params = layer.get_parameters(orig_module, weight=1.0)
params = layer.get_parameters(dict(orig_module.named_parameters(recurse=False)), weight=1.0)
assert len(params) == 1
new_weight = orig_module.weight + params["weight"]
assert torch.allclose(new_weight, target_weight)