This commit is contained in:
Ryan Dick 2024-12-12 22:28:44 +00:00
parent 9019026d6d
commit 5422bb74c6
15 changed files with 55 additions and 39 deletions

View File

@ -8,8 +8,6 @@ import torchvision.transforms as tv_transforms
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
DenoiseMaskField,
@ -24,7 +22,7 @@ from invokeai.app.invocations.fields import (
)
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.model import TransformerField, VAEField, StructuralLoRAField, LoRAField
from invokeai.app.invocations.model import LoRAField, StructuralLoRAField, TransformerField, VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
@ -35,8 +33,10 @@ from invokeai.backend.flux.extensions.instantx_controlnet_extension import Insta
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.flux_tools_sampling_utils import prepare_control
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.sampling_utils import (
clip_timestep_schedule_fractional,
generate_img_ids,
@ -45,8 +45,6 @@ from invokeai.backend.flux.sampling_utils import (
pack,
unpack,
)
from invokeai.backend.flux.flux_tools_sampling_utils import prepare_control
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
@ -359,7 +357,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
controlnet_extensions=controlnet_extensions,
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
img_cond=img_cond
img_cond=img_cond,
)
x = unpack(x.float(), self.height, self.width)

View File

@ -82,7 +82,9 @@ class FluxModelLoaderInvocation(BaseInvocation):
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[], structural_loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], structural_loras=[], skipped_layers=0),
clip=CLIPField(
tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], structural_loras=[], skipped_layers=0
),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],

View File

@ -1,4 +1,4 @@
from typing import Optional, Literal
from typing import Optional
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@ -7,8 +7,8 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, ImageField
from invokeai.app.invocations.model import VAEField, StructuralLoRAField, ModelIdentifierField, TransformerField
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField, StructuralLoRAField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
@ -53,7 +53,11 @@ class FluxStructuralLoRALoaderInvocation(BaseInvocation):
raise ValueError(f"Unknown lora: {lora_key}!")
# Check for existing LoRAs with the same key.
if self.transformer and self.transformer.structural_lora and self.transformer.structural_lora.lora.key == lora_key:
if (
self.transformer
and self.transformer.structural_lora
and self.transformer.structural_lora.lora.key == lora_key
):
raise ValueError(f'Structural LoRA "{lora_key}" already applied to transformer.')
output = FluxStructuralLoRALoaderOutput()

View File

@ -1,5 +1,5 @@
import copy
from typing import List, Optional, Literal
from typing import List, Optional
from pydantic import BaseModel, Field
@ -10,7 +10,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, ImageField
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import (
@ -74,13 +74,18 @@ class VAEField(BaseModel):
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
class StructuralLoRAField(LoRAField):
img: ImageField = Field(description="Image to use in structural conditioning")
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
structural_lora: Optional[StructuralLoRAField] = Field(description="Structural LoRAs to apply on model loading", default=None)
structural_lora: Optional[StructuralLoRAField] = Field(
description="Structural LoRAs to apply on model loading", default=None
)
@invocation_output("unet_output")
class UNetOutput(BaseInvocationOutput):

View File

@ -1,10 +1,11 @@
import torch
import numpy as np
from PIL import Image
import torch
from einops import rearrange
from PIL import Image
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
def prepare_control(
height: int,
width: int,

View File

@ -1,10 +1,10 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from dataclasses import dataclass
from typing import Optional
import torch
from torch import Tensor, nn
from typing import Optional
from invokeai.backend.flux.custom_block_processor import (
CustomDoubleStreamBlockProcessor,

View File

@ -1,20 +1,18 @@
import os
import cv2
import numpy as np
import torch
from einops import rearrange, repeat
from PIL import Image
from safetensors.torch import load_file as load_sft
from torch import nn
from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
from transformers import AutoModelForDepthEstimation, AutoProcessor
class DepthImageEncoder:
depth_model_name = "LiheYoung/depth-anything-large-hf"
def __init__(self, device):
self.device = device
self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device)
self.processor = AutoProcessor.from_pretrained(self.depth_model_name)
def __call__(self, img: torch.Tensor) -> torch.Tensor:
hw = img.shape[-2:]
img = torch.clamp(img, -1.0, 1.0)
@ -26,6 +24,7 @@ class DepthImageEncoder:
depth = depth / 127.5 - 1.0
return depth
class CannyImageEncoder:
def __init__(
self,
@ -36,6 +35,7 @@ class CannyImageEncoder:
self.device = device
self.min_t = min_t
self.max_t = max_t
def __call__(self, img: torch.Tensor) -> torch.Tensor:
assert img.shape[0] == 1, "Only batch size 1 is supported"
img = rearrange(img[0], "c h w -> h w c")

View File

@ -1,14 +1,13 @@
import re
from typing import Any, Dict
import torch
from typing import Any, Dict
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
# A regex pattern that matches all of the keys in the Flux Dev/Canny LoRA format.
# Example keys:
@ -17,6 +16,7 @@ from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
# double_blocks.0.img_attn.norm.key_norm.scale
FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX = r"(final_layer|vector_in|txt_in|time_in|img_in|guidance_in|\w+_blocks)(\.(\d+))?\.(lora_(A|B)|(in|out)_layer|adaLN_modulation|img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear|linear1|linear2|modulation|norm)\.?(.*)"
def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
"""Checks if the provided state dict is likely in the FLUX Control LoRA format.
@ -28,6 +28,7 @@ def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
for k in state_dict.keys()
)
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
# converted_state_dict = _convert_lora_bfl_control(state_dict=state_dict)
# Group keys by layer.
@ -54,7 +55,7 @@ def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor])
None,
layer_state_dict["lora_A.weight"],
None,
layer_state_dict["lora_B.bias"]
layer_state_dict["lora_B.bias"],
)
elif "scale" in layer_state_dict:
layers[prefixed_key] = SetParameterLayer("scale", layer_state_dict["scale"])
@ -62,4 +63,3 @@ def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor])
raise AssertionError(f"{layer_key} not expected")
# Create and return the LoRAModelRaw.
return LoRAModelRaw(layers=layers)

View File

@ -9,4 +9,6 @@ from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer, SetParameterLayer]
AnyLoRALayer = Union[
LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer, SetParameterLayer
]

View File

@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Optional
import torch
@ -7,7 +7,7 @@ from invokeai.backend.util.calc_tensor_size import calc_tensor_size
class ReshapeWeightLayer(LoRALayerBase):
# TODO: Just everything in this class
# TODO: Just everything in this class
def __init__(self, weight: Optional[torch.Tensor], bias: Optional[torch.Tensor], scale: Optional[torch.Tensor]):
super().__init__(alpha=None, bias=bias)
self.weight = torch.nn.Parameter(weight) if weight is not None else None

View File

@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Dict
import torch
@ -17,7 +17,7 @@ class SetParameterLayer(LoRALayerBase):
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight - orig_weight
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
return {self.param_name: self.get_weight(orig_module.get_parameter(self.param_name))}

View File

@ -9,7 +9,6 @@ from invokeai.backend.lora.layers.loha_layer import LoHALayer
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:

View File

@ -9,13 +9,17 @@ import torch
from safetensors.torch import load_file
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.lora.conversions.flux_control_lora_utils import (
is_state_dict_likely_flux_control,
lora_model_from_flux_control_state_dict,
)
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
lora_model_from_flux_diffusers_state_dict,
)
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
is_state_dict_likely_in_flux_kohya_format, lora_model_from_flux_kohya_state_dict,
is_state_dict_likely_in_flux_kohya_format,
lora_model_from_flux_kohya_state_dict,
)
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control, lora_model_from_flux_control_state_dict
from invokeai.backend.lora.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
from invokeai.backend.model_manager import (

View File

@ -15,10 +15,10 @@ from invokeai.backend.flux.controlnet.state_dict_utils import (
is_state_dict_xlabs_controlnet,
)
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
)
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.config import (

View File

@ -23,6 +23,7 @@ def test_is_state_dict_likely_in_flux_control_format_true(sd_keys: dict[str, lis
assert is_state_dict_likely_flux_control(state_dict)
@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys])
def test_is_state_dict_likely_in_flux_control_format_false(sd_keys: dict[str, list[int]]):
"""Test that is_state_dict_likely_flux_control() returns False for a state dict that is in the Diffusers