mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-07 03:17:05 +08:00
ruff
This commit is contained in:
parent
9019026d6d
commit
5422bb74c6
@ -8,8 +8,6 @@ import torchvision.transforms as tv_transforms
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
@ -24,7 +22,7 @@ from invokeai.app.invocations.fields import (
|
||||
)
|
||||
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField, StructuralLoRAField, LoRAField
|
||||
from invokeai.app.invocations.model import LoRAField, StructuralLoRAField, TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
||||
@ -35,8 +33,10 @@ from invokeai.backend.flux.extensions.instantx_controlnet_extension import Insta
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.flux_tools_sampling_utils import prepare_control
|
||||
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.sampling_utils import (
|
||||
clip_timestep_schedule_fractional,
|
||||
generate_img_ids,
|
||||
@ -45,8 +45,6 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.flux.flux_tools_sampling_utils import prepare_control
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
@ -359,7 +357,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
controlnet_extensions=controlnet_extensions,
|
||||
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
|
||||
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
|
||||
img_cond=img_cond
|
||||
img_cond=img_cond,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
|
@ -82,7 +82,9 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[], structural_loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], structural_loras=[], skipped_layers=0),
|
||||
clip=CLIPField(
|
||||
tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], structural_loras=[], skipped_layers=0
|
||||
),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Literal
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
@ -7,8 +7,8 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, ImageField
|
||||
from invokeai.app.invocations.model import VAEField, StructuralLoRAField, ModelIdentifierField, TransformerField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField, StructuralLoRAField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@ -53,7 +53,11 @@ class FluxStructuralLoRALoaderInvocation(BaseInvocation):
|
||||
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||
|
||||
# Check for existing LoRAs with the same key.
|
||||
if self.transformer and self.transformer.structural_lora and self.transformer.structural_lora.lora.key == lora_key:
|
||||
if (
|
||||
self.transformer
|
||||
and self.transformer.structural_lora
|
||||
and self.transformer.structural_lora.lora.key == lora_key
|
||||
):
|
||||
raise ValueError(f'Structural LoRA "{lora_key}" already applied to transformer.')
|
||||
|
||||
output = FluxStructuralLoRALoaderOutput()
|
||||
|
@ -1,5 +1,5 @@
|
||||
import copy
|
||||
from typing import List, Optional, Literal
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -10,7 +10,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, ImageField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@ -74,13 +74,18 @@ class VAEField(BaseModel):
|
||||
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
|
||||
|
||||
class StructuralLoRAField(LoRAField):
|
||||
img: ImageField = Field(description="Image to use in structural conditioning")
|
||||
|
||||
|
||||
class TransformerField(BaseModel):
|
||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
structural_lora: Optional[StructuralLoRAField] = Field(description="Structural LoRAs to apply on model loading", default=None)
|
||||
structural_lora: Optional[StructuralLoRAField] = Field(
|
||||
description="Structural LoRAs to apply on model loading", default=None
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("unet_output")
|
||||
class UNetOutput(BaseInvocationOutput):
|
||||
|
@ -1,10 +1,11 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
|
||||
|
||||
def prepare_control(
|
||||
height: int,
|
||||
width: int,
|
||||
|
@ -1,10 +1,10 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.flux.custom_block_processor import (
|
||||
CustomDoubleStreamBlockProcessor,
|
||||
|
@ -1,20 +1,18 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file as load_sft
|
||||
from torch import nn
|
||||
from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
|
||||
from transformers import AutoModelForDepthEstimation, AutoProcessor
|
||||
|
||||
|
||||
class DepthImageEncoder:
|
||||
depth_model_name = "LiheYoung/depth-anything-large-hf"
|
||||
|
||||
def __init__(self, device):
|
||||
self.device = device
|
||||
self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device)
|
||||
self.processor = AutoProcessor.from_pretrained(self.depth_model_name)
|
||||
|
||||
def __call__(self, img: torch.Tensor) -> torch.Tensor:
|
||||
hw = img.shape[-2:]
|
||||
img = torch.clamp(img, -1.0, 1.0)
|
||||
@ -26,6 +24,7 @@ class DepthImageEncoder:
|
||||
depth = depth / 127.5 - 1.0
|
||||
return depth
|
||||
|
||||
|
||||
class CannyImageEncoder:
|
||||
def __init__(
|
||||
self,
|
||||
@ -36,6 +35,7 @@ class CannyImageEncoder:
|
||||
self.device = device
|
||||
self.min_t = min_t
|
||||
self.max_t = max_t
|
||||
|
||||
def __call__(self, img: torch.Tensor) -> torch.Tensor:
|
||||
assert img.shape[0] == 1, "Only batch size 1 is supported"
|
||||
img = rearrange(img[0], "c h w -> h w c")
|
||||
|
@ -1,14 +1,13 @@
|
||||
import re
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from typing import Any, Dict
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
|
||||
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
|
||||
# A regex pattern that matches all of the keys in the Flux Dev/Canny LoRA format.
|
||||
# Example keys:
|
||||
@ -17,6 +16,7 @@ from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
|
||||
# double_blocks.0.img_attn.norm.key_norm.scale
|
||||
FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX = r"(final_layer|vector_in|txt_in|time_in|img_in|guidance_in|\w+_blocks)(\.(\d+))?\.(lora_(A|B)|(in|out)_layer|adaLN_modulation|img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear|linear1|linear2|modulation|norm)\.?(.*)"
|
||||
|
||||
|
||||
def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the FLUX Control LoRA format.
|
||||
|
||||
@ -28,6 +28,7 @@ def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
|
||||
for k in state_dict.keys()
|
||||
)
|
||||
|
||||
|
||||
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
# converted_state_dict = _convert_lora_bfl_control(state_dict=state_dict)
|
||||
# Group keys by layer.
|
||||
@ -54,7 +55,7 @@ def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor])
|
||||
None,
|
||||
layer_state_dict["lora_A.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_B.bias"]
|
||||
layer_state_dict["lora_B.bias"],
|
||||
)
|
||||
elif "scale" in layer_state_dict:
|
||||
layers[prefixed_key] = SetParameterLayer("scale", layer_state_dict["scale"])
|
||||
@ -62,4 +63,3 @@ def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor])
|
||||
raise AssertionError(f"{layer_key} not expected")
|
||||
# Create and return the LoRAModelRaw.
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
||||
|
@ -9,4 +9,6 @@ from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer, SetParameterLayer]
|
||||
AnyLoRALayer = Union[
|
||||
LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer, SetParameterLayer
|
||||
]
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -7,7 +7,7 @@ from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
class ReshapeWeightLayer(LoRALayerBase):
|
||||
# TODO: Just everything in this class
|
||||
# TODO: Just everything in this class
|
||||
def __init__(self, weight: Optional[torch.Tensor], bias: Optional[torch.Tensor], scale: Optional[torch.Tensor]):
|
||||
super().__init__(alpha=None, bias=bias)
|
||||
self.weight = torch.nn.Parameter(weight) if weight is not None else None
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
@ -17,7 +17,7 @@ class SetParameterLayer(LoRALayerBase):
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight - orig_weight
|
||||
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||
return {self.param_name: self.get_weight(orig_module.get_parameter(self.param_name))}
|
||||
|
||||
|
@ -9,7 +9,6 @@ from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
|
||||
|
||||
|
||||
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
|
||||
|
@ -9,13 +9,17 @@ import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.lora.conversions.flux_control_lora_utils import (
|
||||
is_state_dict_likely_flux_control,
|
||||
lora_model_from_flux_control_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
lora_model_from_flux_diffusers_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_kohya_format, lora_model_from_flux_kohya_state_dict,
|
||||
is_state_dict_likely_in_flux_kohya_format,
|
||||
lora_model_from_flux_kohya_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control, lora_model_from_flux_control_state_dict
|
||||
from invokeai.backend.lora.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
|
||||
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
|
||||
from invokeai.backend.model_manager import (
|
||||
|
@ -15,10 +15,10 @@ from invokeai.backend.flux.controlnet.state_dict_utils import (
|
||||
is_state_dict_xlabs_controlnet,
|
||||
)
|
||||
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
|
||||
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
|
||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.model_manager.config import (
|
||||
|
@ -23,6 +23,7 @@ def test_is_state_dict_likely_in_flux_control_format_true(sd_keys: dict[str, lis
|
||||
|
||||
assert is_state_dict_likely_flux_control(state_dict)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys])
|
||||
def test_is_state_dict_likely_in_flux_control_format_false(sd_keys: dict[str, list[int]]):
|
||||
"""Test that is_state_dict_likely_flux_control() returns False for a state dict that is in the Diffusers
|
||||
|
Loading…
Reference in New Issue
Block a user