mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-09 04:18:46 +08:00
Support bnb quantized nf4 flux models, Use controlnet vae, only support 1 structural lora per transformer. various other refractors and bugfixes
This commit is contained in:
parent
f3b253987f
commit
5a035dd19f
@ -290,20 +290,22 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
device=x.device,
|
||||
)
|
||||
img_cond = None
|
||||
for struct_lora in self.transformer.structural_loras:
|
||||
if struct_lora := self.transformer.structural_lora:
|
||||
# What should we do when we have multiple of these?
|
||||
ae_info = context.models.load(struct_lora.vae.vae)
|
||||
if not self.controlnet_vae:
|
||||
raise ValueError("controlnet_vae must be set when using a strutural lora")
|
||||
ae_info = context.models.load(self.controlnet_vae.vae)
|
||||
img = context.images.get_pil(struct_lora.img.image_name)
|
||||
with ae_info as ae:
|
||||
assert isinstance(ae, AutoEncoder)
|
||||
img_cond = prepare_control(x, ae, img)
|
||||
img_cond = prepare_control(self.height, self.width, self.seed, ae, img)
|
||||
|
||||
# Load the transformer model.
|
||||
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
|
||||
assert isinstance(transformer, Flux)
|
||||
config = transformer_info.config
|
||||
assert config is not None
|
||||
if self.transformer.structural_loras:
|
||||
if self.transformer.structural_lora:
|
||||
replace_linear_with_lora(transformer, 128)
|
||||
|
||||
# Apply LoRA models to the transformer.
|
||||
@ -698,7 +700,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return pos_ip_adapter_extensions, neg_ip_adapter_extensions
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
loras: list[Union[LoRAField, StructuralLoRAField]] = [*self.transformer.loras, *self.transformer.structural_loras]
|
||||
loras: list[Union[LoRAField, StructuralLoRAField]] = [*self.transformer.loras]
|
||||
if self.transformer.structural_lora:
|
||||
loras.append(self.transformer.structural_lora)
|
||||
for lora in loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
|
@ -41,7 +41,6 @@ class FluxStructuralLoRALoaderInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
title="FLUX Transformer",
|
||||
)
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
|
||||
image: ImageField = InputField(
|
||||
description="The image to encode.",
|
||||
)
|
||||
@ -54,7 +53,7 @@ class FluxStructuralLoRALoaderInvocation(BaseInvocation):
|
||||
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||
|
||||
# Check for existing LoRAs with the same key.
|
||||
if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.structural_loras):
|
||||
if self.transformer and self.transformer.structural_lora and self.transformer.structural_lora.lora.key == lora_key:
|
||||
raise ValueError(f'Structural LoRA "{lora_key}" already applied to transformer.')
|
||||
|
||||
output = FluxStructuralLoRALoaderOutput()
|
||||
@ -62,13 +61,10 @@ class FluxStructuralLoRALoaderInvocation(BaseInvocation):
|
||||
# Attach LoRA layers to the models.
|
||||
if self.transformer is not None:
|
||||
output.transformer = self.transformer.model_copy(deep=True)
|
||||
output.transformer.structural_loras.append(
|
||||
StructuralLoRAField(
|
||||
lora=self.lora,
|
||||
vae=self.vae,
|
||||
img=self.image,
|
||||
weight=self.weight,
|
||||
)
|
||||
output.transformer.structural_lora = StructuralLoRAField(
|
||||
lora=self.lora,
|
||||
img=self.image,
|
||||
weight=self.weight,
|
||||
)
|
||||
|
||||
return output
|
||||
|
@ -76,13 +76,11 @@ class VAEField(BaseModel):
|
||||
|
||||
class StructuralLoRAField(LoRAField):
|
||||
img: ImageField = Field(description="Image to use in structural conditioning")
|
||||
vae: VAEField = Field(description="VAE To use with structural lora")
|
||||
|
||||
|
||||
class TransformerField(BaseModel):
|
||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
structural_loras: List[StructuralLoRAField] = Field(description="Structural LoRAs to apply on model loading")
|
||||
structural_lora: Optional[StructuralLoRAField] = Field(description="Structural LoRAs to apply on model loading", default=None)
|
||||
|
||||
@invocation_output("unet_output")
|
||||
class UNetOutput(BaseInvocationOutput):
|
||||
|
@ -71,9 +71,9 @@ def denoise(
|
||||
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
|
||||
# tensors. Calculating the sum materializes each tensor into its own instance.
|
||||
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
|
||||
|
||||
pred_img = torch.cat((img, img_cond), dim=-1) if img_cond is not None else img
|
||||
pred = model(
|
||||
img=torch.cat((img, img_cond), dim=-1) if img_cond is not None else img,
|
||||
img=pred_img,
|
||||
img_ids=img_ids,
|
||||
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
|
@ -6,20 +6,22 @@ from einops import rearrange
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
|
||||
def prepare_control(
|
||||
img: torch.Tensor,
|
||||
height: int,
|
||||
width: int,
|
||||
seed: int,
|
||||
ae: AutoEncoder,
|
||||
cond_image: Image.Image,
|
||||
) -> torch.Tensor:
|
||||
# load and encode the conditioning image
|
||||
_, h, w = img.shape
|
||||
img_cond = cond_image.convert("RGB")
|
||||
width = w * 8
|
||||
height = h * 8
|
||||
img_cond = img_cond.resize((width, height), Image.Resampling.LANCZOS)
|
||||
img_cond = np.array(img_cond)
|
||||
img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
|
||||
img_cond = torch.from_numpy(img_cond).float()
|
||||
img_cond = rearrange(img_cond, "h w c -> 1 c h w")
|
||||
img_cond = img_cond.to(dtype=img.dtype, device=img.device)
|
||||
img_cond = ae.encode(img_cond)
|
||||
ae_dtype = next(iter(ae.parameters())).dtype
|
||||
ae_device = next(iter(ae.parameters())).device
|
||||
img_cond = img_cond.to(device=ae_device, dtype=ae_dtype)
|
||||
generator = torch.Generator(device=ae_device).manual_seed(seed)
|
||||
img_cond = ae.encode(img_cond, sample=True, generator=generator)
|
||||
img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
return img_cond
|
||||
|
@ -1,19 +1,31 @@
|
||||
import torch
|
||||
|
||||
import bitsandbytes as bnb
|
||||
from torch import nn
|
||||
|
||||
def replace_linear_with_lora(
|
||||
module: nn.Module,
|
||||
max_rank: int,
|
||||
scale: float = 1.0,
|
||||
) -> None:
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, (LinearLora, BNBNF4LinearLora)):
|
||||
# TODO: We really need to undo this after generation runs
|
||||
return
|
||||
if isinstance(child, nn.Linear):
|
||||
new_lora = LinearLora(
|
||||
in_features=child.in_features,
|
||||
dtype = child.weight.dtype
|
||||
loraModule = LinearLora
|
||||
if hasattr(child, "compute_dtype"):
|
||||
dtype = getattr(child, "compute_dtype")
|
||||
loraModule = BNBNF4LinearLora
|
||||
new_lora = loraModule(
|
||||
# Double the in features to accommodate the control image. This conditional is to avoid increasing the final_layer in features
|
||||
in_features=child.out_features*2 if child.in_features == child.out_features else child.in_features,
|
||||
out_features=child.out_features,
|
||||
bias=child.bias is not None,
|
||||
rank=max_rank,
|
||||
scale=scale,
|
||||
dtype=getattr(child, "compute_dtype") if hasattr(child, "compute_dtype") else child.weight.dtype,
|
||||
dtype=dtype,
|
||||
device=child.weight.device,
|
||||
)
|
||||
new_lora.weight = child.weight
|
||||
@ -26,6 +38,61 @@ def replace_linear_with_lora(
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
class BNBNF4LinearLora(bnb.nn.LinearNF4):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool,
|
||||
rank: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
lora_bias: bool = True,
|
||||
scale: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
input_features=in_features,
|
||||
output_features=out_features,
|
||||
bias=bias is not None,
|
||||
device=device,
|
||||
compute_dtype=dtype,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
assert isinstance(scale, float), "scale must be a float"
|
||||
self.scale = scale
|
||||
self.rank = rank
|
||||
self.lora_bias = lora_bias
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if rank > (new_rank := min(self.out_features, self.in_features)):
|
||||
self.rank = new_rank
|
||||
self.lora_A = bnb.nn.LinearNF4(
|
||||
input_features=in_features,
|
||||
output_features=self.rank,
|
||||
bias=False,
|
||||
device=device,
|
||||
compute_dtype=dtype
|
||||
)
|
||||
self.lora_B = bnb.nn.LinearNF4(
|
||||
input_features=self.rank,
|
||||
output_features=out_features,
|
||||
bias=self.lora_bias,
|
||||
device=device,
|
||||
compute_dtype=dtype
|
||||
)
|
||||
def set_scale(self, scale: float) -> None:
|
||||
assert isinstance(scale, float), "scalar value must be a float"
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
base_out = super().forward(x)
|
||||
_lora_out_B = self.lora_B(self.lora_A(x))
|
||||
lora_update = _lora_out_B * self.scale
|
||||
return base_out + lora_update
|
||||
|
||||
class LinearLora(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -26,13 +26,16 @@ def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
|
||||
)
|
||||
|
||||
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
state_dict = _convert_lora_bfl_control(state_dict=state_dict)
|
||||
converted_state_dict = _convert_lora_bfl_control(state_dict=state_dict)
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for key, value in state_dict.items():
|
||||
for key, value in converted_state_dict.items():
|
||||
key_props = key.split(".")
|
||||
layer_name = ".".join(key_props[:-1])
|
||||
param_name = key_props[-1]
|
||||
# Got it loading using lora_down and lora_up but it didn't seem to match this lora's structure
|
||||
# Leaving this in since it doesn't hurt anything and may be better
|
||||
layer_prop_size = -2 if any(prop in key for prop in ["lora_down", "lora_up"]) else -1
|
||||
layer_name = ".".join(key_props[:layer_prop_size])
|
||||
param_name = ".".join(key_props[layer_prop_size:])
|
||||
if layer_name not in grouped_state_dict:
|
||||
grouped_state_dict[layer_name] = {}
|
||||
grouped_state_dict[layer_name][param_name] = value
|
||||
@ -50,10 +53,11 @@ def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor])
|
||||
def _convert_lora_bfl_control(state_dict: dict[str, torch.Tensor])-> dict[str, torch.Tensor]:
|
||||
sd_out: dict[str, torch.Tensor] = {}
|
||||
for k in state_dict:
|
||||
if k.endswith(".scale"): # TODO: Fix these patches
|
||||
continue
|
||||
k_to = k.replace(".lora_B.bias", ".lora_B.diff_b")\
|
||||
.replace(".lora_A.weight", ".lora_A.diff")\
|
||||
.replace(".lora_B.weight", ".lora_B.diff")\
|
||||
.replace("_norm.scale", "_norm.scale.set_weight")
|
||||
.replace(".lora_B.weight", ".lora_B.diff")
|
||||
sd_out[k_to] = state_dict[k]
|
||||
|
||||
# sd_out["img_in.reshape_weight"] = torch.tensor([state_dict["img_in.lora_B.weight"].shape[0], state_dict["img_in.lora_A.weight"].shape[1]])
|
||||
|
@ -7,17 +7,22 @@ from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
class SetWeightLayer(LoRALayerBase):
|
||||
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor]):
|
||||
# TODO: Just everything in this class
|
||||
def __init__(self, weight: Optional[torch.Tensor], bias: Optional[torch.Tensor], scale: Optional[torch.Tensor]):
|
||||
super().__init__(alpha=None, bias=bias)
|
||||
self.weight = torch.nn.Parameter(weight)
|
||||
self.weight = torch.nn.Parameter(weight) if weight is not None else None
|
||||
self.manual_scale = scale
|
||||
|
||||
def scale(self):
|
||||
return self.manual_scale.float() if self.manual_scale is not None else super().scale()
|
||||
|
||||
@classmethod
|
||||
def from_state_dict_values(
|
||||
cls,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
layer = cls(weight=values["set_weight"], bias=values.get("set_bias", None))
|
||||
cls.warn_on_unhandled_keys(values=values, handled_keys={"set_weight", "set_bias"})
|
||||
layer = cls(weight=values.get("set_weight", None), bias=values.get("set_bias", None), scale=values.get("set_scale", None))
|
||||
cls.warn_on_unhandled_keys(values=values, handled_keys={"set_weight", "set_bias", "set_scale"})
|
||||
return layer
|
||||
|
||||
def rank(self) -> int | None:
|
||||
@ -28,7 +33,10 @@ class SetWeightLayer(LoRALayerBase):
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
super().to(device=device, dtype=dtype)
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
if self.weight is not None:
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
if self.manual_scale is not None:
|
||||
self.manual_scale = self.manual_scale.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
return super().calc_size() + calc_tensor_size(self.weight)
|
||||
return super().calc_size() + calc_tensor_size(self.manual_scale)
|
||||
|
@ -30,7 +30,7 @@ def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLo
|
||||
return IA3Layer.from_state_dict_values(state_dict)
|
||||
elif "w_norm" in state_dict:
|
||||
return NormLayer.from_state_dict_values(state_dict)
|
||||
elif "set_weight" in state_dict:
|
||||
elif any(key in state_dict for key in ["set_weight", "set_bias", "set_scale"]):
|
||||
return SetWeightLayer.from_state_dict_values(state_dict)
|
||||
else:
|
||||
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")
|
||||
|
@ -6,7 +6,6 @@ import torch
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.set_weight_layer import SetWeightLayer
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
|
||||
ConcatenatedLoRALinearSidecarLayer,
|
||||
@ -116,8 +115,6 @@ class LoRAPatcher:
|
||||
|
||||
if module_param.shape != lora_param_weight.shape:
|
||||
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||
if isinstance(layer, SetWeightLayer):
|
||||
module_param = lora_param_weight
|
||||
lora_param_weight *= patch_weight * layer_scale
|
||||
module_param += lora_param_weight.to(dtype=dtype)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user