InvokeAI/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

294 lines
12 KiB
Python
Raw Normal View History

import copy
from contextlib import ExitStack
from typing import Iterator, Tuple
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
from invokeai.app.invocations.fields import (
ConditioningField,
FieldDescriptions,
Input,
InputField,
LatentsField,
UIType,
)
from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.patches.layer_patcher import LayerPatcher
2024-12-14 15:40:25 +00:00
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
MultiDiffusionRegionConditioning,
)
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.tiles.tiles import (
calc_tiles_min_overlap,
)
from invokeai.backend.tiles.utils import TBLR
from invokeai.backend.util.devices import TorchDevice
def crop_controlnet_data(control_data: ControlNetData, latent_region: TBLR) -> ControlNetData:
"""Crop a ControlNetData object to a region."""
# Create a shallow copy of the control_data object.
control_data_copy = copy.copy(control_data)
# The ControlNet reference image is the only attribute that needs to be cropped.
control_data_copy.image_tensor = control_data.image_tensor[
:,
:,
latent_region.top * LATENT_SCALE_FACTOR : latent_region.bottom * LATENT_SCALE_FACTOR,
latent_region.left * LATENT_SCALE_FACTOR : latent_region.right * LATENT_SCALE_FACTOR,
]
return control_data_copy
@invocation(
"tiled_multi_diffusion_denoise_latents",
title="Tiled Multi-Diffusion Denoise Latents",
tags=["upscale", "denoise"],
category="latents",
classification=Classification.Beta,
version="1.0.0",
)
class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
"""Tiled Multi-Diffusion denoising.
2024-06-19 11:07:35 -04:00
This node handles automatically tiling the input image, and is primarily intended for global refinement of images
in tiled upscaling workflows. Future Multi-Diffusion nodes should allow the user to specify custom regions with
different parameters for each region to harness the full power of Multi-Diffusion.
This node has a similar interface to the `DenoiseLatents` node, but it has a reduced feature set (no IP-Adapter,
T2I-Adapter, masking, etc.).
"""
positive_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
noise: LatentsField | None = InputField(
default=None,
description=FieldDescriptions.noise,
input=Input.Connection,
)
latents: LatentsField | None = InputField(
default=None,
description=FieldDescriptions.latents,
input=Input.Connection,
)
tile_height: int = InputField(
default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Height of the tiles in image space."
)
tile_width: int = InputField(
default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Width of the tiles in image space."
)
tile_overlap: int = InputField(
default=32,
multiple_of=LATENT_SCALE_FACTOR,
gt=0,
description="The overlap between adjacent tiles in pixel space. (Of course, tile merging is applied in latent "
"space.) Tiles will be cropped during merging (if necessary) to ensure that they overlap by exactly this "
"amount.",
)
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
denoising_start: float = InputField(
2024-06-19 11:07:35 -04:00
default=0.0,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
unet: UNetField = InputField(
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
cfg_rescale_multiplier: float = InputField(
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
)
control: ControlField | list[ControlField] | None = InputField(
default=None,
input=Input.Connection,
)
@field_validator("cfg_scale")
def ge_one(cls, v: list[float] | float) -> list[float] | float:
"""Validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError("cfg_scale must be greater than 1")
else:
if v < 1:
raise ValueError("cfg_scale must be greater than 1")
return v
@staticmethod
def create_pipeline(
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
) -> MultiDiffusionPipeline:
# TODO(ryand): Get rid of this FakeVae hack.
class FakeVae:
class FakeVaeConfig:
def __init__(self) -> None:
self.block_out_channels = [0]
def __init__(self) -> None:
self.config = FakeVae.FakeVaeConfig()
return MultiDiffusionPipeline(
2024-06-19 11:07:35 -04:00
vae=FakeVae(),
text_encoder=None,
tokenizer=None,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
# Convert tile image-space dimensions to latent-space dimensions.
latent_tile_height = self.tile_height // LATENT_SCALE_FACTOR
latent_tile_width = self.tile_width // LATENT_SCALE_FACTOR
latent_tile_overlap = self.tile_overlap // LATENT_SCALE_FACTOR
seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
_, _, latent_height, latent_width = latents.shape
# Calculate the tile locations to cover the latent-space image.
# TODO(ryand): In the future, we may want to revisit the tile overlap strategy. Things to consider:
# - How much overlap 'context' to provide for each denoising step.
# - How much overlap to use during merging/blending.
# - Should we 'jitter' the tile locations in each step so that the seams are in different places?
tiles = calc_tiles_min_overlap(
image_height=latent_height,
image_width=latent_width,
tile_height=latent_tile_height,
tile_width=latent_tile_width,
min_overlap=latent_tile_overlap,
)
2024-06-19 13:54:54 -04:00
# Get the unet's config so that we can pass the base to sd_step_callback().
unet_config = context.models.get_config(self.unet.unet.key)
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
# Prepare an iterator that yields the UNet's LoRA models and their weights.
2024-12-14 15:40:25 +00:00
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
2024-12-14 15:40:25 +00:00
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
with (
ExitStack() as exit_stack,
unet_info as unet,
LayerPatcher.apply_smart_model_patches(
model=unet, patches=_lora_loader(), prefix="lora_unet_", dtype=unet.dtype
),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)
pipeline = self.create_pipeline(unet=unet, scheduler=scheduler)
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
device=unet.device,
dtype=unet.dtype,
latent_height=latent_tile_height,
latent_width=latent_tile_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
controlnet_data = DenoiseLatentsInvocation.prep_control_data(
context=context,
control_input=self.control,
latents_shape=list(latents.shape),
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
# Split the controlnet_data into tiles.
# controlnet_data_tiles[t][c] is the c'th control data for the t'th tile.
controlnet_data_tiles: list[list[ControlNetData]] = []
for tile in tiles:
tile_controlnet_data = [crop_controlnet_data(cn, tile.coords) for cn in controlnet_data or []]
controlnet_data_tiles.append(tile_controlnet_data)
# Prepare the MultiDiffusionRegionConditioning list.
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning] = []
for tile, tile_controlnet_data in zip(tiles, controlnet_data_tiles, strict=True):
multi_diffusion_conditioning.append(
MultiDiffusionRegionConditioning(
region=tile,
text_conditioning_data=conditioning_data,
control_data=tile_controlnet_data,
)
)
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
seed=seed,
)
# Run Multi-Diffusion denoising.
result_latents = pipeline.multi_diffusion_denoise(
multi_diffusion_conditioning=multi_diffusion_conditioning,
target_overlap=latent_tile_overlap,
latents=latents,
scheduler_step_kwargs=scheduler_step_kwargs,
noise=noise,
timesteps=timesteps,
init_timestep=init_timestep,
callback=step_callback,
)
result_latents = result_latents.to("cpu")
2024-06-19 11:07:35 -04:00
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)