InvokeAI/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

661 lines
27 KiB
Python
Raw Normal View History

2023-08-08 23:33:52 +03:00
from __future__ import annotations
2023-02-28 00:31:15 -05:00
from contextlib import contextmanager
from dataclasses import dataclass
2023-08-06 05:05:25 +03:00
import math
2023-08-17 18:45:25 -04:00
from typing import Any, Callable, Optional, Union
2023-02-28 00:31:15 -05:00
import torch
from diffusers import UNet2DConditionModel
2023-02-28 00:31:15 -05:00
from typing_extensions import TypeAlias
from invokeai.app.services.config import InvokeAIAppConfig
2023-03-03 01:02:00 -05:00
from .cross_attention_control import (
Arguments,
Context,
CrossAttentionType,
SwapCrossAttnContext,
get_cross_attention_modules,
setup_cross_attention_control_attention_processors,
2023-03-03 01:02:00 -05:00
)
2023-02-28 00:31:15 -05:00
from .cross_attention_map_saving import AttentionMapSaver
ModelForwardCallback: TypeAlias = Union[
# x, t, conditioning, Optional[cross-attention kwargs]
2023-03-03 01:02:00 -05:00
Callable[
[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]],
torch.Tensor,
],
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
2023-02-28 00:31:15 -05:00
]
2023-07-27 10:54:01 -04:00
2023-08-08 23:33:52 +03:00
@dataclass
class BasicConditioningInfo:
embeds: torch.Tensor
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
# weight: float
# mode: ConditioningAlgo
def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype)
return self
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor
def to(self, device, dtype=None):
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
return super().to(device=device, dtype=dtype)
2023-02-28 00:31:15 -05:00
@dataclass(frozen=True)
class PostprocessingSettings:
threshold: float
warmup: float
h_symmetry_time_pct: Optional[float]
v_symmetry_time_pct: Optional[float]
class InvokeAIDiffuserComponent:
2023-03-03 01:02:00 -05:00
"""
2023-02-28 00:31:15 -05:00
The aim of this component is to provide a single place for code that can be applied identically to
all InvokeAI diffusion procedures.
At the moment it includes the following features:
* Cross attention control ("prompt2prompt")
* Hybrid conditioning (used for inpainting)
2023-03-03 01:02:00 -05:00
"""
2023-02-28 00:31:15 -05:00
debug_thresholding = False
sequential_guidance = False
@dataclass
class ExtraConditioningInfo:
tokens_count_including_eos_bos: int
cross_attention_control_args: Optional[Arguments] = None
@property
def wants_cross_attention_control(self):
return self.cross_attention_control_args is not None
2023-03-03 01:02:00 -05:00
def __init__(
self,
model,
model_forward_callback: ModelForwardCallback,
):
2023-02-28 00:31:15 -05:00
"""
:param model: the unet model to pass through to cross attention control
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
"""
config = InvokeAIAppConfig.get_config()
2023-02-28 00:31:15 -05:00
self.conditioning = None
self.model = model
self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None
self.sequential_guidance = config.sequential_guidance
2023-02-28 00:31:15 -05:00
@contextmanager
2023-03-03 01:02:00 -05:00
def custom_attention_context(
self,
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int,
2023-03-03 01:02:00 -05:00
):
old_attn_processors = None
if extra_conditioning_info and (extra_conditioning_info.wants_cross_attention_control):
old_attn_processors = unet.attn_processors
# Load lora conditions into the model
if extra_conditioning_info.wants_cross_attention_control:
self.cross_attention_control_context = Context(
arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
)
setup_cross_attention_control_attention_processors(
unet,
self.cross_attention_control_context,
)
2023-02-28 00:31:15 -05:00
try:
yield None
finally:
self.cross_attention_control_context = None
if old_attn_processors is not None:
unet.set_attn_processor(old_attn_processors)
2023-02-28 00:31:15 -05:00
# TODO resuscitate attention map saving
2023-03-03 01:02:00 -05:00
# self.remove_attention_map_saving()
2023-02-28 00:31:15 -05:00
def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key):
if dim is not None:
# sliced tokens attention map saving is not implemented
return
saver.add_attention_maps(slice, key)
2023-03-03 01:02:00 -05:00
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
2023-02-28 00:31:15 -05:00
for identifier, module in tokens_cross_attention_modules:
2023-03-03 01:02:00 -05:00
key = "down" if identifier.startswith("down") else "up" if identifier.startswith("up") else "mid"
2023-02-28 00:31:15 -05:00
module.set_attention_slice_calculated_callback(
2023-03-03 01:02:00 -05:00
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key)
)
2023-02-28 00:31:15 -05:00
def remove_attention_map_saving(self):
2023-03-03 01:02:00 -05:00
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
2023-02-28 00:31:15 -05:00
for _, module in tokens_cross_attention_modules:
module.set_attention_slice_calculated_callback(None)
2023-08-06 05:05:25 +03:00
def do_controlnet_step(
2023-03-03 01:02:00 -05:00
self,
2023-08-06 05:05:25 +03:00
control_data,
sample: torch.Tensor,
timestep: torch.Tensor,
step_index: int,
total_step_count: int,
conditioning_data,
2023-03-03 01:02:00 -05:00
):
2023-08-06 05:05:25 +03:00
down_block_res_samples, mid_block_res_sample = None, None
# control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list)
for i, control_datum in enumerate(control_data):
control_mode = control_datum.control_mode
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
# that are combined at higher level to make control_mode enum
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
# or default weighting (if False)
soft_injection = control_mode == "more_prompt" or control_mode == "more_control"
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
# or the default both conditional and unconditional (if False)
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step:
if cfg_injection:
sample_model_input = sample
else:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
sample_model_input = torch.cat([sample] * 2)
added_cond_kwargs = None
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
2023-08-08 23:33:52 +03:00
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
2023-08-06 05:05:25 +03:00
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids,
}
encoder_hidden_states = conditioning_data.text_embeddings.embeds
encoder_attention_mask = None
else:
2023-08-08 23:33:52 +03:00
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
2023-08-06 05:05:25 +03:00
added_cond_kwargs = {
2023-08-13 21:28:39 +12:00
"text_embeds": torch.cat(
[
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds,
conditioning_data.text_embeddings.pooled_embeds,
],
dim=0,
),
"time_ids": torch.cat(
[
conditioning_data.unconditioned_embeddings.add_time_ids,
conditioning_data.text_embeddings.add_time_ids,
],
dim=0,
),
2023-08-06 05:05:25 +03:00
}
(
encoder_hidden_states,
encoder_attention_mask,
) = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds,
conditioning_data.text_embeddings.embeds,
)
if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index]
else:
# if controlnet has a single weight, use it for all steps
controlnet_weight = control_datum.weight
# controlnet(s) inference
down_samples, mid_sample = control_datum.model(
sample=sample_model_input,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
encoder_attention_mask=encoder_attention_mask,
added_cond_kwargs=added_cond_kwargs,
2023-08-06 05:05:25 +03:00
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,
)
if cfg_injection:
# Inferred ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# prepend zeros for unconditional batch
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
# add controlnet outputs together if have multiple controlnets
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample
2023-02-28 00:31:15 -05:00
2023-08-06 05:05:25 +03:00
return down_block_res_samples, mid_block_res_sample
Feat/easy param (#3504) * Testing change to LatentsToText to allow setting different cfg_scale values per diffusion step. * Adding first attempt at float param easing node, using Penner easing functions. * Core implementation of ControlNet and MultiControlNet. * Added support for ControlNet and MultiControlNet to legacy non-nodal Txt2Img in backend/generator. Although backend/generator will likely disappear by v3.x, right now they are very useful for testing core ControlNet and MultiControlNet functionality while node codebase is rapidly evolving. * Added example of using ControlNet with legacy Txt2Img generator * Resolving rebase conflict * Added first controlnet preprocessor node for canny edge detection. * Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node * Switching to ControlField for output from controlnet nodes. * Resolving conflicts in rebase to origin/main * Refactored ControlNet nodes so they subclass from PreprocessedControlInvocation, and only need to override run_processor(image) (instead of reimplementing invoke()) * changes to base class for controlnet nodes * Added HED, LineArt, and OpenPose ControlNet nodes * Added an additional "raw_processed_image" output port to controlnets, mainly so could route ImageField to a ShowImage node * Added more preprocessor nodes for: MidasDepth ZoeDepth MLSD NormalBae Pidi LineartAnime ContentShuffle Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup. * Prep for splitting pre-processor and controlnet nodes * Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes. * Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue. * More rebase repair. * Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this... * Fixed use of ControlNet control_weight parameter * Fixed lint-ish formatting error * Core implementation of ControlNet and MultiControlNet. * Added first controlnet preprocessor node for canny edge detection. * Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node * Switching to ControlField for output from controlnet nodes. * Refactored controlnet node to output ControlField that bundles control info. * changes to base class for controlnet nodes * Added more preprocessor nodes for: MidasDepth ZoeDepth MLSD NormalBae Pidi LineartAnime ContentShuffle Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup. * Prep for splitting pre-processor and controlnet nodes * Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes. * Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue. * Cleaning up TextToLatent arg testing * Cleaning up mistakes after rebase. * Removed last bits of dtype and and device hardwiring from controlnet section * Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled. * Added support for specifying which step iteration to start using each ControlNet, and which step to end using each controlnet (specified as fraction of total steps) * Cleaning up prior to submitting ControlNet PR. Mostly turning off diagnostic printing. Also fixed error when there is no controlnet input. * Added dependency on controlnet-aux v0.0.3 * Commented out ZoeDetector. Will re-instate once there's a controlnet-aux release that supports it. * Switched CotrolNet node modelname input from free text to default list of popular ControlNet model names. * Fix to work with current stable release of controlnet_aux (v0.0.3). Turned of pre-processor params that were added post v0.0.3. Also change defaults for shuffle. * Refactored most of controlnet code into its own method to declutter TextToLatents.invoke(), and make upcoming integration with LatentsToLatents easier. * Cleaning up after ControlNet refactor in TextToLatentsInvocation * Extended node-based ControlNet support to LatentsToLatentsInvocation. * chore(ui): regen api client * fix(ui): add value to conditioning field * fix(ui): add control field type * fix(ui): fix node ui type hints * fix(nodes): controlnet input accepts list or single controlnet * Moved to controlnet_aux v0.0.4, reinstated Zoe controlnet preprocessor. Also in pyproject.toml had to specify downgrade of timm to 0.6.13 _after_ controlnet-aux installs timm >= 0.9.2, because timm >0.6.13 breaks Zoe preprocessor. * Core implementation of ControlNet and MultiControlNet. * Added first controlnet preprocessor node for canny edge detection. * Switching to ControlField for output from controlnet nodes. * Resolving conflicts in rebase to origin/main * Refactored ControlNet nodes so they subclass from PreprocessedControlInvocation, and only need to override run_processor(image) (instead of reimplementing invoke()) * changes to base class for controlnet nodes * Added HED, LineArt, and OpenPose ControlNet nodes * Added more preprocessor nodes for: MidasDepth ZoeDepth MLSD NormalBae Pidi LineartAnime ContentShuffle Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup. * Prep for splitting pre-processor and controlnet nodes * Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes. * Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue. * Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this... * Fixed use of ControlNet control_weight parameter * Core implementation of ControlNet and MultiControlNet. * Added first controlnet preprocessor node for canny edge detection. * Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node * Switching to ControlField for output from controlnet nodes. * Refactored controlnet node to output ControlField that bundles control info. * changes to base class for controlnet nodes * Added more preprocessor nodes for: MidasDepth ZoeDepth MLSD NormalBae Pidi LineartAnime ContentShuffle Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup. * Prep for splitting pre-processor and controlnet nodes * Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes. * Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue. * Cleaning up TextToLatent arg testing * Cleaning up mistakes after rebase. * Removed last bits of dtype and and device hardwiring from controlnet section * Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled. * Added support for specifying which step iteration to start using each ControlNet, and which step to end using each controlnet (specified as fraction of total steps) * Cleaning up prior to submitting ControlNet PR. Mostly turning off diagnostic printing. Also fixed error when there is no controlnet input. * Commented out ZoeDetector. Will re-instate once there's a controlnet-aux release that supports it. * Switched CotrolNet node modelname input from free text to default list of popular ControlNet model names. * Fix to work with current stable release of controlnet_aux (v0.0.3). Turned of pre-processor params that were added post v0.0.3. Also change defaults for shuffle. * Refactored most of controlnet code into its own method to declutter TextToLatents.invoke(), and make upcoming integration with LatentsToLatents easier. * Cleaning up after ControlNet refactor in TextToLatentsInvocation * Extended node-based ControlNet support to LatentsToLatentsInvocation. * chore(ui): regen api client * fix(ui): fix node ui type hints * fix(nodes): controlnet input accepts list or single controlnet * Added Mediapipe image processor for use as ControlNet preprocessor. Also hacked in ability to specify HF subfolder when loading ControlNet models from string. * Fixed bug where MediapipFaceProcessorInvocation was ignoring max_faces and min_confidence params. * Added nodes for float params: ParamFloatInvocation and FloatCollectionOutput. Also added FloatOutput. * Added mediapipe install requirement. Should be able to remove once controlnet_aux package adds mediapipe to its requirements. * Added float to FIELD_TYPE_MAP ins constants.ts * Progress toward improvement in fieldTemplateBuilder.ts getFieldType() * Fixed controlnet preprocessors and controlnet handling in TextToLatents to work with revised Image services. * Cleaning up from merge, re-adding cfg_scale to FIELD_TYPE_MAP * Making sure cfg_scale of type list[float] can be used in image metadata, to support param easing for cfg_scale * Fixed math for per-step param easing. * Added option to show plot of param value at each step * Just cleaning up after adding param easing plot option, removing vestigial code. * Modified control_weight ControlNet param to be polistmorphic -- can now be either a single float weight applied for all steps, or a list of floats of size total_steps, that specifies weight for each step. * Added more informative error message when _validat_edge() throws an error. * Just improving parm easing bar chart title to include easing type. * Added requirement for easing-functions package * Taking out some diagnostic prints. * Added option to use both easing function and mirror of easing function together. * Fixed recently introduced problem (when pulled in main), triggered by num_steps in StepParamEasingInvocation not having a default value -- just added default. --------- Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2023-06-10 23:27:44 -07:00
2023-08-06 05:05:25 +03:00
def do_unet_step(
self,
sample: torch.Tensor,
timestep: torch.Tensor,
2023-08-13 21:28:39 +12:00
conditioning_data, # TODO: type
2023-08-06 05:05:25 +03:00
step_index: int,
total_step_count: int,
**kwargs,
):
2023-02-28 00:31:15 -05:00
cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None:
2023-06-17 19:39:51 +03:00
percent_through = step_index / total_step_count
2023-03-03 01:02:00 -05:00
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(
percent_through
)
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
2023-08-06 05:05:25 +03:00
if wants_cross_attention_control:
2023-03-03 01:02:00 -05:00
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_cross_attention_controlled_conditioning(
2023-08-06 05:05:25 +03:00
sample,
timestep,
conditioning_data,
2023-03-03 01:02:00 -05:00
cross_attention_control_types_to_do,
**kwargs,
2023-03-03 01:02:00 -05:00
)
2023-02-28 00:31:15 -05:00
elif self.sequential_guidance:
2023-03-03 01:02:00 -05:00
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning_sequentially(
2023-08-06 05:05:25 +03:00
sample,
timestep,
conditioning_data,
**kwargs,
2023-03-03 01:02:00 -05:00
)
2023-02-28 00:31:15 -05:00
else:
2023-03-03 01:02:00 -05:00
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
2023-08-06 05:05:25 +03:00
sample,
timestep,
conditioning_data,
**kwargs,
2023-03-03 01:02:00 -05:00
)
2023-08-06 05:05:25 +03:00
return unconditioned_next_x, conditioned_next_x
2023-02-28 00:31:15 -05:00
def do_latent_postprocessing(
self,
postprocessing_settings: PostprocessingSettings,
latents: torch.Tensor,
sigma,
step_index,
2023-03-03 01:02:00 -05:00
total_step_count,
2023-02-28 00:31:15 -05:00
) -> torch.Tensor:
if postprocessing_settings is not None:
2023-06-17 19:39:51 +03:00
percent_through = step_index / total_step_count
2023-03-03 01:02:00 -05:00
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
2023-02-28 00:31:15 -05:00
return latents
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones(
(cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype
)
if cond.shape[1] < max_len:
conditioning_attention_mask = torch.cat(
[
conditioning_attention_mask,
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
],
dim=1,
)
cond = torch.cat(
[
cond,
torch.zeros(
(cond.shape[0], max_len - cond.shape[1], cond.shape[2]),
device=cond.device,
dtype=cond.dtype,
),
],
dim=1,
)
if encoder_attention_mask is None:
encoder_attention_mask = conditioning_attention_mask
else:
encoder_attention_mask = torch.cat(
[
encoder_attention_mask,
conditioning_attention_mask,
]
)
2023-07-27 10:54:01 -04:00
return cond, encoder_attention_mask
encoder_attention_mask = None
if unconditioning.shape[1] != conditioning.shape[1]:
max_len = max(unconditioning.shape[1], conditioning.shape[1])
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
2023-02-28 00:31:15 -05:00
# methods below are called from do_diffusion_step and should be considered private to this class.
2023-08-06 05:05:25 +03:00
def _apply_standard_conditioning(self, x, sigma, conditioning_data, **kwargs):
2023-02-28 00:31:15 -05:00
# fast batched path
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
2023-08-06 05:05:25 +03:00
added_cond_kwargs = None
2023-08-08 23:33:52 +03:00
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
2023-08-06 05:05:25 +03:00
added_cond_kwargs = {
2023-08-13 21:28:39 +12:00
"text_embeds": torch.cat(
[
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds,
conditioning_data.text_embeddings.pooled_embeds,
],
dim=0,
),
"time_ids": torch.cat(
[
conditioning_data.unconditioned_embeddings.add_time_ids,
conditioning_data.text_embeddings.add_time_ids,
],
dim=0,
),
2023-08-06 05:05:25 +03:00
}
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
2023-08-13 21:28:39 +12:00
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds
2023-08-06 05:05:25 +03:00
)
2023-03-03 01:02:00 -05:00
both_results = self.model_forward_callback(
x_twice,
sigma_twice,
both_conditionings,
encoder_attention_mask=encoder_attention_mask,
2023-08-06 05:05:25 +03:00
added_cond_kwargs=added_cond_kwargs,
**kwargs,
2023-03-03 01:02:00 -05:00
)
2023-02-28 00:31:15 -05:00
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
return unconditioned_next_x, conditioned_next_x
2023-03-03 01:02:00 -05:00
def _apply_standard_conditioning_sequentially(
self,
x: torch.Tensor,
sigma,
conditioning_data,
**kwargs,
2023-03-03 01:02:00 -05:00
):
2023-02-28 00:31:15 -05:00
# low-memory sequential path
uncond_down_block, cond_down_block = None, None
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
if down_block_additional_residuals is not None:
uncond_down_block, cond_down_block = [], []
for down_block in down_block_additional_residuals:
_uncond_down, _cond_down = down_block.chunk(2)
uncond_down_block.append(_uncond_down)
cond_down_block.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
2023-08-06 05:05:25 +03:00
added_cond_kwargs = None
2023-08-08 23:33:52 +03:00
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
2023-08-06 05:05:25 +03:00
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
}
unconditioned_next_x = self.model_forward_callback(
x,
sigma,
2023-08-06 05:05:25 +03:00
conditioning_data.unconditioned_embeddings.embeds,
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
2023-08-06 05:05:25 +03:00
added_cond_kwargs=added_cond_kwargs,
**kwargs,
)
2023-08-06 05:05:25 +03:00
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids,
}
conditioned_next_x = self.model_forward_callback(
x,
sigma,
2023-08-06 05:05:25 +03:00
conditioning_data.text_embeddings.embeds,
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
2023-08-06 05:05:25 +03:00
added_cond_kwargs=added_cond_kwargs,
**kwargs,
)
2023-02-28 00:31:15 -05:00
return unconditioned_next_x, conditioned_next_x
2023-03-03 01:02:00 -05:00
def _apply_cross_attention_controlled_conditioning(
self,
x: torch.Tensor,
sigma,
conditioning_data,
2023-03-03 01:02:00 -05:00
cross_attention_control_types_to_do,
**kwargs,
2023-03-03 01:02:00 -05:00
):
2023-02-28 00:31:15 -05:00
context: Context = self.cross_attention_control_context
uncond_down_block, cond_down_block = None, None
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
if down_block_additional_residuals is not None:
uncond_down_block, cond_down_block = [], []
for down_block in down_block_additional_residuals:
_uncond_down, _cond_down = down_block.chunk(2)
uncond_down_block.append(_uncond_down)
cond_down_block.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
2023-03-03 01:02:00 -05:00
cross_attn_processor_context = SwapCrossAttnContext(
modified_text_embeddings=context.arguments.edited_conditioning,
index_map=context.cross_attention_index_map,
mask=context.cross_attention_mask,
cross_attention_types_to_do=[],
)
2023-08-06 05:05:25 +03:00
added_cond_kwargs = None
2023-08-08 23:33:52 +03:00
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
2023-08-06 05:05:25 +03:00
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
}
2023-02-28 00:31:15 -05:00
# no cross attention for unconditioning (negative prompt)
2023-03-03 01:02:00 -05:00
unconditioned_next_x = self.model_forward_callback(
x,
sigma,
2023-08-06 05:05:25 +03:00
conditioning_data.unconditioned_embeddings.embeds,
2023-03-03 01:02:00 -05:00
{"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
2023-08-06 05:05:25 +03:00
added_cond_kwargs=added_cond_kwargs,
**kwargs,
2023-03-03 01:02:00 -05:00
)
2023-02-28 00:31:15 -05:00
2023-08-06 05:05:25 +03:00
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids,
}
2023-02-28 00:31:15 -05:00
# do requested cross attention types for conditioning (positive prompt)
2023-03-03 01:02:00 -05:00
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
conditioned_next_x = self.model_forward_callback(
x,
sigma,
2023-08-06 05:05:25 +03:00
conditioning_data.text_embeddings.embeds,
2023-03-03 01:02:00 -05:00
{"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
2023-08-06 05:05:25 +03:00
added_cond_kwargs=added_cond_kwargs,
**kwargs,
2023-03-03 01:02:00 -05:00
)
2023-02-28 00:31:15 -05:00
return unconditioned_next_x, conditioned_next_x
def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale):
# to scale how much effect conditioning has, calculate the changes it does and then scale that
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
combined_next_x = unconditioned_next_x + scaled_delta
return combined_next_x
def apply_symmetry(
self,
postprocessing_settings: PostprocessingSettings,
latents: torch.Tensor,
2023-03-03 01:02:00 -05:00
percent_through: float,
2023-02-28 00:31:15 -05:00
) -> torch.Tensor:
# Reset our last percent through if this is our first step.
if percent_through == 0.0:
self.last_percent_through = 0.0
if postprocessing_settings is None:
return latents
# Check for out of bounds
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
2023-03-03 01:02:00 -05:00
if h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0):
2023-02-28 00:31:15 -05:00
h_symmetry_time_pct = None
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
2023-03-03 01:02:00 -05:00
if v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0):
2023-02-28 00:31:15 -05:00
v_symmetry_time_pct = None
dev = latents.device.type
2023-03-03 01:02:00 -05:00
latents.to(device="cpu")
2023-02-28 00:31:15 -05:00
if (
2023-08-17 18:45:25 -04:00
h_symmetry_time_pct is not None
2023-03-03 01:02:00 -05:00
and self.last_percent_through < h_symmetry_time_pct
and percent_through >= h_symmetry_time_pct
2023-02-28 00:31:15 -05:00
):
# Horizontal symmetry occurs on the 3rd dimension of the latent
width = latents.shape[3]
x_flipped = torch.flip(latents, dims=[3])
2023-03-03 01:02:00 -05:00
latents = torch.cat(
[
latents[:, :, :, 0 : int(width / 2)],
x_flipped[:, :, :, int(width / 2) : int(width)],
],
dim=3,
)
2023-02-28 00:31:15 -05:00
if (
2023-08-17 18:45:25 -04:00
v_symmetry_time_pct is not None
2023-03-03 01:02:00 -05:00
and self.last_percent_through < v_symmetry_time_pct
and percent_through >= v_symmetry_time_pct
2023-02-28 00:31:15 -05:00
):
# Vertical symmetry occurs on the 2nd dimension of the latent
height = latents.shape[2]
y_flipped = torch.flip(latents, dims=[2])
2023-03-03 01:02:00 -05:00
latents = torch.cat(
[
latents[:, :, 0 : int(height / 2)],
y_flipped[:, :, int(height / 2) : int(height)],
],
dim=2,
)
2023-02-28 00:31:15 -05:00
self.last_percent_through = percent_through
return latents.to(device=dev)
# todo: make this work
@classmethod
2023-03-03 01:02:00 -05:00
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
2023-02-28 00:31:15 -05:00
x_in = torch.cat([x] * 2)
2023-03-03 01:02:00 -05:00
t_in = torch.cat([t] * 2) # aka sigmas
2023-02-28 00:31:15 -05:00
deltas = None
uncond_latents = None
2023-03-03 01:02:00 -05:00
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
2023-02-28 00:31:15 -05:00
# below is fugly omg
2023-03-03 01:02:00 -05:00
conditionings = [uc] + [c for c, weight in weighted_cond_list]
weights = [1] + [weight for c, weight in weighted_cond_list]
2023-08-06 05:05:25 +03:00
chunk_count = math.ceil(len(conditionings) / 2)
2023-02-28 00:31:15 -05:00
deltas = None
for chunk_index in range(chunk_count):
2023-03-03 01:02:00 -05:00
offset = chunk_index * 2
chunk_size = min(2, len(conditionings) - offset)
2023-02-28 00:31:15 -05:00
if chunk_size == 1:
c_in = conditionings[offset]
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
latents_b = None
else:
2023-03-03 01:02:00 -05:00
c_in = torch.cat(conditionings[offset : offset + 2])
2023-02-28 00:31:15 -05:00
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
if chunk_index == 0:
uncond_latents = latents_a
deltas = latents_b - uncond_latents
else:
deltas = torch.cat((deltas, latents_a - uncond_latents))
if latents_b is not None:
deltas = torch.cat((deltas, latents_b - uncond_latents))
# merge the weighted deltas together into a single merged delta
2023-03-03 01:02:00 -05:00
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
2023-02-28 00:31:15 -05:00
normalize = False
if normalize:
per_delta_weights /= torch.sum(per_delta_weights)
2023-03-03 01:02:00 -05:00
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
2023-02-28 00:31:15 -05:00
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
return uncond_latents + deltas_merged * global_guidance_scale