2024-09-11 16:33:19 +00:00
from contextlib import ExitStack
2024-12-09 16:01:29 -05:00
from typing import Callable , Iterator , Optional , Tuple , Union
2024-08-29 14:17:08 +00:00
2024-12-16 16:22:54 +00:00
import einops
2024-10-18 18:52:12 +00:00
import numpy as np
import numpy . typing as npt
2024-08-06 21:51:22 +00:00
import torch
2024-08-29 19:05:44 +00:00
import torchvision . transforms as tv_transforms
2024-12-16 16:22:54 +00:00
from PIL import Image
2024-08-29 19:05:44 +00:00
from torchvision . transforms . functional import resize as tv_resize
2024-10-16 18:11:48 +00:00
from transformers import CLIPImageProcessor , CLIPVisionModelWithProjection
2024-08-06 21:51:22 +00:00
2024-08-22 15:29:59 +00:00
from invokeai . app . invocations . baseinvocation import BaseInvocation , Classification , invocation
2024-08-12 18:23:02 +00:00
from invokeai . app . invocations . fields import (
2024-08-29 19:05:44 +00:00
DenoiseMaskField ,
2024-08-12 18:23:02 +00:00
FieldDescriptions ,
2024-08-23 13:50:01 -04:00
FluxConditioningField ,
2024-10-15 22:28:59 +00:00
ImageField ,
2024-08-12 18:23:02 +00:00
Input ,
InputField ,
2024-08-29 14:17:08 +00:00
LatentsField ,
2024-08-12 18:23:02 +00:00
WithBoard ,
WithMetadata ,
)
2024-10-08 21:24:55 +00:00
from invokeai . app . invocations . flux_controlnet import FluxControlNetField
2024-12-16 16:22:54 +00:00
from invokeai . app . invocations . flux_vae_encode import FluxVaeEncodeInvocation
2024-10-15 22:28:59 +00:00
from invokeai . app . invocations . ip_adapter import IPAdapterField
2024-12-12 13:45:07 -05:00
from invokeai . app . invocations . model import ControlLoRAField , LoRAField , TransformerField , VAEField
2024-08-29 14:50:58 +00:00
from invokeai . app . invocations . primitives import LatentsOutput
2024-08-06 21:51:22 +00:00
from invokeai . app . services . shared . invocation_context import InvocationContext
2024-10-07 22:17:06 +00:00
from invokeai . backend . flux . controlnet . instantx_controlnet_flux import InstantXControlNetFlux
2024-10-04 14:28:23 +00:00
from invokeai . backend . flux . controlnet . xlabs_controlnet_flux import XLabsControlNetFlux
2024-08-29 19:05:44 +00:00
from invokeai . backend . flux . denoise import denoise
2024-10-07 19:04:29 +00:00
from invokeai . backend . flux . extensions . inpaint_extension import InpaintExtension
2024-10-07 22:17:06 +00:00
from invokeai . backend . flux . extensions . instantx_controlnet_extension import InstantXControlNetExtension
2024-11-20 19:48:04 +00:00
from invokeai . backend . flux . extensions . regional_prompting_extension import RegionalPromptingExtension
2024-10-07 22:17:06 +00:00
from invokeai . backend . flux . extensions . xlabs_controlnet_extension import XLabsControlNetExtension
2024-10-15 22:28:59 +00:00
from invokeai . backend . flux . extensions . xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai . backend . flux . ip_adapter . xlabs_ip_adapter_flux import XlabsIpAdapterFlux
2024-08-19 10:14:58 -04:00
from invokeai . backend . flux . model import Flux
2024-08-29 19:05:44 +00:00
from invokeai . backend . flux . sampling_utils import (
2024-09-24 23:38:20 +00:00
clip_timestep_schedule_fractional ,
2024-08-29 19:05:44 +00:00
generate_img_ids ,
get_noise ,
get_schedule ,
pack ,
unpack ,
)
2024-11-20 19:48:04 +00:00
from invokeai . backend . flux . text_conditioning import FluxTextConditioning
2024-09-11 16:33:19 +00:00
from invokeai . backend . model_manager . config import ModelFormat
2024-12-17 17:19:12 +00:00
from invokeai . backend . patches . layer_patcher import LayerPatcher
2024-12-13 14:42:29 +00:00
from invokeai . backend . patches . lora_conversions . flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
2024-12-14 15:40:25 +00:00
from invokeai . backend . patches . model_patch_raw import ModelPatchRaw
2024-08-30 12:31:29 -04:00
from invokeai . backend . stable_diffusion . diffusers_pipeline import PipelineIntermediateState
2024-11-20 19:48:04 +00:00
from invokeai . backend . stable_diffusion . diffusion . conditioning_data import FLUXConditioningInfo
2024-08-16 20:22:49 +00:00
from invokeai . backend . util . devices import TorchDevice
2024-08-06 21:51:22 +00:00
2024-08-08 18:23:20 +00:00
2024-08-06 21:51:22 +00:00
@invocation (
2024-08-30 15:34:56 +00:00
" flux_denoise " ,
title = " FLUX Denoise " ,
2024-08-20 15:31:22 -04:00
tags = [ " image " , " flux " ] ,
2024-08-06 21:51:22 +00:00
category = " image " ,
2024-11-30 07:18:22 -06:00
version = " 3.2.2 " ,
2024-08-22 15:29:59 +00:00
classification = Classification . Prototype ,
2024-08-06 21:51:22 +00:00
)
2024-08-30 15:34:56 +00:00
class FluxDenoiseInvocation ( BaseInvocation , WithMetadata , WithBoard ) :
""" Run denoising process with a FLUX transformer model. """
2024-08-06 21:51:22 +00:00
2024-08-29 14:17:08 +00:00
# If latents is provided, this means we are doing image-to-image.
latents : Optional [ LatentsField ] = InputField (
default = None ,
description = FieldDescriptions . latents ,
input = Input . Connection ,
)
2024-08-30 14:46:04 +00:00
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
2024-08-29 19:05:44 +00:00
denoise_mask : Optional [ DenoiseMaskField ] = InputField (
default = None ,
description = FieldDescriptions . denoise_mask ,
input = Input . Connection ,
)
2024-08-29 14:17:08 +00:00
denoising_start : float = InputField (
default = 0.0 ,
ge = 0 ,
le = 1 ,
description = FieldDescriptions . denoising_start ,
)
2024-08-30 19:13:20 +00:00
denoising_end : float = InputField ( default = 1.0 , ge = 0 , le = 1 , description = FieldDescriptions . denoising_end )
2024-11-05 13:23:19 -06:00
add_noise : bool = InputField ( default = True , description = " Add noise based on denoising start. " )
2024-08-12 18:01:42 -04:00
transformer : TransformerField = InputField (
2024-08-21 13:45:22 +00:00
description = FieldDescriptions . flux_model ,
2024-08-12 18:01:42 -04:00
input = Input . Connection ,
title = " Transformer " ,
2024-08-12 14:04:23 -04:00
)
2024-12-12 13:45:07 -05:00
control_lora : Optional [ ControlLoRAField ] = InputField (
2024-12-16 16:22:54 +00:00
description = FieldDescriptions . control_lora_model , input = Input . Connection , title = " Control LoRA " , default = None
2024-12-12 13:45:07 -05:00
)
2024-11-20 18:51:43 +00:00
positive_text_conditioning : FluxConditioningField | list [ FluxConditioningField ] = InputField (
2024-08-12 18:23:02 +00:00
description = FieldDescriptions . positive_cond , input = Input . Connection
)
2024-11-20 18:51:43 +00:00
negative_text_conditioning : FluxConditioningField | list [ FluxConditioningField ] | None = InputField (
2024-10-18 20:31:27 +00:00
default = None ,
description = " Negative conditioning tensor. Can be None if cfg_scale is 1.0. " ,
input = Input . Connection ,
2024-10-16 16:22:35 +00:00
)
2024-10-18 20:14:47 +00:00
cfg_scale : float | list [ float ] = InputField ( default = 1.0 , description = FieldDescriptions . cfg_scale , title = " CFG Scale " )
2024-10-21 14:52:02 +00:00
cfg_scale_start_step : int = InputField (
default = 0 ,
title = " CFG Scale Start Step " ,
description = " Index of the first step to apply cfg_scale. Negative indices count backwards from the "
+ " the last step (e.g. a value of -1 refers to the final step). " ,
)
cfg_scale_end_step : int = InputField (
default = - 1 ,
title = " CFG Scale End Step " ,
description = " Index of the last step to apply cfg_scale. Negative indices count backwards from the "
+ " last step (e.g. a value of -1 refers to the final step). " ,
)
2024-08-06 21:51:22 +00:00
width : int = InputField ( default = 1024 , multiple_of = 16 , description = " Width of the generated image. " )
height : int = InputField ( default = 1024 , multiple_of = 16 , description = " Height of the generated image. " )
2024-08-21 13:45:22 +00:00
num_steps : int = InputField (
2024-08-29 14:50:58 +00:00
default = 4 , description = " Number of diffusion steps. Recommended values are schnell: 4, dev: 50. "
2024-08-21 13:45:22 +00:00
)
2024-08-06 21:51:22 +00:00
guidance : float = InputField (
default = 4.0 ,
2024-08-21 13:45:22 +00:00
description = " The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell. " ,
2024-08-06 21:51:22 +00:00
)
seed : int = InputField ( default = 0 , description = " Randomness seed for reproducibility. " )
2024-10-09 23:38:31 +00:00
control : FluxControlNetField | list [ FluxControlNetField ] | None = InputField (
2024-10-03 16:18:57 +00:00
default = None , input = Input . Connection , description = " ControlNet models. "
)
2024-10-07 22:17:06 +00:00
controlnet_vae : VAEField | None = InputField (
2024-10-22 08:03:45 +10:00
default = None ,
2024-10-07 22:17:06 +00:00
description = FieldDescriptions . vae ,
input = Input . Connection ,
)
2024-08-06 21:51:22 +00:00
2024-10-15 22:28:59 +00:00
ip_adapter : IPAdapterField | list [ IPAdapterField ] | None = InputField (
description = FieldDescriptions . ip_adapter , title = " IP-Adapter " , default = None , input = Input . Connection
)
2024-08-06 21:51:22 +00:00
@torch.no_grad ( )
2024-08-29 14:50:58 +00:00
def invoke ( self , context : InvocationContext ) - > LatentsOutput :
2024-08-28 15:03:08 +00:00
latents = self . _run_diffusion ( context )
2024-08-29 14:50:58 +00:00
latents = latents . detach ( ) . to ( " cpu " )
name = context . tensors . save ( tensor = latents )
return LatentsOutput . build ( latents_name = name , latents = latents , seed = None )
2024-08-06 21:51:22 +00:00
def _run_diffusion (
self ,
context : InvocationContext ,
) :
2024-08-19 13:59:44 -04:00
inference_dtype = torch . bfloat16
2024-08-19 10:14:58 -04:00
2024-08-29 14:17:08 +00:00
# Load the input latents, if provided.
init_latents = context . tensors . load ( self . latents . latents_name ) if self . latents else None
if init_latents is not None :
init_latents = init_latents . to ( device = TorchDevice . choose_torch_device ( ) , dtype = inference_dtype )
2024-08-28 15:03:08 +00:00
2024-08-19 10:14:58 -04:00
# Prepare input noise.
2024-08-29 19:05:44 +00:00
noise = get_noise (
2024-08-19 10:14:58 -04:00
num_samples = 1 ,
height = self . height ,
width = self . width ,
device = TorchDevice . choose_torch_device ( ) ,
dtype = inference_dtype ,
seed = self . seed ,
)
2024-11-20 18:51:43 +00:00
b , _c , latent_h , latent_w = noise . shape
2024-11-21 22:46:25 +00:00
packed_h = latent_h / / 2
packed_w = latent_w / / 2
2024-11-20 18:51:43 +00:00
# Load the conditioning data.
pos_text_conditionings = self . _load_text_conditioning (
context = context ,
cond_field = self . positive_text_conditioning ,
2024-11-21 22:46:25 +00:00
packed_height = packed_h ,
packed_width = packed_w ,
2024-11-20 18:51:43 +00:00
dtype = inference_dtype ,
2024-11-25 16:02:03 +00:00
device = TorchDevice . choose_torch_device ( ) ,
2024-11-20 18:51:43 +00:00
)
neg_text_conditionings : list [ FluxTextConditioning ] | None = None
if self . negative_text_conditioning is not None :
neg_text_conditionings = self . _load_text_conditioning (
context = context ,
cond_field = self . negative_text_conditioning ,
2024-11-21 22:46:25 +00:00
packed_height = packed_h ,
packed_width = packed_w ,
2024-11-20 18:51:43 +00:00
dtype = inference_dtype ,
2024-11-25 16:02:03 +00:00
device = TorchDevice . choose_torch_device ( ) ,
2024-11-20 18:51:43 +00:00
)
2024-11-25 22:04:23 +00:00
pos_regional_prompting_extension = RegionalPromptingExtension . from_text_conditioning (
pos_text_conditionings , img_seq_len = packed_h * packed_w
)
2024-11-20 19:48:04 +00:00
neg_regional_prompting_extension = (
2024-11-25 22:04:23 +00:00
RegionalPromptingExtension . from_text_conditioning ( neg_text_conditionings , img_seq_len = packed_h * packed_w )
2024-11-20 19:48:04 +00:00
if neg_text_conditionings
else None
2024-11-20 18:51:43 +00:00
)
2024-08-19 10:14:58 -04:00
2024-08-29 14:17:08 +00:00
transformer_info = context . models . load ( self . transformer . transformer )
2024-12-12 13:45:07 -05:00
is_schnell = " schnell " in getattr ( transformer_info . config , " config_path " , " " )
2024-08-22 16:03:54 +00:00
2024-08-30 14:46:04 +00:00
# Calculate the timestep schedule.
2024-08-19 10:14:58 -04:00
timesteps = get_schedule (
num_steps = self . num_steps ,
2024-11-21 22:46:25 +00:00
image_seq_len = packed_h * packed_w ,
2024-08-19 10:14:58 -04:00
shift = not is_schnell ,
)
2024-08-30 19:13:20 +00:00
# Clip the timesteps schedule based on denoising_start and denoising_end.
2024-09-24 23:38:20 +00:00
timesteps = clip_timestep_schedule_fractional ( timesteps , self . denoising_start , self . denoising_end )
2024-08-30 19:13:20 +00:00
2024-08-29 19:05:44 +00:00
# Prepare input latent image.
2024-08-30 15:09:55 +00:00
if init_latents is not None :
# If init_latents is provided, we are doing image-to-image.
if is_schnell :
context . logger . warning (
" Running image-to-image with a FLUX schnell model. This is not recommended. The results are likely "
" to be poor. Consider using a FLUX dev model instead. "
)
2024-08-29 14:17:08 +00:00
2024-11-05 13:23:19 -06:00
if self . add_noise :
# Noise the orig_latents by the appropriate amount for the first timestep.
t_0 = timesteps [ 0 ]
x = t_0 * noise + ( 1.0 - t_0 ) * init_latents
else :
x = init_latents
2024-08-29 19:05:44 +00:00
else :
2024-08-30 15:09:55 +00:00
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
if self . denoising_start > 1e-5 :
raise ValueError ( " denoising_start should be 0 when initial latents are not provided. " )
2024-08-29 19:05:44 +00:00
x = noise
2024-08-30 19:13:20 +00:00
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
# denoising steps.
if len ( timesteps ) < = 1 :
return x
2024-12-18 10:00:02 -05:00
if is_schnell and self . control_lora :
raise ValueError ( " Control LoRAs cannot be used with FLUX Schnell " )
2024-12-16 16:22:54 +00:00
# Prepare the extra image conditioning tensor if a FLUX structural control image is provided.
img_cond = self . _prep_structural_control_img_cond ( context )
2024-08-29 19:05:44 +00:00
inpaint_mask = self . _prep_inpaint_mask ( context , x )
2024-10-03 16:18:57 +00:00
img_ids = generate_img_ids ( h = latent_h , w = latent_w , batch_size = b , device = x . device , dtype = x . dtype )
2024-08-29 14:17:08 +00:00
2024-08-29 19:05:44 +00:00
# Pack all latent tensors.
init_latents = pack ( init_latents ) if init_latents is not None else None
inpaint_mask = pack ( inpaint_mask ) if inpaint_mask is not None else None
2024-12-16 16:22:54 +00:00
img_cond = pack ( img_cond ) if img_cond is not None else None
2024-08-29 19:05:44 +00:00
noise = pack ( noise )
x = pack ( x )
2024-11-21 22:46:25 +00:00
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len, packed_h, and
# packed_w correctly.
assert packed_h * packed_w == x . shape [ 1 ]
2024-08-29 19:05:44 +00:00
2024-09-25 15:14:11 +00:00
# Prepare inpaint extension.
inpaint_extension : InpaintExtension | None = None
if inpaint_mask is not None :
assert init_latents is not None
inpaint_extension = InpaintExtension (
2024-08-30 14:46:04 +00:00
init_latents = init_latents ,
inpaint_mask = inpaint_mask ,
2024-09-25 15:14:11 +00:00
noise = noise ,
2024-08-30 14:46:04 +00:00
)
2024-10-15 22:28:59 +00:00
# Compute the IP-Adapter image prompt clip embeddings.
# We do this before loading other models to minimize peak memory.
# TODO(ryand): We should really do this in a separate invocation to benefit from caching.
ip_adapter_fields = self . _normalize_ip_adapter_fields ( )
2024-10-16 18:11:48 +00:00
pos_image_prompt_clip_embeds , neg_image_prompt_clip_embeds = self . _prep_ip_adapter_image_prompt_clip_embeds (
ip_adapter_fields , context
)
2024-10-15 22:28:59 +00:00
2024-10-21 14:52:02 +00:00
cfg_scale = self . prep_cfg_scale (
cfg_scale = self . cfg_scale ,
timesteps = timesteps ,
cfg_scale_start_step = self . cfg_scale_start_step ,
cfg_scale_end_step = self . cfg_scale_end_step ,
)
2024-10-10 00:18:46 +00:00
with ExitStack ( ) as exit_stack :
# Prepare ControlNet extensions.
# Note: We do this before loading the transformer model to minimize peak memory (see implementation).
controlnet_extensions = self . _prep_controlnet_extensions (
context = context ,
exit_stack = exit_stack ,
latent_height = latent_h ,
latent_width = latent_w ,
dtype = inference_dtype ,
device = x . device ,
)
2024-08-16 20:22:49 +00:00
2024-10-10 00:18:46 +00:00
# Load the transformer model.
( cached_weights , transformer ) = exit_stack . enter_context ( transformer_info . model_on_device ( ) )
assert isinstance ( transformer , Flux )
2024-09-11 16:33:19 +00:00
config = transformer_info . config
assert config is not None
2024-12-17 18:33:36 +00:00
# Determine if the model is quantized.
# If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
# slower inference than direct patching, but is agnostic to the quantization format.
2024-09-11 16:33:19 +00:00
if config . format in [ ModelFormat . Checkpoint ] :
2024-12-17 18:33:36 +00:00
model_is_quantized = False
2024-09-30 15:42:51 -04:00
elif config . format in [
ModelFormat . BnbQuantizedLlmInt8b ,
ModelFormat . BnbQuantizednf4b ,
ModelFormat . GGUFQuantized ,
] :
2024-12-17 18:33:36 +00:00
model_is_quantized = True
2024-09-11 16:33:19 +00:00
else :
raise ValueError ( f " Unsupported model format: { config . format } " )
2024-12-17 18:33:36 +00:00
# Apply LoRA models to the transformer.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
exit_stack . enter_context (
LayerPatcher . apply_smart_model_patches (
model = transformer ,
patches = self . _lora_iterator ( context ) ,
prefix = FLUX_LORA_TRANSFORMER_PREFIX ,
dtype = inference_dtype ,
cached_weights = cached_weights ,
force_sidecar_patching = model_is_quantized ,
)
)
2024-10-15 22:28:59 +00:00
# Prepare IP-Adapter extensions.
2024-10-16 18:11:48 +00:00
pos_ip_adapter_extensions , neg_ip_adapter_extensions = self . _prep_ip_adapter_extensions (
pos_image_prompt_clip_embeds = pos_image_prompt_clip_embeds ,
neg_image_prompt_clip_embeds = neg_image_prompt_clip_embeds ,
2024-10-15 22:28:59 +00:00
ip_adapter_fields = ip_adapter_fields ,
context = context ,
exit_stack = exit_stack ,
dtype = inference_dtype ,
)
2024-08-16 20:22:49 +00:00
x = denoise (
model = transformer ,
2024-08-28 15:03:08 +00:00
img = x ,
2024-08-16 20:22:49 +00:00
img_ids = img_ids ,
2024-11-20 19:48:04 +00:00
pos_regional_prompting_extension = pos_regional_prompting_extension ,
neg_regional_prompting_extension = neg_regional_prompting_extension ,
2024-08-16 20:22:49 +00:00
timesteps = timesteps ,
2024-08-30 14:46:04 +00:00
step_callback = self . _build_step_callback ( context ) ,
2024-08-16 20:22:49 +00:00
guidance = self . guidance ,
2024-10-21 14:52:02 +00:00
cfg_scale = cfg_scale ,
2024-09-25 15:14:11 +00:00
inpaint_extension = inpaint_extension ,
2024-10-09 16:12:09 +00:00
controlnet_extensions = controlnet_extensions ,
2024-10-16 18:11:48 +00:00
pos_ip_adapter_extensions = pos_ip_adapter_extensions ,
neg_ip_adapter_extensions = neg_ip_adapter_extensions ,
2024-12-12 13:45:07 -05:00
img_cond = img_cond ,
2024-08-06 21:51:22 +00:00
)
2024-08-16 20:22:49 +00:00
x = unpack ( x . float ( ) , self . height , self . width )
return x
2024-08-29 19:05:44 +00:00
2024-11-20 19:48:04 +00:00
def _load_text_conditioning (
self ,
context : InvocationContext ,
cond_field : FluxConditioningField | list [ FluxConditioningField ] ,
2024-11-21 22:46:25 +00:00
packed_height : int ,
packed_width : int ,
2024-11-20 19:48:04 +00:00
dtype : torch . dtype ,
2024-11-25 16:02:03 +00:00
device : torch . device ,
2024-11-20 19:48:04 +00:00
) - > list [ FluxTextConditioning ] :
""" Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields. """
# Normalize to a list of FluxConditioningFields.
cond_list = [ cond_field ] if isinstance ( cond_field , FluxConditioningField ) else cond_field
text_conditionings : list [ FluxTextConditioning ] = [ ]
for cond_field in cond_list :
# Load the text embeddings.
cond_data = context . conditioning . load ( cond_field . conditioning_name )
assert len ( cond_data . conditionings ) == 1
flux_conditioning = cond_data . conditionings [ 0 ]
assert isinstance ( flux_conditioning , FLUXConditioningInfo )
2024-11-25 16:02:03 +00:00
flux_conditioning = flux_conditioning . to ( dtype = dtype , device = device )
2024-11-20 19:48:04 +00:00
t5_embeddings = flux_conditioning . t5_embeds
clip_embeddings = flux_conditioning . clip_embeds
# Load the mask, if provided.
mask : Optional [ torch . Tensor ] = None
if cond_field . mask is not None :
mask = context . tensors . load ( cond_field . mask . tensor_name )
2024-11-25 16:02:03 +00:00
mask = mask . to ( device = device )
2024-11-25 22:04:23 +00:00
mask = RegionalPromptingExtension . preprocess_regional_prompt_mask (
mask , packed_height , packed_width , dtype , device
)
2024-11-20 19:48:04 +00:00
text_conditionings . append ( FluxTextConditioning ( t5_embeddings , clip_embeddings , mask ) )
return text_conditionings
2024-10-21 14:52:02 +00:00
@classmethod
def prep_cfg_scale (
cls , cfg_scale : float | list [ float ] , timesteps : list [ float ] , cfg_scale_start_step : int , cfg_scale_end_step : int
) - > list [ float ] :
""" Prepare the cfg_scale schedule.
- Clips the cfg_scale schedule based on cfg_scale_start_step and cfg_scale_end_step .
- If cfg_scale is a list , then it is assumed to be a schedule and is returned as - is .
- If cfg_scale is a scalar , then a linear schedule is created from cfg_scale_start_step to cfg_scale_end_step .
"""
# num_steps is the number of denoising steps, which is one less than the number of timesteps.
num_steps = len ( timesteps ) - 1
# Normalize cfg_scale to a list if it is a scalar.
cfg_scale_list : list [ float ]
if isinstance ( cfg_scale , float ) :
cfg_scale_list = [ cfg_scale ] * num_steps
elif isinstance ( cfg_scale , list ) :
cfg_scale_list = cfg_scale
else :
raise ValueError ( f " Unsupported cfg_scale type: { type ( cfg_scale ) } " )
assert len ( cfg_scale_list ) == num_steps
# Handle negative indices for cfg_scale_start_step and cfg_scale_end_step.
start_step_index = cfg_scale_start_step
if start_step_index < 0 :
start_step_index = num_steps + start_step_index
end_step_index = cfg_scale_end_step
if end_step_index < 0 :
end_step_index = num_steps + end_step_index
# Validate the start and end step indices.
if not ( 0 < = start_step_index < num_steps ) :
raise ValueError ( f " Invalid cfg_scale_start_step. Out of range: { cfg_scale_start_step } . " )
if not ( 0 < = end_step_index < num_steps ) :
raise ValueError ( f " Invalid cfg_scale_end_step. Out of range: { cfg_scale_end_step } . " )
if start_step_index > end_step_index :
raise ValueError (
f " cfg_scale_start_step ( { cfg_scale_start_step } ) must be before cfg_scale_end_step "
+ f " ( { cfg_scale_end_step } ). "
)
# Set values outside the start and end step indices to 1.0. This is equivalent to disabling cfg_scale for those
# steps.
clipped_cfg_scale = [ 1.0 ] * num_steps
clipped_cfg_scale [ start_step_index : end_step_index + 1 ] = cfg_scale_list [ start_step_index : end_step_index + 1 ]
return clipped_cfg_scale
2024-08-29 19:05:44 +00:00
def _prep_inpaint_mask ( self , context : InvocationContext , latents : torch . Tensor ) - > torch . Tensor | None :
""" Prepare the inpaint mask.
2024-08-30 14:46:04 +00:00
- Loads the mask
- Resizes if necessary
- Casts to same device / dtype as latents
- Expands mask to the same shape as latents so that they line up after ' packing '
Args :
context ( InvocationContext ) : The invocation context , for loading the inpaint mask .
latents ( torch . Tensor ) : A latent image tensor . In ' unpacked ' format . Used to determine the target shape ,
device , and dtype for the inpaint mask .
2024-08-29 19:05:44 +00:00
Returns :
2024-09-13 21:20:25 +00:00
torch . Tensor | None : Inpaint mask . Values of 0.0 represent the regions to be fully denoised , and 1.0
represent the regions to be preserved .
2024-08-29 19:05:44 +00:00
"""
if self . denoise_mask is None :
return None
mask = context . tensors . load ( self . denoise_mask . mask_name )
2024-08-30 14:46:04 +00:00
2024-09-13 21:20:25 +00:00
# The input denoise_mask contains values in [0, 1], where 0.0 represents the regions to be fully denoised, and
# 1.0 represents the regions to be preserved.
# We invert the mask so that the regions to be preserved are 0.0 and the regions to be denoised are 1.0.
mask = 1.0 - mask
2024-08-29 19:05:44 +00:00
_ , _ , latent_height , latent_width = latents . shape
mask = tv_resize (
img = mask ,
size = [ latent_height , latent_width ] ,
interpolation = tv_transforms . InterpolationMode . BILINEAR ,
antialias = False ,
)
2024-08-30 14:46:04 +00:00
2024-08-29 19:05:44 +00:00
mask = mask . to ( device = latents . device , dtype = latents . dtype )
2024-08-30 14:46:04 +00:00
# Expand the inpaint mask to the same shape as `latents` so that when we 'pack' `mask` it lines up with
# `latents`.
return mask . expand_as ( latents )
2024-10-03 16:18:57 +00:00
def _prep_controlnet_extensions (
self ,
context : InvocationContext ,
exit_stack : ExitStack ,
latent_height : int ,
latent_width : int ,
dtype : torch . dtype ,
device : torch . device ,
2024-10-09 16:12:09 +00:00
) - > list [ XLabsControlNetExtension | InstantXControlNetExtension ] :
2024-10-03 16:18:57 +00:00
# Normalize the controlnet input to list[ControlField].
2024-10-08 21:24:55 +00:00
controlnets : list [ FluxControlNetField ]
2024-10-09 23:38:31 +00:00
if self . control is None :
2024-10-07 22:17:06 +00:00
controlnets = [ ]
2024-10-09 23:38:31 +00:00
elif isinstance ( self . control , FluxControlNetField ) :
controlnets = [ self . control ]
elif isinstance ( self . control , list ) :
controlnets = self . control
2024-10-03 16:18:57 +00:00
else :
2024-10-09 23:38:31 +00:00
raise ValueError ( f " Unsupported controlnet type: { type ( self . control ) } " )
2024-10-03 16:18:57 +00:00
2024-10-07 22:17:06 +00:00
# TODO(ryand): Add a field to the model config so that we can distinguish between XLabs and InstantX ControlNets
# before loading the models. Then make sure that all VAE encoding is done before loading the ControlNets to
# minimize peak memory.
2024-10-10 00:18:46 +00:00
# First, load the ControlNet models so that we can determine the ControlNet types.
controlnet_models = [ context . models . load ( controlnet . control_model ) for controlnet in controlnets ]
2024-10-03 16:18:57 +00:00
2024-10-10 00:18:46 +00:00
# Calculate the controlnet conditioning tensors.
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
# keep peak memory down.
controlnet_conds : list [ torch . Tensor ] = [ ]
for controlnet , controlnet_model in zip ( controlnets , controlnet_models , strict = True ) :
image = context . images . get_pil ( controlnet . image . image_name )
if isinstance ( controlnet_model . model , InstantXControlNetFlux ) :
if self . controlnet_vae is None :
raise ValueError ( " A ControlNet VAE is required when using an InstantX FLUX ControlNet. " )
vae_info = context . models . load ( self . controlnet_vae . vae )
controlnet_conds . append (
InstantXControlNetExtension . prepare_controlnet_cond (
2024-10-07 22:17:06 +00:00
controlnet_image = image ,
2024-10-10 00:18:46 +00:00
vae_info = vae_info ,
2024-10-07 22:17:06 +00:00
latent_height = latent_height ,
latent_width = latent_width ,
dtype = dtype ,
device = device ,
resize_mode = controlnet . resize_mode ,
2024-10-10 00:18:46 +00:00
)
)
elif isinstance ( controlnet_model . model , XLabsControlNetFlux ) :
controlnet_conds . append (
XLabsControlNetExtension . prepare_controlnet_cond (
controlnet_image = image ,
latent_height = latent_height ,
latent_width = latent_width ,
dtype = dtype ,
device = device ,
resize_mode = controlnet . resize_mode ,
)
)
# Finally, load the ControlNet models and initialize the ControlNet extensions.
controlnet_extensions : list [ XLabsControlNetExtension | InstantXControlNetExtension ] = [ ]
for controlnet , controlnet_cond , controlnet_model in zip (
controlnets , controlnet_conds , controlnet_models , strict = True
) :
model = exit_stack . enter_context ( controlnet_model )
if isinstance ( model , XLabsControlNetFlux ) :
controlnet_extensions . append (
XLabsControlNetExtension (
model = model ,
controlnet_cond = controlnet_cond ,
2024-10-07 22:17:06 +00:00
weight = controlnet . control_weight ,
begin_step_percent = controlnet . begin_step_percent ,
end_step_percent = controlnet . end_step_percent ,
)
2024-10-03 16:18:57 +00:00
)
2024-10-07 22:17:06 +00:00
elif isinstance ( model , InstantXControlNetFlux ) :
2024-10-09 22:00:54 +00:00
instantx_control_mode : torch . Tensor | None = None
2024-10-10 15:25:30 +00:00
if controlnet . instantx_control_mode is not None and controlnet . instantx_control_mode > = 0 :
instantx_control_mode = torch . tensor ( controlnet . instantx_control_mode , dtype = torch . long )
instantx_control_mode = instantx_control_mode . reshape ( [ - 1 , 1 ] )
2024-10-07 22:17:06 +00:00
2024-10-09 16:12:09 +00:00
controlnet_extensions . append (
2024-10-10 00:18:46 +00:00
InstantXControlNetExtension (
2024-10-07 22:17:06 +00:00
model = model ,
2024-10-10 00:18:46 +00:00
controlnet_cond = controlnet_cond ,
2024-10-08 21:52:59 +00:00
instantx_control_mode = instantx_control_mode ,
2024-10-07 22:17:06 +00:00
weight = controlnet . control_weight ,
begin_step_percent = controlnet . begin_step_percent ,
end_step_percent = controlnet . end_step_percent ,
)
)
else :
raise ValueError ( f " Unsupported ControlNet model type: { type ( model ) } " )
2024-10-03 16:18:57 +00:00
2024-10-09 16:12:09 +00:00
return controlnet_extensions
2024-10-03 16:18:57 +00:00
2024-12-16 16:22:54 +00:00
def _prep_structural_control_img_cond ( self , context : InvocationContext ) - > torch . Tensor | None :
if self . control_lora is None :
return None
if not self . controlnet_vae :
raise ValueError ( " controlnet_vae must be set when using a FLUX Control LoRA. " )
# Load the conditioning image and resize it to the target image size.
cond_img = context . images . get_pil ( self . control_lora . img . image_name )
cond_img = cond_img . convert ( " RGB " )
cond_img = cond_img . resize ( ( self . width , self . height ) , Image . Resampling . BICUBIC )
cond_img = np . array ( cond_img )
# Normalize the conditioning image to the range [-1, 1].
# This normalization is based on the original implementations here:
# https://github.com/black-forest-labs/flux/blob/805da8571a0b49b6d4043950bd266a65328c243b/src/flux/modules/image_embedders.py#L34
# https://github.com/black-forest-labs/flux/blob/805da8571a0b49b6d4043950bd266a65328c243b/src/flux/modules/image_embedders.py#L60
img_cond = torch . from_numpy ( cond_img ) . float ( ) / 127.5 - 1.0
img_cond = einops . rearrange ( img_cond , " h w c -> 1 c h w " )
vae_info = context . models . load ( self . controlnet_vae . vae )
return FluxVaeEncodeInvocation . vae_encode ( vae_info = vae_info , image_tensor = img_cond )
2024-10-15 22:28:59 +00:00
def _normalize_ip_adapter_fields ( self ) - > list [ IPAdapterField ] :
if self . ip_adapter is None :
return [ ]
elif isinstance ( self . ip_adapter , IPAdapterField ) :
return [ self . ip_adapter ]
elif isinstance ( self . ip_adapter , list ) :
return self . ip_adapter
else :
raise ValueError ( f " Unsupported IP-Adapter type: { type ( self . ip_adapter ) } " )
def _prep_ip_adapter_image_prompt_clip_embeds (
self ,
ip_adapter_fields : list [ IPAdapterField ] ,
context : InvocationContext ,
2024-10-16 18:11:48 +00:00
) - > tuple [ list [ torch . Tensor ] , list [ torch . Tensor ] ] :
2024-10-15 22:28:59 +00:00
""" Run the IPAdapter CLIPVisionModel, returning image prompt embeddings. """
2024-10-16 18:11:48 +00:00
clip_image_processor = CLIPImageProcessor ( )
pos_image_prompt_clip_embeds : list [ torch . Tensor ] = [ ]
neg_image_prompt_clip_embeds : list [ torch . Tensor ] = [ ]
2024-10-15 22:28:59 +00:00
for ip_adapter_field in ip_adapter_fields :
# `ip_adapter_field.image` could be a list or a single ImageField. Normalize to a list here.
ipa_image_fields : list [ ImageField ]
if isinstance ( ip_adapter_field . image , ImageField ) :
ipa_image_fields = [ ip_adapter_field . image ]
elif isinstance ( ip_adapter_field . image , list ) :
ipa_image_fields = ip_adapter_field . image
else :
raise ValueError ( f " Unsupported IP-Adapter image type: { type ( ip_adapter_field . image ) } " )
2024-10-22 16:32:01 +00:00
if len ( ipa_image_fields ) != 1 :
raise ValueError (
f " FLUX IP-Adapter only supports a single image prompt (received { len ( ipa_image_fields ) } ). "
)
2024-10-23 08:34:15 +10:00
ipa_images = [ context . images . get_pil ( image . image_name , mode = " RGB " ) for image in ipa_image_fields ]
2024-10-15 22:28:59 +00:00
2024-10-18 18:52:12 +00:00
pos_images : list [ npt . NDArray [ np . uint8 ] ] = [ ]
neg_images : list [ npt . NDArray [ np . uint8 ] ] = [ ]
for ipa_image in ipa_images :
assert ipa_image . mode == " RGB "
pos_image = np . array ( ipa_image )
# We use a black image as the negative image prompt for parity with
# https://github.com/XLabs-AI/x-flux-comfyui/blob/45c834727dd2141aebc505ae4b01f193a8414e38/nodes.py#L592-L593
# An alternative scheme would be to apply zeros_like() after calling the clip_image_processor.
neg_image = np . zeros_like ( pos_image )
pos_images . append ( pos_image )
neg_images . append ( neg_image )
2024-10-15 22:28:59 +00:00
with context . models . load ( ip_adapter_field . image_encoder_model ) as image_encoder_model :
assert isinstance ( image_encoder_model , CLIPVisionModelWithProjection )
2024-10-18 18:52:12 +00:00
clip_image : torch . Tensor = clip_image_processor ( images = pos_images , return_tensors = " pt " ) . pixel_values
2024-10-16 18:11:48 +00:00
clip_image = clip_image . to ( device = image_encoder_model . device , dtype = image_encoder_model . dtype )
pos_clip_image_embeds = image_encoder_model ( clip_image ) . image_embeds
2024-10-18 18:52:12 +00:00
clip_image = clip_image_processor ( images = neg_images , return_tensors = " pt " ) . pixel_values
clip_image = clip_image . to ( device = image_encoder_model . device , dtype = image_encoder_model . dtype )
neg_clip_image_embeds = image_encoder_model ( clip_image ) . image_embeds
2024-10-16 18:11:48 +00:00
pos_image_prompt_clip_embeds . append ( pos_clip_image_embeds )
neg_image_prompt_clip_embeds . append ( neg_clip_image_embeds )
return pos_image_prompt_clip_embeds , neg_image_prompt_clip_embeds
2024-10-15 22:28:59 +00:00
def _prep_ip_adapter_extensions (
self ,
ip_adapter_fields : list [ IPAdapterField ] ,
2024-10-16 18:11:48 +00:00
pos_image_prompt_clip_embeds : list [ torch . Tensor ] ,
neg_image_prompt_clip_embeds : list [ torch . Tensor ] ,
2024-10-15 22:28:59 +00:00
context : InvocationContext ,
exit_stack : ExitStack ,
dtype : torch . dtype ,
2024-10-18 18:52:12 +00:00
) - > tuple [ list [ XLabsIPAdapterExtension ] , list [ XLabsIPAdapterExtension ] ] :
2024-10-16 18:11:48 +00:00
pos_ip_adapter_extensions : list [ XLabsIPAdapterExtension ] = [ ]
neg_ip_adapter_extensions : list [ XLabsIPAdapterExtension ] = [ ]
for ip_adapter_field , pos_image_prompt_clip_embed , neg_image_prompt_clip_embed in zip (
ip_adapter_fields , pos_image_prompt_clip_embeds , neg_image_prompt_clip_embeds , strict = True
) :
2024-10-15 22:28:59 +00:00
ip_adapter_model = exit_stack . enter_context ( context . models . load ( ip_adapter_field . ip_adapter_model ) )
assert isinstance ( ip_adapter_model , XlabsIpAdapterFlux )
ip_adapter_model = ip_adapter_model . to ( dtype = dtype )
if ip_adapter_field . mask is not None :
raise ValueError ( " IP-Adapter masks are not yet supported in Flux. " )
ip_adapter_extension = XLabsIPAdapterExtension (
model = ip_adapter_model ,
2024-10-16 18:11:48 +00:00
image_prompt_clip_embed = pos_image_prompt_clip_embed ,
2024-10-15 22:28:59 +00:00
weight = ip_adapter_field . weight ,
begin_step_percent = ip_adapter_field . begin_step_percent ,
end_step_percent = ip_adapter_field . end_step_percent ,
)
2024-10-16 18:11:48 +00:00
ip_adapter_extension . run_image_proj ( dtype = dtype )
pos_ip_adapter_extensions . append ( ip_adapter_extension )
2024-10-15 22:28:59 +00:00
2024-10-16 18:11:48 +00:00
ip_adapter_extension = XLabsIPAdapterExtension (
model = ip_adapter_model ,
image_prompt_clip_embed = neg_image_prompt_clip_embed ,
weight = ip_adapter_field . weight ,
begin_step_percent = ip_adapter_field . begin_step_percent ,
end_step_percent = ip_adapter_field . end_step_percent ,
)
2024-10-15 22:28:59 +00:00
ip_adapter_extension . run_image_proj ( dtype = dtype )
2024-10-16 18:11:48 +00:00
neg_ip_adapter_extensions . append ( ip_adapter_extension )
2024-10-15 22:28:59 +00:00
2024-10-16 18:11:48 +00:00
return pos_ip_adapter_extensions , neg_ip_adapter_extensions
2024-10-15 22:28:59 +00:00
2024-12-14 15:40:25 +00:00
def _lora_iterator ( self , context : InvocationContext ) - > Iterator [ Tuple [ ModelPatchRaw , float ] ] :
2024-12-12 13:45:07 -05:00
loras : list [ Union [ LoRAField , ControlLoRAField ] ] = [ * self . transformer . loras ]
if self . control_lora :
2024-12-17 00:43:13 +00:00
# Note: Since FLUX structural control LoRAs modify the shape of some weights, it is important that they are
# applied last.
2024-12-12 13:45:07 -05:00
loras . append ( self . control_lora )
2024-12-09 16:01:29 -05:00
for lora in loras :
2024-09-04 19:55:06 +00:00
lora_info = context . models . load ( lora . lora )
2024-12-14 15:40:25 +00:00
assert isinstance ( lora_info . model , ModelPatchRaw )
2024-09-04 19:55:06 +00:00
yield ( lora_info . model , lora . weight )
del lora_info
2024-09-02 23:21:33 -04:00
def _build_step_callback ( self , context : InvocationContext ) - > Callable [ [ PipelineIntermediateState ] , None ] :
def step_callback ( state : PipelineIntermediateState ) - > None :
state . latents = unpack ( state . latents . float ( ) , self . height , self . width ) . squeeze ( )
context . util . flux_step_callback ( state )
2024-08-30 14:46:04 +00:00
return step_callback