2024-12-11 14:14:50 -05:00
|
|
|
import torch
|
|
|
|
|
2024-12-13 16:52:57 +00:00
|
|
|
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
2024-12-11 14:14:50 -05:00
|
|
|
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
|
|
|
|
|
|
|
|
2024-12-13 16:52:57 +00:00
|
|
|
class SetParameterLayer(BaseLayerPatch):
|
|
|
|
"""A layer that sets a single parameter to a new target value.
|
|
|
|
(The diff between the target value and current value is calculated internally.)
|
|
|
|
"""
|
2024-12-11 14:14:50 -05:00
|
|
|
|
2024-12-13 16:52:57 +00:00
|
|
|
def __init__(self, param_name: str, weight: torch.Tensor):
|
|
|
|
super().__init__()
|
2024-12-13 21:23:43 +00:00
|
|
|
self.weight = weight
|
|
|
|
self.param_name = param_name
|
2024-12-12 13:45:07 -05:00
|
|
|
|
2024-12-28 21:12:53 +00:00
|
|
|
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
|
2024-12-17 05:20:44 +00:00
|
|
|
# Note: We intentionally ignore the weight parameter here. This matches the behavior in the official FLUX
|
|
|
|
# Control LoRA implementation.
|
2024-12-28 21:12:53 +00:00
|
|
|
diff = self.weight - orig_parameters[self.param_name]
|
2024-12-17 05:20:44 +00:00
|
|
|
return {self.param_name: diff}
|
2024-12-11 14:14:50 -05:00
|
|
|
|
|
|
|
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
2024-12-13 21:23:43 +00:00
|
|
|
self.weight = self.weight.to(device=device, dtype=dtype)
|
2024-12-11 14:14:50 -05:00
|
|
|
|
|
|
|
def calc_size(self) -> int:
|
2024-12-13 21:23:43 +00:00
|
|
|
return calc_tensor_size(self.weight)
|