Support bnb quantized nf4 flux models, Use controlnet vae, only support 1 structural lora per transformer. various other refractors and bugfixes

This commit is contained in:
Brandon Rising 2024-12-10 03:26:29 -05:00 committed by Kent Keirsey
parent f3b253987f
commit 5a035dd19f
10 changed files with 121 additions and 45 deletions

View File

@ -290,20 +290,22 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
device=x.device,
)
img_cond = None
for struct_lora in self.transformer.structural_loras:
if struct_lora := self.transformer.structural_lora:
# What should we do when we have multiple of these?
ae_info = context.models.load(struct_lora.vae.vae)
if not self.controlnet_vae:
raise ValueError("controlnet_vae must be set when using a strutural lora")
ae_info = context.models.load(self.controlnet_vae.vae)
img = context.images.get_pil(struct_lora.img.image_name)
with ae_info as ae:
assert isinstance(ae, AutoEncoder)
img_cond = prepare_control(x, ae, img)
img_cond = prepare_control(self.height, self.width, self.seed, ae, img)
# Load the transformer model.
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
assert isinstance(transformer, Flux)
config = transformer_info.config
assert config is not None
if self.transformer.structural_loras:
if self.transformer.structural_lora:
replace_linear_with_lora(transformer, 128)
# Apply LoRA models to the transformer.
@ -698,7 +700,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
return pos_ip_adapter_extensions, neg_ip_adapter_extensions
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
loras: list[Union[LoRAField, StructuralLoRAField]] = [*self.transformer.loras, *self.transformer.structural_loras]
loras: list[Union[LoRAField, StructuralLoRAField]] = [*self.transformer.loras]
if self.transformer.structural_lora:
loras.append(self.transformer.structural_lora)
for lora in loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)

View File

@ -41,7 +41,6 @@ class FluxStructuralLoRALoaderInvocation(BaseInvocation):
input=Input.Connection,
title="FLUX Transformer",
)
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
image: ImageField = InputField(
description="The image to encode.",
)
@ -54,7 +53,7 @@ class FluxStructuralLoRALoaderInvocation(BaseInvocation):
raise ValueError(f"Unknown lora: {lora_key}!")
# Check for existing LoRAs with the same key.
if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.structural_loras):
if self.transformer and self.transformer.structural_lora and self.transformer.structural_lora.lora.key == lora_key:
raise ValueError(f'Structural LoRA "{lora_key}" already applied to transformer.')
output = FluxStructuralLoRALoaderOutput()
@ -62,13 +61,10 @@ class FluxStructuralLoRALoaderInvocation(BaseInvocation):
# Attach LoRA layers to the models.
if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
output.transformer.structural_loras.append(
StructuralLoRAField(
lora=self.lora,
vae=self.vae,
img=self.image,
weight=self.weight,
)
output.transformer.structural_lora = StructuralLoRAField(
lora=self.lora,
img=self.image,
weight=self.weight,
)
return output

View File

@ -76,13 +76,11 @@ class VAEField(BaseModel):
class StructuralLoRAField(LoRAField):
img: ImageField = Field(description="Image to use in structural conditioning")
vae: VAEField = Field(description="VAE To use with structural lora")
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
structural_loras: List[StructuralLoRAField] = Field(description="Structural LoRAs to apply on model loading")
structural_lora: Optional[StructuralLoRAField] = Field(description="Structural LoRAs to apply on model loading", default=None)
@invocation_output("unet_output")
class UNetOutput(BaseInvocationOutput):

View File

@ -71,9 +71,9 @@ def denoise(
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
# tensors. Calculating the sum materializes each tensor into its own instance.
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
pred_img = torch.cat((img, img_cond), dim=-1) if img_cond is not None else img
pred = model(
img=torch.cat((img, img_cond), dim=-1) if img_cond is not None else img,
img=pred_img,
img_ids=img_ids,
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,

View File

@ -6,20 +6,22 @@ from einops import rearrange
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
def prepare_control(
img: torch.Tensor,
height: int,
width: int,
seed: int,
ae: AutoEncoder,
cond_image: Image.Image,
) -> torch.Tensor:
# load and encode the conditioning image
_, h, w = img.shape
img_cond = cond_image.convert("RGB")
width = w * 8
height = h * 8
img_cond = img_cond.resize((width, height), Image.Resampling.LANCZOS)
img_cond = np.array(img_cond)
img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
img_cond = torch.from_numpy(img_cond).float()
img_cond = rearrange(img_cond, "h w c -> 1 c h w")
img_cond = img_cond.to(dtype=img.dtype, device=img.device)
img_cond = ae.encode(img_cond)
ae_dtype = next(iter(ae.parameters())).dtype
ae_device = next(iter(ae.parameters())).device
img_cond = img_cond.to(device=ae_device, dtype=ae_dtype)
generator = torch.Generator(device=ae_device).manual_seed(seed)
img_cond = ae.encode(img_cond, sample=True, generator=generator)
img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
return img_cond

View File

@ -1,19 +1,31 @@
import torch
import bitsandbytes as bnb
from torch import nn
def replace_linear_with_lora(
module: nn.Module,
max_rank: int,
scale: float = 1.0,
) -> None:
for name, child in module.named_children():
if isinstance(child, (LinearLora, BNBNF4LinearLora)):
# TODO: We really need to undo this after generation runs
return
if isinstance(child, nn.Linear):
new_lora = LinearLora(
in_features=child.in_features,
dtype = child.weight.dtype
loraModule = LinearLora
if hasattr(child, "compute_dtype"):
dtype = getattr(child, "compute_dtype")
loraModule = BNBNF4LinearLora
new_lora = loraModule(
# Double the in features to accommodate the control image. This conditional is to avoid increasing the final_layer in features
in_features=child.out_features*2 if child.in_features == child.out_features else child.in_features,
out_features=child.out_features,
bias=child.bias is not None,
rank=max_rank,
scale=scale,
dtype=getattr(child, "compute_dtype") if hasattr(child, "compute_dtype") else child.weight.dtype,
dtype=dtype,
device=child.weight.device,
)
new_lora.weight = child.weight
@ -26,6 +38,61 @@ def replace_linear_with_lora(
scale=scale,
)
class BNBNF4LinearLora(bnb.nn.LinearNF4):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
rank: int,
dtype: torch.dtype,
device: torch.device,
lora_bias: bool = True,
scale: float = 1.0,
*args,
**kwargs,
) -> None:
super().__init__(
input_features=in_features,
output_features=out_features,
bias=bias is not None,
device=device,
compute_dtype=dtype,
*args,
**kwargs,
)
assert isinstance(scale, float), "scale must be a float"
self.scale = scale
self.rank = rank
self.lora_bias = lora_bias
self.dtype = dtype
self.device = device
if rank > (new_rank := min(self.out_features, self.in_features)):
self.rank = new_rank
self.lora_A = bnb.nn.LinearNF4(
input_features=in_features,
output_features=self.rank,
bias=False,
device=device,
compute_dtype=dtype
)
self.lora_B = bnb.nn.LinearNF4(
input_features=self.rank,
output_features=out_features,
bias=self.lora_bias,
device=device,
compute_dtype=dtype
)
def set_scale(self, scale: float) -> None:
assert isinstance(scale, float), "scalar value must be a float"
self.scale = scale
def forward(self, x: torch.Tensor) -> torch.Tensor:
base_out = super().forward(x)
_lora_out_B = self.lora_B(self.lora_A(x))
lora_update = _lora_out_B * self.scale
return base_out + lora_update
class LinearLora(nn.Linear):
def __init__(
self,

View File

@ -26,13 +26,16 @@ def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
)
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
state_dict = _convert_lora_bfl_control(state_dict=state_dict)
converted_state_dict = _convert_lora_bfl_control(state_dict=state_dict)
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
for key, value in converted_state_dict.items():
key_props = key.split(".")
layer_name = ".".join(key_props[:-1])
param_name = key_props[-1]
# Got it loading using lora_down and lora_up but it didn't seem to match this lora's structure
# Leaving this in since it doesn't hurt anything and may be better
layer_prop_size = -2 if any(prop in key for prop in ["lora_down", "lora_up"]) else -1
layer_name = ".".join(key_props[:layer_prop_size])
param_name = ".".join(key_props[layer_prop_size:])
if layer_name not in grouped_state_dict:
grouped_state_dict[layer_name] = {}
grouped_state_dict[layer_name][param_name] = value
@ -50,10 +53,11 @@ def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor])
def _convert_lora_bfl_control(state_dict: dict[str, torch.Tensor])-> dict[str, torch.Tensor]:
sd_out: dict[str, torch.Tensor] = {}
for k in state_dict:
if k.endswith(".scale"): # TODO: Fix these patches
continue
k_to = k.replace(".lora_B.bias", ".lora_B.diff_b")\
.replace(".lora_A.weight", ".lora_A.diff")\
.replace(".lora_B.weight", ".lora_B.diff")\
.replace("_norm.scale", "_norm.scale.set_weight")
.replace(".lora_B.weight", ".lora_B.diff")
sd_out[k_to] = state_dict[k]
# sd_out["img_in.reshape_weight"] = torch.tensor([state_dict["img_in.lora_B.weight"].shape[0], state_dict["img_in.lora_A.weight"].shape[1]])

View File

@ -7,17 +7,22 @@ from invokeai.backend.util.calc_tensor_size import calc_tensor_size
class SetWeightLayer(LoRALayerBase):
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor]):
# TODO: Just everything in this class
def __init__(self, weight: Optional[torch.Tensor], bias: Optional[torch.Tensor], scale: Optional[torch.Tensor]):
super().__init__(alpha=None, bias=bias)
self.weight = torch.nn.Parameter(weight)
self.weight = torch.nn.Parameter(weight) if weight is not None else None
self.manual_scale = scale
def scale(self):
return self.manual_scale.float() if self.manual_scale is not None else super().scale()
@classmethod
def from_state_dict_values(
cls,
values: Dict[str, torch.Tensor],
):
layer = cls(weight=values["set_weight"], bias=values.get("set_bias", None))
cls.warn_on_unhandled_keys(values=values, handled_keys={"set_weight", "set_bias"})
layer = cls(weight=values.get("set_weight", None), bias=values.get("set_bias", None), scale=values.get("set_scale", None))
cls.warn_on_unhandled_keys(values=values, handled_keys={"set_weight", "set_bias", "set_scale"})
return layer
def rank(self) -> int | None:
@ -28,7 +33,10 @@ class SetWeightLayer(LoRALayerBase):
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
if self.weight is not None:
self.weight = self.weight.to(device=device, dtype=dtype)
if self.manual_scale is not None:
self.manual_scale = self.manual_scale.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return super().calc_size() + calc_tensor_size(self.weight)
return super().calc_size() + calc_tensor_size(self.manual_scale)

View File

@ -30,7 +30,7 @@ def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLo
return IA3Layer.from_state_dict_values(state_dict)
elif "w_norm" in state_dict:
return NormLayer.from_state_dict_values(state_dict)
elif "set_weight" in state_dict:
elif any(key in state_dict for key in ["set_weight", "set_bias", "set_scale"]):
return SetWeightLayer.from_state_dict_values(state_dict)
else:
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")

View File

@ -6,7 +6,6 @@ import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.set_weight_layer import SetWeightLayer
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
ConcatenatedLoRALinearSidecarLayer,
@ -116,8 +115,6 @@ class LoRAPatcher:
if module_param.shape != lora_param_weight.shape:
lora_param_weight = lora_param_weight.reshape(module_param.shape)
if isinstance(layer, SetWeightLayer):
module_param = lora_param_weight
lora_param_weight *= patch_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype)