mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-04-03 07:21:32 +08:00
736 lines
34 KiB
Python
736 lines
34 KiB
Python
from contextlib import ExitStack
|
|
from typing import Callable, Iterator, Optional, Tuple, Union
|
|
|
|
import einops
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
import torch
|
|
import torchvision.transforms as tv_transforms
|
|
from PIL import Image
|
|
from torchvision.transforms.functional import resize as tv_resize
|
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
|
|
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
|
from invokeai.app.invocations.fields import (
|
|
DenoiseMaskField,
|
|
FieldDescriptions,
|
|
FluxConditioningField,
|
|
ImageField,
|
|
Input,
|
|
InputField,
|
|
LatentsField,
|
|
WithBoard,
|
|
WithMetadata,
|
|
)
|
|
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
|
|
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
|
|
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
|
from invokeai.app.invocations.model import ControlLoRAField, LoRAField, TransformerField, VAEField
|
|
from invokeai.app.invocations.primitives import LatentsOutput
|
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
|
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
|
|
from invokeai.backend.flux.denoise import denoise
|
|
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
|
|
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
|
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
|
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
|
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
|
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
|
|
from invokeai.backend.flux.model import Flux
|
|
from invokeai.backend.flux.sampling_utils import (
|
|
clip_timestep_schedule_fractional,
|
|
generate_img_ids,
|
|
get_noise,
|
|
get_schedule,
|
|
pack,
|
|
unpack,
|
|
)
|
|
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
|
|
from invokeai.backend.model_manager.config import ModelFormat
|
|
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
|
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
|
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
|
from invokeai.backend.util.devices import TorchDevice
|
|
|
|
|
|
@invocation(
|
|
"flux_denoise",
|
|
title="FLUX Denoise",
|
|
tags=["image", "flux"],
|
|
category="image",
|
|
version="3.2.2",
|
|
classification=Classification.Prototype,
|
|
)
|
|
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|
"""Run denoising process with a FLUX transformer model."""
|
|
|
|
# If latents is provided, this means we are doing image-to-image.
|
|
latents: Optional[LatentsField] = InputField(
|
|
default=None,
|
|
description=FieldDescriptions.latents,
|
|
input=Input.Connection,
|
|
)
|
|
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
|
|
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
|
default=None,
|
|
description=FieldDescriptions.denoise_mask,
|
|
input=Input.Connection,
|
|
)
|
|
denoising_start: float = InputField(
|
|
default=0.0,
|
|
ge=0,
|
|
le=1,
|
|
description=FieldDescriptions.denoising_start,
|
|
)
|
|
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
|
add_noise: bool = InputField(default=True, description="Add noise based on denoising start.")
|
|
transformer: TransformerField = InputField(
|
|
description=FieldDescriptions.flux_model,
|
|
input=Input.Connection,
|
|
title="Transformer",
|
|
)
|
|
control_lora: Optional[ControlLoRAField] = InputField(
|
|
description=FieldDescriptions.control_lora_model, input=Input.Connection, title="Control LoRA", default=None
|
|
)
|
|
positive_text_conditioning: FluxConditioningField | list[FluxConditioningField] = InputField(
|
|
description=FieldDescriptions.positive_cond, input=Input.Connection
|
|
)
|
|
negative_text_conditioning: FluxConditioningField | list[FluxConditioningField] | None = InputField(
|
|
default=None,
|
|
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
|
|
input=Input.Connection,
|
|
)
|
|
cfg_scale: float | list[float] = InputField(default=1.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
|
cfg_scale_start_step: int = InputField(
|
|
default=0,
|
|
title="CFG Scale Start Step",
|
|
description="Index of the first step to apply cfg_scale. Negative indices count backwards from the "
|
|
+ "the last step (e.g. a value of -1 refers to the final step).",
|
|
)
|
|
cfg_scale_end_step: int = InputField(
|
|
default=-1,
|
|
title="CFG Scale End Step",
|
|
description="Index of the last step to apply cfg_scale. Negative indices count backwards from the "
|
|
+ "last step (e.g. a value of -1 refers to the final step).",
|
|
)
|
|
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
|
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
|
num_steps: int = InputField(
|
|
default=4, description="Number of diffusion steps. Recommended values are schnell: 4, dev: 50."
|
|
)
|
|
guidance: float = InputField(
|
|
default=4.0,
|
|
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
|
|
)
|
|
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
|
control: FluxControlNetField | list[FluxControlNetField] | None = InputField(
|
|
default=None, input=Input.Connection, description="ControlNet models."
|
|
)
|
|
controlnet_vae: VAEField | None = InputField(
|
|
default=None,
|
|
description=FieldDescriptions.vae,
|
|
input=Input.Connection,
|
|
)
|
|
|
|
ip_adapter: IPAdapterField | list[IPAdapterField] | None = InputField(
|
|
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
latents = self._run_diffusion(context)
|
|
latents = latents.detach().to("cpu")
|
|
|
|
name = context.tensors.save(tensor=latents)
|
|
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
|
|
|
def _run_diffusion(
|
|
self,
|
|
context: InvocationContext,
|
|
):
|
|
inference_dtype = torch.bfloat16
|
|
|
|
# Load the input latents, if provided.
|
|
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
|
if init_latents is not None:
|
|
init_latents = init_latents.to(device=TorchDevice.choose_torch_device(), dtype=inference_dtype)
|
|
|
|
# Prepare input noise.
|
|
noise = get_noise(
|
|
num_samples=1,
|
|
height=self.height,
|
|
width=self.width,
|
|
device=TorchDevice.choose_torch_device(),
|
|
dtype=inference_dtype,
|
|
seed=self.seed,
|
|
)
|
|
b, _c, latent_h, latent_w = noise.shape
|
|
packed_h = latent_h // 2
|
|
packed_w = latent_w // 2
|
|
|
|
# Load the conditioning data.
|
|
pos_text_conditionings = self._load_text_conditioning(
|
|
context=context,
|
|
cond_field=self.positive_text_conditioning,
|
|
packed_height=packed_h,
|
|
packed_width=packed_w,
|
|
dtype=inference_dtype,
|
|
device=TorchDevice.choose_torch_device(),
|
|
)
|
|
neg_text_conditionings: list[FluxTextConditioning] | None = None
|
|
if self.negative_text_conditioning is not None:
|
|
neg_text_conditionings = self._load_text_conditioning(
|
|
context=context,
|
|
cond_field=self.negative_text_conditioning,
|
|
packed_height=packed_h,
|
|
packed_width=packed_w,
|
|
dtype=inference_dtype,
|
|
device=TorchDevice.choose_torch_device(),
|
|
)
|
|
pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(
|
|
pos_text_conditionings, img_seq_len=packed_h * packed_w
|
|
)
|
|
neg_regional_prompting_extension = (
|
|
RegionalPromptingExtension.from_text_conditioning(neg_text_conditionings, img_seq_len=packed_h * packed_w)
|
|
if neg_text_conditionings
|
|
else None
|
|
)
|
|
|
|
transformer_info = context.models.load(self.transformer.transformer)
|
|
is_schnell = "schnell" in getattr(transformer_info.config, "config_path", "")
|
|
|
|
# Calculate the timestep schedule.
|
|
timesteps = get_schedule(
|
|
num_steps=self.num_steps,
|
|
image_seq_len=packed_h * packed_w,
|
|
shift=not is_schnell,
|
|
)
|
|
|
|
# Clip the timesteps schedule based on denoising_start and denoising_end.
|
|
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
|
|
|
|
# Prepare input latent image.
|
|
if init_latents is not None:
|
|
# If init_latents is provided, we are doing image-to-image.
|
|
|
|
if is_schnell:
|
|
context.logger.warning(
|
|
"Running image-to-image with a FLUX schnell model. This is not recommended. The results are likely "
|
|
"to be poor. Consider using a FLUX dev model instead."
|
|
)
|
|
|
|
if self.add_noise:
|
|
# Noise the orig_latents by the appropriate amount for the first timestep.
|
|
t_0 = timesteps[0]
|
|
x = t_0 * noise + (1.0 - t_0) * init_latents
|
|
else:
|
|
x = init_latents
|
|
else:
|
|
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
|
|
if self.denoising_start > 1e-5:
|
|
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
|
|
|
|
x = noise
|
|
|
|
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
|
|
# denoising steps.
|
|
if len(timesteps) <= 1:
|
|
return x
|
|
|
|
if is_schnell and self.control_lora:
|
|
raise ValueError("Control LoRAs cannot be used with FLUX Schnell")
|
|
|
|
# Prepare the extra image conditioning tensor if a FLUX structural control image is provided.
|
|
img_cond = self._prep_structural_control_img_cond(context)
|
|
|
|
inpaint_mask = self._prep_inpaint_mask(context, x)
|
|
|
|
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
|
|
|
|
# Pack all latent tensors.
|
|
init_latents = pack(init_latents) if init_latents is not None else None
|
|
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
|
|
img_cond = pack(img_cond) if img_cond is not None else None
|
|
noise = pack(noise)
|
|
x = pack(x)
|
|
|
|
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len, packed_h, and
|
|
# packed_w correctly.
|
|
assert packed_h * packed_w == x.shape[1]
|
|
|
|
# Prepare inpaint extension.
|
|
inpaint_extension: InpaintExtension | None = None
|
|
if inpaint_mask is not None:
|
|
assert init_latents is not None
|
|
inpaint_extension = InpaintExtension(
|
|
init_latents=init_latents,
|
|
inpaint_mask=inpaint_mask,
|
|
noise=noise,
|
|
)
|
|
|
|
# Compute the IP-Adapter image prompt clip embeddings.
|
|
# We do this before loading other models to minimize peak memory.
|
|
# TODO(ryand): We should really do this in a separate invocation to benefit from caching.
|
|
ip_adapter_fields = self._normalize_ip_adapter_fields()
|
|
pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds = self._prep_ip_adapter_image_prompt_clip_embeds(
|
|
ip_adapter_fields, context
|
|
)
|
|
|
|
cfg_scale = self.prep_cfg_scale(
|
|
cfg_scale=self.cfg_scale,
|
|
timesteps=timesteps,
|
|
cfg_scale_start_step=self.cfg_scale_start_step,
|
|
cfg_scale_end_step=self.cfg_scale_end_step,
|
|
)
|
|
|
|
with ExitStack() as exit_stack:
|
|
# Prepare ControlNet extensions.
|
|
# Note: We do this before loading the transformer model to minimize peak memory (see implementation).
|
|
controlnet_extensions = self._prep_controlnet_extensions(
|
|
context=context,
|
|
exit_stack=exit_stack,
|
|
latent_height=latent_h,
|
|
latent_width=latent_w,
|
|
dtype=inference_dtype,
|
|
device=x.device,
|
|
)
|
|
|
|
# Load the transformer model.
|
|
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
|
|
assert isinstance(transformer, Flux)
|
|
config = transformer_info.config
|
|
assert config is not None
|
|
|
|
# Determine if the model is quantized.
|
|
# If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
|
|
# slower inference than direct patching, but is agnostic to the quantization format.
|
|
if config.format in [ModelFormat.Checkpoint]:
|
|
model_is_quantized = False
|
|
elif config.format in [
|
|
ModelFormat.BnbQuantizedLlmInt8b,
|
|
ModelFormat.BnbQuantizednf4b,
|
|
ModelFormat.GGUFQuantized,
|
|
]:
|
|
model_is_quantized = True
|
|
else:
|
|
raise ValueError(f"Unsupported model format: {config.format}")
|
|
|
|
# Apply LoRA models to the transformer.
|
|
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
|
exit_stack.enter_context(
|
|
LayerPatcher.apply_smart_model_patches(
|
|
model=transformer,
|
|
patches=self._lora_iterator(context),
|
|
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
|
dtype=inference_dtype,
|
|
cached_weights=cached_weights,
|
|
force_sidecar_patching=model_is_quantized,
|
|
)
|
|
)
|
|
|
|
# Prepare IP-Adapter extensions.
|
|
pos_ip_adapter_extensions, neg_ip_adapter_extensions = self._prep_ip_adapter_extensions(
|
|
pos_image_prompt_clip_embeds=pos_image_prompt_clip_embeds,
|
|
neg_image_prompt_clip_embeds=neg_image_prompt_clip_embeds,
|
|
ip_adapter_fields=ip_adapter_fields,
|
|
context=context,
|
|
exit_stack=exit_stack,
|
|
dtype=inference_dtype,
|
|
)
|
|
|
|
x = denoise(
|
|
model=transformer,
|
|
img=x,
|
|
img_ids=img_ids,
|
|
pos_regional_prompting_extension=pos_regional_prompting_extension,
|
|
neg_regional_prompting_extension=neg_regional_prompting_extension,
|
|
timesteps=timesteps,
|
|
step_callback=self._build_step_callback(context),
|
|
guidance=self.guidance,
|
|
cfg_scale=cfg_scale,
|
|
inpaint_extension=inpaint_extension,
|
|
controlnet_extensions=controlnet_extensions,
|
|
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
|
|
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
|
|
img_cond=img_cond,
|
|
)
|
|
|
|
x = unpack(x.float(), self.height, self.width)
|
|
return x
|
|
|
|
def _load_text_conditioning(
|
|
self,
|
|
context: InvocationContext,
|
|
cond_field: FluxConditioningField | list[FluxConditioningField],
|
|
packed_height: int,
|
|
packed_width: int,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
) -> list[FluxTextConditioning]:
|
|
"""Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields."""
|
|
# Normalize to a list of FluxConditioningFields.
|
|
cond_list = [cond_field] if isinstance(cond_field, FluxConditioningField) else cond_field
|
|
|
|
text_conditionings: list[FluxTextConditioning] = []
|
|
for cond_field in cond_list:
|
|
# Load the text embeddings.
|
|
cond_data = context.conditioning.load(cond_field.conditioning_name)
|
|
assert len(cond_data.conditionings) == 1
|
|
flux_conditioning = cond_data.conditionings[0]
|
|
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
|
flux_conditioning = flux_conditioning.to(dtype=dtype, device=device)
|
|
t5_embeddings = flux_conditioning.t5_embeds
|
|
clip_embeddings = flux_conditioning.clip_embeds
|
|
|
|
# Load the mask, if provided.
|
|
mask: Optional[torch.Tensor] = None
|
|
if cond_field.mask is not None:
|
|
mask = context.tensors.load(cond_field.mask.tensor_name)
|
|
mask = mask.to(device=device)
|
|
mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(
|
|
mask, packed_height, packed_width, dtype, device
|
|
)
|
|
|
|
text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask))
|
|
|
|
return text_conditionings
|
|
|
|
@classmethod
|
|
def prep_cfg_scale(
|
|
cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int
|
|
) -> list[float]:
|
|
"""Prepare the cfg_scale schedule.
|
|
|
|
- Clips the cfg_scale schedule based on cfg_scale_start_step and cfg_scale_end_step.
|
|
- If cfg_scale is a list, then it is assumed to be a schedule and is returned as-is.
|
|
- If cfg_scale is a scalar, then a linear schedule is created from cfg_scale_start_step to cfg_scale_end_step.
|
|
"""
|
|
# num_steps is the number of denoising steps, which is one less than the number of timesteps.
|
|
num_steps = len(timesteps) - 1
|
|
|
|
# Normalize cfg_scale to a list if it is a scalar.
|
|
cfg_scale_list: list[float]
|
|
if isinstance(cfg_scale, float):
|
|
cfg_scale_list = [cfg_scale] * num_steps
|
|
elif isinstance(cfg_scale, list):
|
|
cfg_scale_list = cfg_scale
|
|
else:
|
|
raise ValueError(f"Unsupported cfg_scale type: {type(cfg_scale)}")
|
|
assert len(cfg_scale_list) == num_steps
|
|
|
|
# Handle negative indices for cfg_scale_start_step and cfg_scale_end_step.
|
|
start_step_index = cfg_scale_start_step
|
|
if start_step_index < 0:
|
|
start_step_index = num_steps + start_step_index
|
|
end_step_index = cfg_scale_end_step
|
|
if end_step_index < 0:
|
|
end_step_index = num_steps + end_step_index
|
|
|
|
# Validate the start and end step indices.
|
|
if not (0 <= start_step_index < num_steps):
|
|
raise ValueError(f"Invalid cfg_scale_start_step. Out of range: {cfg_scale_start_step}.")
|
|
if not (0 <= end_step_index < num_steps):
|
|
raise ValueError(f"Invalid cfg_scale_end_step. Out of range: {cfg_scale_end_step}.")
|
|
if start_step_index > end_step_index:
|
|
raise ValueError(
|
|
f"cfg_scale_start_step ({cfg_scale_start_step}) must be before cfg_scale_end_step "
|
|
+ f"({cfg_scale_end_step})."
|
|
)
|
|
|
|
# Set values outside the start and end step indices to 1.0. This is equivalent to disabling cfg_scale for those
|
|
# steps.
|
|
clipped_cfg_scale = [1.0] * num_steps
|
|
clipped_cfg_scale[start_step_index : end_step_index + 1] = cfg_scale_list[start_step_index : end_step_index + 1]
|
|
|
|
return clipped_cfg_scale
|
|
|
|
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
|
"""Prepare the inpaint mask.
|
|
|
|
- Loads the mask
|
|
- Resizes if necessary
|
|
- Casts to same device/dtype as latents
|
|
- Expands mask to the same shape as latents so that they line up after 'packing'
|
|
|
|
Args:
|
|
context (InvocationContext): The invocation context, for loading the inpaint mask.
|
|
latents (torch.Tensor): A latent image tensor. In 'unpacked' format. Used to determine the target shape,
|
|
device, and dtype for the inpaint mask.
|
|
|
|
Returns:
|
|
torch.Tensor | None: Inpaint mask. Values of 0.0 represent the regions to be fully denoised, and 1.0
|
|
represent the regions to be preserved.
|
|
"""
|
|
if self.denoise_mask is None:
|
|
return None
|
|
|
|
mask = context.tensors.load(self.denoise_mask.mask_name)
|
|
|
|
# The input denoise_mask contains values in [0, 1], where 0.0 represents the regions to be fully denoised, and
|
|
# 1.0 represents the regions to be preserved.
|
|
# We invert the mask so that the regions to be preserved are 0.0 and the regions to be denoised are 1.0.
|
|
mask = 1.0 - mask
|
|
|
|
_, _, latent_height, latent_width = latents.shape
|
|
mask = tv_resize(
|
|
img=mask,
|
|
size=[latent_height, latent_width],
|
|
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
|
antialias=False,
|
|
)
|
|
|
|
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
|
|
|
# Expand the inpaint mask to the same shape as `latents` so that when we 'pack' `mask` it lines up with
|
|
# `latents`.
|
|
return mask.expand_as(latents)
|
|
|
|
def _prep_controlnet_extensions(
|
|
self,
|
|
context: InvocationContext,
|
|
exit_stack: ExitStack,
|
|
latent_height: int,
|
|
latent_width: int,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
) -> list[XLabsControlNetExtension | InstantXControlNetExtension]:
|
|
# Normalize the controlnet input to list[ControlField].
|
|
controlnets: list[FluxControlNetField]
|
|
if self.control is None:
|
|
controlnets = []
|
|
elif isinstance(self.control, FluxControlNetField):
|
|
controlnets = [self.control]
|
|
elif isinstance(self.control, list):
|
|
controlnets = self.control
|
|
else:
|
|
raise ValueError(f"Unsupported controlnet type: {type(self.control)}")
|
|
|
|
# TODO(ryand): Add a field to the model config so that we can distinguish between XLabs and InstantX ControlNets
|
|
# before loading the models. Then make sure that all VAE encoding is done before loading the ControlNets to
|
|
# minimize peak memory.
|
|
|
|
# First, load the ControlNet models so that we can determine the ControlNet types.
|
|
controlnet_models = [context.models.load(controlnet.control_model) for controlnet in controlnets]
|
|
|
|
# Calculate the controlnet conditioning tensors.
|
|
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
|
|
# keep peak memory down.
|
|
controlnet_conds: list[torch.Tensor] = []
|
|
for controlnet, controlnet_model in zip(controlnets, controlnet_models, strict=True):
|
|
image = context.images.get_pil(controlnet.image.image_name)
|
|
if isinstance(controlnet_model.model, InstantXControlNetFlux):
|
|
if self.controlnet_vae is None:
|
|
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
|
|
vae_info = context.models.load(self.controlnet_vae.vae)
|
|
controlnet_conds.append(
|
|
InstantXControlNetExtension.prepare_controlnet_cond(
|
|
controlnet_image=image,
|
|
vae_info=vae_info,
|
|
latent_height=latent_height,
|
|
latent_width=latent_width,
|
|
dtype=dtype,
|
|
device=device,
|
|
resize_mode=controlnet.resize_mode,
|
|
)
|
|
)
|
|
elif isinstance(controlnet_model.model, XLabsControlNetFlux):
|
|
controlnet_conds.append(
|
|
XLabsControlNetExtension.prepare_controlnet_cond(
|
|
controlnet_image=image,
|
|
latent_height=latent_height,
|
|
latent_width=latent_width,
|
|
dtype=dtype,
|
|
device=device,
|
|
resize_mode=controlnet.resize_mode,
|
|
)
|
|
)
|
|
|
|
# Finally, load the ControlNet models and initialize the ControlNet extensions.
|
|
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension] = []
|
|
for controlnet, controlnet_cond, controlnet_model in zip(
|
|
controlnets, controlnet_conds, controlnet_models, strict=True
|
|
):
|
|
model = exit_stack.enter_context(controlnet_model)
|
|
|
|
if isinstance(model, XLabsControlNetFlux):
|
|
controlnet_extensions.append(
|
|
XLabsControlNetExtension(
|
|
model=model,
|
|
controlnet_cond=controlnet_cond,
|
|
weight=controlnet.control_weight,
|
|
begin_step_percent=controlnet.begin_step_percent,
|
|
end_step_percent=controlnet.end_step_percent,
|
|
)
|
|
)
|
|
elif isinstance(model, InstantXControlNetFlux):
|
|
instantx_control_mode: torch.Tensor | None = None
|
|
if controlnet.instantx_control_mode is not None and controlnet.instantx_control_mode >= 0:
|
|
instantx_control_mode = torch.tensor(controlnet.instantx_control_mode, dtype=torch.long)
|
|
instantx_control_mode = instantx_control_mode.reshape([-1, 1])
|
|
|
|
controlnet_extensions.append(
|
|
InstantXControlNetExtension(
|
|
model=model,
|
|
controlnet_cond=controlnet_cond,
|
|
instantx_control_mode=instantx_control_mode,
|
|
weight=controlnet.control_weight,
|
|
begin_step_percent=controlnet.begin_step_percent,
|
|
end_step_percent=controlnet.end_step_percent,
|
|
)
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported ControlNet model type: {type(model)}")
|
|
|
|
return controlnet_extensions
|
|
|
|
def _prep_structural_control_img_cond(self, context: InvocationContext) -> torch.Tensor | None:
|
|
if self.control_lora is None:
|
|
return None
|
|
|
|
if not self.controlnet_vae:
|
|
raise ValueError("controlnet_vae must be set when using a FLUX Control LoRA.")
|
|
|
|
# Load the conditioning image and resize it to the target image size.
|
|
cond_img = context.images.get_pil(self.control_lora.img.image_name)
|
|
cond_img = cond_img.convert("RGB")
|
|
cond_img = cond_img.resize((self.width, self.height), Image.Resampling.BICUBIC)
|
|
cond_img = np.array(cond_img)
|
|
|
|
# Normalize the conditioning image to the range [-1, 1].
|
|
# This normalization is based on the original implementations here:
|
|
# https://github.com/black-forest-labs/flux/blob/805da8571a0b49b6d4043950bd266a65328c243b/src/flux/modules/image_embedders.py#L34
|
|
# https://github.com/black-forest-labs/flux/blob/805da8571a0b49b6d4043950bd266a65328c243b/src/flux/modules/image_embedders.py#L60
|
|
img_cond = torch.from_numpy(cond_img).float() / 127.5 - 1.0
|
|
img_cond = einops.rearrange(img_cond, "h w c -> 1 c h w")
|
|
|
|
vae_info = context.models.load(self.controlnet_vae.vae)
|
|
return FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=img_cond)
|
|
|
|
def _normalize_ip_adapter_fields(self) -> list[IPAdapterField]:
|
|
if self.ip_adapter is None:
|
|
return []
|
|
elif isinstance(self.ip_adapter, IPAdapterField):
|
|
return [self.ip_adapter]
|
|
elif isinstance(self.ip_adapter, list):
|
|
return self.ip_adapter
|
|
else:
|
|
raise ValueError(f"Unsupported IP-Adapter type: {type(self.ip_adapter)}")
|
|
|
|
def _prep_ip_adapter_image_prompt_clip_embeds(
|
|
self,
|
|
ip_adapter_fields: list[IPAdapterField],
|
|
context: InvocationContext,
|
|
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
|
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
|
|
clip_image_processor = CLIPImageProcessor()
|
|
|
|
pos_image_prompt_clip_embeds: list[torch.Tensor] = []
|
|
neg_image_prompt_clip_embeds: list[torch.Tensor] = []
|
|
for ip_adapter_field in ip_adapter_fields:
|
|
# `ip_adapter_field.image` could be a list or a single ImageField. Normalize to a list here.
|
|
ipa_image_fields: list[ImageField]
|
|
if isinstance(ip_adapter_field.image, ImageField):
|
|
ipa_image_fields = [ip_adapter_field.image]
|
|
elif isinstance(ip_adapter_field.image, list):
|
|
ipa_image_fields = ip_adapter_field.image
|
|
else:
|
|
raise ValueError(f"Unsupported IP-Adapter image type: {type(ip_adapter_field.image)}")
|
|
|
|
if len(ipa_image_fields) != 1:
|
|
raise ValueError(
|
|
f"FLUX IP-Adapter only supports a single image prompt (received {len(ipa_image_fields)})."
|
|
)
|
|
|
|
ipa_images = [context.images.get_pil(image.image_name, mode="RGB") for image in ipa_image_fields]
|
|
|
|
pos_images: list[npt.NDArray[np.uint8]] = []
|
|
neg_images: list[npt.NDArray[np.uint8]] = []
|
|
for ipa_image in ipa_images:
|
|
assert ipa_image.mode == "RGB"
|
|
pos_image = np.array(ipa_image)
|
|
# We use a black image as the negative image prompt for parity with
|
|
# https://github.com/XLabs-AI/x-flux-comfyui/blob/45c834727dd2141aebc505ae4b01f193a8414e38/nodes.py#L592-L593
|
|
# An alternative scheme would be to apply zeros_like() after calling the clip_image_processor.
|
|
neg_image = np.zeros_like(pos_image)
|
|
pos_images.append(pos_image)
|
|
neg_images.append(neg_image)
|
|
|
|
with context.models.load(ip_adapter_field.image_encoder_model) as image_encoder_model:
|
|
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
|
|
|
|
clip_image: torch.Tensor = clip_image_processor(images=pos_images, return_tensors="pt").pixel_values
|
|
clip_image = clip_image.to(device=image_encoder_model.device, dtype=image_encoder_model.dtype)
|
|
pos_clip_image_embeds = image_encoder_model(clip_image).image_embeds
|
|
|
|
clip_image = clip_image_processor(images=neg_images, return_tensors="pt").pixel_values
|
|
clip_image = clip_image.to(device=image_encoder_model.device, dtype=image_encoder_model.dtype)
|
|
neg_clip_image_embeds = image_encoder_model(clip_image).image_embeds
|
|
|
|
pos_image_prompt_clip_embeds.append(pos_clip_image_embeds)
|
|
neg_image_prompt_clip_embeds.append(neg_clip_image_embeds)
|
|
|
|
return pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds
|
|
|
|
def _prep_ip_adapter_extensions(
|
|
self,
|
|
ip_adapter_fields: list[IPAdapterField],
|
|
pos_image_prompt_clip_embeds: list[torch.Tensor],
|
|
neg_image_prompt_clip_embeds: list[torch.Tensor],
|
|
context: InvocationContext,
|
|
exit_stack: ExitStack,
|
|
dtype: torch.dtype,
|
|
) -> tuple[list[XLabsIPAdapterExtension], list[XLabsIPAdapterExtension]]:
|
|
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension] = []
|
|
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension] = []
|
|
for ip_adapter_field, pos_image_prompt_clip_embed, neg_image_prompt_clip_embed in zip(
|
|
ip_adapter_fields, pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds, strict=True
|
|
):
|
|
ip_adapter_model = exit_stack.enter_context(context.models.load(ip_adapter_field.ip_adapter_model))
|
|
assert isinstance(ip_adapter_model, XlabsIpAdapterFlux)
|
|
ip_adapter_model = ip_adapter_model.to(dtype=dtype)
|
|
if ip_adapter_field.mask is not None:
|
|
raise ValueError("IP-Adapter masks are not yet supported in Flux.")
|
|
ip_adapter_extension = XLabsIPAdapterExtension(
|
|
model=ip_adapter_model,
|
|
image_prompt_clip_embed=pos_image_prompt_clip_embed,
|
|
weight=ip_adapter_field.weight,
|
|
begin_step_percent=ip_adapter_field.begin_step_percent,
|
|
end_step_percent=ip_adapter_field.end_step_percent,
|
|
)
|
|
ip_adapter_extension.run_image_proj(dtype=dtype)
|
|
pos_ip_adapter_extensions.append(ip_adapter_extension)
|
|
|
|
ip_adapter_extension = XLabsIPAdapterExtension(
|
|
model=ip_adapter_model,
|
|
image_prompt_clip_embed=neg_image_prompt_clip_embed,
|
|
weight=ip_adapter_field.weight,
|
|
begin_step_percent=ip_adapter_field.begin_step_percent,
|
|
end_step_percent=ip_adapter_field.end_step_percent,
|
|
)
|
|
ip_adapter_extension.run_image_proj(dtype=dtype)
|
|
neg_ip_adapter_extensions.append(ip_adapter_extension)
|
|
|
|
return pos_ip_adapter_extensions, neg_ip_adapter_extensions
|
|
|
|
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
|
loras: list[Union[LoRAField, ControlLoRAField]] = [*self.transformer.loras]
|
|
if self.control_lora:
|
|
# Note: Since FLUX structural control LoRAs modify the shape of some weights, it is important that they are
|
|
# applied last.
|
|
loras.append(self.control_lora)
|
|
for lora in loras:
|
|
lora_info = context.models.load(lora.lora)
|
|
assert isinstance(lora_info.model, ModelPatchRaw)
|
|
yield (lora_info.model, lora.weight)
|
|
del lora_info
|
|
|
|
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
|
def step_callback(state: PipelineIntermediateState) -> None:
|
|
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()
|
|
context.util.flux_step_callback(state)
|
|
|
|
return step_callback
|