mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui
synced 2025-01-09 05:37:34 +08:00
Merge branch 'dev' into refiner
This commit is contained in:
commit
54c3e5c913
@ -1,5 +1,7 @@
|
||||
import math
|
||||
|
||||
import gradio as gr
|
||||
from modules import scripts, shared, ui_components, ui_settings
|
||||
from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste
|
||||
from modules.ui_components import FormColumn
|
||||
|
||||
|
||||
@ -19,18 +21,37 @@ class ExtraOptionsSection(scripts.Script):
|
||||
def ui(self, is_img2img):
|
||||
self.comps = []
|
||||
self.setting_names = []
|
||||
self.infotext_fields = []
|
||||
|
||||
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
|
||||
|
||||
with gr.Blocks() as interface:
|
||||
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row():
|
||||
for setting_name in shared.opts.extra_options:
|
||||
with FormColumn():
|
||||
comp = ui_settings.create_setting_component(setting_name)
|
||||
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group():
|
||||
|
||||
self.comps.append(comp)
|
||||
self.setting_names.append(setting_name)
|
||||
row_count = math.ceil(len(shared.opts.extra_options) / shared.opts.extra_options_cols)
|
||||
|
||||
for row in range(row_count):
|
||||
with gr.Row():
|
||||
for col in range(shared.opts.extra_options_cols):
|
||||
index = row * shared.opts.extra_options_cols + col
|
||||
if index >= len(shared.opts.extra_options):
|
||||
break
|
||||
|
||||
setting_name = shared.opts.extra_options[index]
|
||||
|
||||
with FormColumn():
|
||||
comp = ui_settings.create_setting_component(setting_name)
|
||||
|
||||
self.comps.append(comp)
|
||||
self.setting_names.append(setting_name)
|
||||
|
||||
setting_infotext_name = mapping.get(setting_name)
|
||||
if setting_infotext_name is not None:
|
||||
self.infotext_fields.append((comp, setting_infotext_name))
|
||||
|
||||
def get_settings_values():
|
||||
return [ui_settings.get_value_for_setting(key) for key in self.setting_names]
|
||||
res = [ui_settings.get_value_for_setting(key) for key in self.setting_names]
|
||||
return res[0] if len(res) == 1 else res
|
||||
|
||||
interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False)
|
||||
|
||||
@ -44,5 +65,8 @@ class ExtraOptionsSection(scripts.Script):
|
||||
|
||||
shared.options_templates.update(shared.options_section(('ui', "User interface"), {
|
||||
"extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_reload_ui(),
|
||||
"extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion").needs_restart()
|
||||
"extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(),
|
||||
"extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui()
|
||||
}))
|
||||
|
||||
|
||||
|
@ -136,6 +136,11 @@ function setupImageForLightbox(e) {
|
||||
var event = isFirefox ? 'mousedown' : 'click';
|
||||
|
||||
e.addEventListener(event, function(evt) {
|
||||
if (evt.button == 1) {
|
||||
open(evt.target.src);
|
||||
evt.preventDefault();
|
||||
return;
|
||||
}
|
||||
if (!opts.js_modal_lightbox || evt.button != 0) return;
|
||||
|
||||
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed);
|
||||
|
@ -416,10 +416,15 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
||||
return res
|
||||
|
||||
if override_settings_component is not None:
|
||||
already_handled_fields = {key: 1 for _, key in paste_fields}
|
||||
|
||||
def paste_settings(params):
|
||||
vals = {}
|
||||
|
||||
for param_name, setting_name in infotext_to_setting_name_mapping:
|
||||
if param_name in already_handled_fields:
|
||||
continue
|
||||
|
||||
v = params.get(param_name, None)
|
||||
if v is None:
|
||||
continue
|
||||
|
@ -6,7 +6,7 @@ import numpy as np
|
||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
|
||||
import gradio as gr
|
||||
|
||||
from modules import sd_samplers, images as imgutil
|
||||
from modules import images as imgutil
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.shared import opts, state
|
||||
@ -116,7 +116,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
||||
process_images(p)
|
||||
|
||||
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||
override_settings = create_override_settings_dict(override_settings_texts)
|
||||
|
||||
is_batch = mode == 5
|
||||
@ -172,7 +172,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
||||
seed_resize_from_h=seed_resize_from_h,
|
||||
seed_resize_from_w=seed_resize_from_w,
|
||||
seed_enable_extras=seed_enable_extras,
|
||||
sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
|
||||
sampler_name=sampler_name,
|
||||
batch_size=batch_size,
|
||||
n_iter=n_iter,
|
||||
steps=steps,
|
||||
|
@ -139,6 +139,27 @@ def check_run_python(code: str) -> bool:
|
||||
return result.returncode == 0
|
||||
|
||||
|
||||
def git_fix_workspace(dir, name):
|
||||
run(f'"{git}" -C "{dir}" fetch --refetch --no-auto-gc', f"Fetching all contents for {name}", f"Couldn't fetch {name}", live=True)
|
||||
run(f'"{git}" -C "{dir}" gc --aggressive --prune=now', f"Pruning {name}", f"Couldn't prune {name}", live=True)
|
||||
return
|
||||
|
||||
|
||||
def run_git(dir, name, command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live, autofix=True):
|
||||
try:
|
||||
return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
if not autofix:
|
||||
return None
|
||||
|
||||
print(f"{errdesc}, attempting autofix...")
|
||||
git_fix_workspace(dir, name)
|
||||
|
||||
return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
|
||||
|
||||
|
||||
def git_clone(url, dir, name, commithash=None):
|
||||
# TODO clone into temporary dir and move if successful
|
||||
|
||||
@ -146,12 +167,14 @@ def git_clone(url, dir, name, commithash=None):
|
||||
if commithash is None:
|
||||
return
|
||||
|
||||
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
|
||||
current_hash = run_git(dir, name, 'rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
|
||||
if current_hash == commithash:
|
||||
return
|
||||
|
||||
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
||||
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
||||
run_git('fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
||||
|
||||
run_git('checkout', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
||||
|
||||
return
|
||||
|
||||
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
|
||||
|
@ -1119,9 +1119,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
||||
img2img_sampler_name = self.hr_sampler_name or self.sampler_name
|
||||
|
||||
if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
|
||||
img2img_sampler_name = 'DDIM'
|
||||
|
||||
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
|
||||
|
||||
if self.latent_scale_mode is not None:
|
||||
|
@ -5,7 +5,7 @@ from types import MethodType
|
||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules.shared import cmd_opts
|
||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting
|
||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
||||
|
||||
import ldm.modules.attention
|
||||
import ldm.modules.diffusionmodules.model
|
||||
@ -34,8 +34,6 @@ ldm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||
ldm.util.print = shared.ldm_print
|
||||
ldm.models.diffusion.ddpm.print = shared.ldm_print
|
||||
|
||||
sd_hijack_inpainting.do_inpainting_hijack()
|
||||
|
||||
optimizers = []
|
||||
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
||||
|
||||
|
@ -1,95 +0,0 @@
|
||||
import torch
|
||||
|
||||
import ldm.models.diffusion.ddpm
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
|
||||
from ldm.models.diffusion.ddim import noise_like
|
||||
from ldm.models.diffusion.sampling_util import norm_thresholding
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = {}
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [
|
||||
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
||||
for i in range(len(c[k]))
|
||||
]
|
||||
else:
|
||||
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
if dynamic_threshold is not None:
|
||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
e_t = get_model_output(x, t)
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = get_model_output(x_prev, t_next)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
elif len(old_eps) >= 3:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
return x_prev, pred_x0, e_t
|
||||
|
||||
|
||||
def do_inpainting_hijack():
|
||||
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
@ -372,7 +372,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
|
||||
sd_vae.delete_base_vae()
|
||||
sd_vae.clear_loaded_vae()
|
||||
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
|
||||
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
|
||||
sd_vae.load_vae(model, vae_file, vae_source)
|
||||
timer.record("load VAE")
|
||||
|
||||
@ -715,6 +715,7 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
print(f"Weights loaded in {timer.summary()}.")
|
||||
|
||||
model_data.set_sd_model(sd_model)
|
||||
sd_unet.apply_unet()
|
||||
|
||||
return sd_model
|
||||
|
||||
|
@ -1,17 +1,18 @@
|
||||
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
|
||||
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared
|
||||
|
||||
# imports for functions that previously were here and are used by other modules
|
||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
||||
|
||||
all_samplers = [
|
||||
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
||||
*sd_samplers_compvis.samplers_data_compvis,
|
||||
*sd_samplers_timesteps.samplers_data_timesteps,
|
||||
]
|
||||
all_samplers_map = {x.name: x for x in all_samplers}
|
||||
|
||||
samplers = []
|
||||
samplers_for_img2img = []
|
||||
samplers_map = {}
|
||||
samplers_hidden = {}
|
||||
|
||||
|
||||
def find_sampler_config(name):
|
||||
@ -38,13 +39,11 @@ def create_sampler(name, model):
|
||||
|
||||
|
||||
def set_samplers():
|
||||
global samplers, samplers_for_img2img
|
||||
global samplers, samplers_for_img2img, samplers_hidden
|
||||
|
||||
hidden = set(shared.opts.hide_samplers)
|
||||
hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC'])
|
||||
|
||||
samplers = [x for x in all_samplers if x.name not in hidden]
|
||||
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
|
||||
samplers_hidden = set(shared.opts.hide_samplers)
|
||||
samplers = all_samplers
|
||||
samplers_for_img2img = all_samplers
|
||||
|
||||
samplers_map.clear()
|
||||
for sampler in all_samplers:
|
||||
@ -53,4 +52,8 @@ def set_samplers():
|
||||
samplers_map[alias.lower()] = sampler.name
|
||||
|
||||
|
||||
def visible_sampler_names():
|
||||
return [x.name for x in samplers if x.name not in samplers_hidden]
|
||||
|
||||
|
||||
set_samplers()
|
||||
|
203
modules/sd_samplers_cfg_denoiser.py
Normal file
203
modules/sd_samplers_cfg_denoiser.py
Normal file
@ -0,0 +1,203 @@
|
||||
import torch
|
||||
from modules import prompt_parser, devices, sd_samplers_common
|
||||
|
||||
from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
||||
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
|
||||
|
||||
|
||||
def catenate_conds(conds):
|
||||
if not isinstance(conds[0], dict):
|
||||
return torch.cat(conds)
|
||||
|
||||
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
|
||||
|
||||
|
||||
def subscript_cond(cond, a, b):
|
||||
if not isinstance(cond, dict):
|
||||
return cond[a:b]
|
||||
|
||||
return {key: vec[a:b] for key, vec in cond.items()}
|
||||
|
||||
|
||||
def pad_cond(tensor, repeats, empty):
|
||||
if not isinstance(tensor, dict):
|
||||
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
|
||||
|
||||
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
|
||||
return tensor
|
||||
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
"""
|
||||
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
||||
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
||||
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
||||
negative prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, model, sampler):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.step = 0
|
||||
self.image_cfg_scale = None
|
||||
self.padded_cond_uncond = False
|
||||
self.sampler = sampler
|
||||
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
denoised = torch.clone(denoised_uncond)
|
||||
|
||||
for i, conds in enumerate(conds_list):
|
||||
for cond_index, weight in conds:
|
||||
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
||||
|
||||
return denoised
|
||||
|
||||
def combine_denoised_for_edit_model(self, x_out, cond_scale):
|
||||
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
||||
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
|
||||
|
||||
return denoised
|
||||
|
||||
def get_pred_x0(self, x_in, x_out, sigma):
|
||||
return x_out
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
||||
# so is_edit_model is set to False to support AND composition.
|
||||
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||
|
||||
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||
|
||||
if self.mask is not None:
|
||||
x = self.init_latent * self.mask + self.nmask * x
|
||||
|
||||
batch_size = len(conds_list)
|
||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||
|
||||
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
||||
image_uncond = torch.zeros_like(image_cond)
|
||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
|
||||
else:
|
||||
image_uncond = image_cond
|
||||
if isinstance(uncond, dict):
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
|
||||
else:
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
|
||||
|
||||
if not is_edit_model:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
|
||||
else:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
|
||||
|
||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
|
||||
cfg_denoiser_callback(denoiser_params)
|
||||
x_in = denoiser_params.x
|
||||
image_cond_in = denoiser_params.image_cond
|
||||
sigma_in = denoiser_params.sigma
|
||||
tensor = denoiser_params.text_cond
|
||||
uncond = denoiser_params.text_uncond
|
||||
skip_uncond = False
|
||||
|
||||
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
|
||||
if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
|
||||
skip_uncond = True
|
||||
x_in = x_in[:-batch_size]
|
||||
sigma_in = sigma_in[:-batch_size]
|
||||
|
||||
self.padded_cond_uncond = False
|
||||
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
||||
empty = shared.sd_model.cond_stage_model_empty_prompt
|
||||
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
||||
|
||||
if num_repeats < 0:
|
||||
tensor = pad_cond(tensor, -num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
elif num_repeats > 0:
|
||||
uncond = pad_cond(uncond, num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
||||
if is_edit_model:
|
||||
cond_in = catenate_conds([tensor, uncond, uncond])
|
||||
elif skip_uncond:
|
||||
cond_in = tensor
|
||||
else:
|
||||
cond_in = catenate_conds([tensor, uncond])
|
||||
|
||||
if shared.batch_cond_uncond:
|
||||
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = a + batch_size
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = min(a + batch_size, tensor.shape[0])
|
||||
|
||||
if not is_edit_model:
|
||||
c_crossattn = subscript_cond(tensor, a, b)
|
||||
else:
|
||||
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
||||
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
||||
|
||||
if not skip_uncond:
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
|
||||
|
||||
denoised_image_indexes = [x[0][0] for x in conds_list]
|
||||
if skip_uncond:
|
||||
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
||||
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
||||
|
||||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
|
||||
cfg_denoised_callback(denoised_params)
|
||||
|
||||
devices.test_for_nans(x_out, "unet")
|
||||
|
||||
if is_edit_model:
|
||||
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
||||
elif skip_uncond:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
|
||||
else:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
|
||||
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
|
||||
|
||||
if opts.live_preview_content == "Prompt":
|
||||
preview = self.sampler.last_latent
|
||||
elif opts.live_preview_content == "Negative prompt":
|
||||
preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
|
||||
else:
|
||||
preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)
|
||||
|
||||
sd_samplers_common.store_latent(preview)
|
||||
|
||||
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
|
||||
cfg_after_cfg_callback(after_cfg_callback_params)
|
||||
denoised = after_cfg_callback_params.x
|
||||
|
||||
self.step += 1
|
||||
return denoised
|
||||
|
@ -1,9 +1,11 @@
|
||||
from collections import namedtuple
|
||||
import inspect
|
||||
from collections import namedtuple, deque
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
|
||||
from modules.shared import opts, state
|
||||
import k_diffusion.sampling
|
||||
|
||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||
|
||||
@ -155,3 +157,137 @@ def apply_refiner(sampler):
|
||||
return True
|
||||
|
||||
|
||||
class TorchHijack:
|
||||
def __init__(self, sampler_noises):
|
||||
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
||||
# implementation.
|
||||
self.sampler_noises = deque(sampler_noises)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'randn_like':
|
||||
return self.randn_like
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||
|
||||
def randn_like(self, x):
|
||||
if self.sampler_noises:
|
||||
noise = self.sampler_noises.popleft()
|
||||
if noise.shape == x.shape:
|
||||
return noise
|
||||
|
||||
return devices.randn_like(x)
|
||||
|
||||
|
||||
class Sampler:
|
||||
def __init__(self, funcname):
|
||||
self.funcname = funcname
|
||||
self.func = funcname
|
||||
self.extra_params = []
|
||||
self.sampler_noises = None
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.config = None # set by the function calling the constructor
|
||||
self.last_latent = None
|
||||
self.s_min_uncond = None
|
||||
self.s_churn = 0.0
|
||||
self.s_tmin = 0.0
|
||||
self.s_tmax = float('inf')
|
||||
self.s_noise = 1.0
|
||||
|
||||
self.eta_option_field = 'eta_ancestral'
|
||||
self.eta_infotext_field = 'Eta'
|
||||
|
||||
self.conditioning_key = shared.sd_model.model.conditioning_key
|
||||
|
||||
self.model_wrap = None
|
||||
self.model_wrap_cfg = None
|
||||
|
||||
def callback_state(self, d):
|
||||
step = d['i']
|
||||
|
||||
if self.stop_at is not None and step > self.stop_at:
|
||||
raise InterruptedException
|
||||
|
||||
state.sampling_step = step
|
||||
shared.total_tqdm.update()
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
state.sampling_steps = steps
|
||||
state.sampling_step = 0
|
||||
|
||||
try:
|
||||
return func()
|
||||
except RecursionError:
|
||||
print(
|
||||
'Encountered RecursionError during sampling, returning last latent. '
|
||||
'rho >5 with a polyexponential scheduler may cause this error. '
|
||||
'You should try to use a smaller rho value instead.'
|
||||
)
|
||||
return self.last_latent
|
||||
except InterruptedException:
|
||||
return self.last_latent
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return p.steps
|
||||
|
||||
def initialize(self, p) -> dict:
|
||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
self.model_wrap_cfg.step = 0
|
||||
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||
self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
|
||||
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
|
||||
|
||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||
|
||||
extra_params_kwargs = {}
|
||||
for param_name in self.extra_params:
|
||||
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs[param_name] = getattr(p, param_name)
|
||||
|
||||
if 'eta' in inspect.signature(self.func).parameters:
|
||||
if self.eta != 1.0:
|
||||
p.extra_generation_params[self.eta_infotext_field] = self.eta
|
||||
|
||||
extra_params_kwargs['eta'] = self.eta
|
||||
|
||||
if len(self.extra_params) > 0:
|
||||
s_churn = getattr(opts, 's_churn', p.s_churn)
|
||||
s_tmin = getattr(opts, 's_tmin', p.s_tmin)
|
||||
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
|
||||
s_noise = getattr(opts, 's_noise', p.s_noise)
|
||||
|
||||
if s_churn != self.s_churn:
|
||||
extra_params_kwargs['s_churn'] = s_churn
|
||||
p.s_churn = s_churn
|
||||
p.extra_generation_params['Sigma churn'] = s_churn
|
||||
if s_tmin != self.s_tmin:
|
||||
extra_params_kwargs['s_tmin'] = s_tmin
|
||||
p.s_tmin = s_tmin
|
||||
p.extra_generation_params['Sigma tmin'] = s_tmin
|
||||
if s_tmax != self.s_tmax:
|
||||
extra_params_kwargs['s_tmax'] = s_tmax
|
||||
p.s_tmax = s_tmax
|
||||
p.extra_generation_params['Sigma tmax'] = s_tmax
|
||||
if s_noise != self.s_noise:
|
||||
extra_params_kwargs['s_noise'] = s_noise
|
||||
p.s_noise = s_noise
|
||||
p.extra_generation_params['Sigma noise'] = s_noise
|
||||
|
||||
return extra_params_kwargs
|
||||
|
||||
def create_noise_sampler(self, x, sigmas, p):
|
||||
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
|
||||
if shared.opts.no_dpmpp_sde_batch_determinism:
|
||||
return None
|
||||
|
||||
from k_diffusion.sampling import BrownianTreeNoiseSampler
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
|
||||
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
|
||||
|
||||
|
||||
|
||||
|
@ -1,232 +0,0 @@
|
||||
import math
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modules.shared import state
|
||||
from modules import sd_samplers_common, prompt_parser, shared
|
||||
import modules.models.diffusion.uni_pc
|
||||
|
||||
|
||||
samplers_data_compvis = [
|
||||
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}),
|
||||
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}),
|
||||
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}),
|
||||
]
|
||||
|
||||
|
||||
class VanillaStableDiffusionSampler:
|
||||
def __init__(self, constructor, sd_model):
|
||||
self.p = None
|
||||
self.sampler = constructor(shared.sd_model)
|
||||
self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
|
||||
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
|
||||
self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
|
||||
self.orig_p_sample_ddim = None
|
||||
if self.is_plms:
|
||||
self.orig_p_sample_ddim = self.sampler.p_sample_plms
|
||||
elif self.is_ddim:
|
||||
self.orig_p_sample_ddim = self.sampler.p_sample_ddim
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.sampler_noises = None
|
||||
self.steps = None
|
||||
self.step = 0
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.config = None
|
||||
self.last_latent = None
|
||||
|
||||
self.conditioning_key = sd_model.model.conditioning_key
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return 0
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
self.steps = steps
|
||||
state.sampling_steps = steps
|
||||
state.sampling_step = 0
|
||||
|
||||
try:
|
||||
return func()
|
||||
except sd_samplers_common.InterruptedException:
|
||||
return self.last_latent
|
||||
|
||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
|
||||
|
||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
|
||||
|
||||
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
|
||||
|
||||
return res
|
||||
|
||||
def update_inner_model(self):
|
||||
self.sampler.model = shared.sd_model
|
||||
|
||||
def before_sample(self, x, ts, cond, unconditional_conditioning):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
if self.stop_at is not None and self.step > self.stop_at:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
# Have to unwrap the inpainting conditioning here to perform pre-processing
|
||||
image_conditioning = None
|
||||
uc_image_conditioning = None
|
||||
if isinstance(cond, dict):
|
||||
if self.conditioning_key == "crossattn-adm":
|
||||
image_conditioning = cond["c_adm"]
|
||||
uc_image_conditioning = unconditional_conditioning["c_adm"]
|
||||
else:
|
||||
image_conditioning = cond["c_concat"][0]
|
||||
cond = cond["c_crossattn"][0]
|
||||
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||
|
||||
assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
|
||||
cond = tensor
|
||||
|
||||
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
||||
# filling unconditional_conditioning with repeats of the last vector to match length is
|
||||
# not 100% correct but should work well enough
|
||||
if unconditional_conditioning.shape[1] < cond.shape[1]:
|
||||
last_vector = unconditional_conditioning[:, -1:]
|
||||
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
|
||||
unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
|
||||
elif unconditional_conditioning.shape[1] > cond.shape[1]:
|
||||
unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
|
||||
|
||||
if self.mask is not None:
|
||||
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
||||
x = img_orig * self.mask + self.nmask * x
|
||||
|
||||
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
||||
# Note that they need to be lists because it just concatenates them later.
|
||||
if image_conditioning is not None:
|
||||
if self.conditioning_key == "crossattn-adm":
|
||||
cond = {"c_adm": image_conditioning, "c_crossattn": [cond]}
|
||||
unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]}
|
||||
else:
|
||||
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
return x, ts, cond, unconditional_conditioning
|
||||
|
||||
def update_step(self, last_latent):
|
||||
if self.mask is not None:
|
||||
self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
|
||||
else:
|
||||
self.last_latent = last_latent
|
||||
|
||||
sd_samplers_common.store_latent(self.last_latent)
|
||||
|
||||
self.step += 1
|
||||
state.sampling_step = self.step
|
||||
shared.total_tqdm.update()
|
||||
|
||||
def after_sample(self, x, ts, cond, uncond, res):
|
||||
if not self.is_unipc:
|
||||
self.update_step(res[1])
|
||||
|
||||
return x, ts, cond, uncond, res
|
||||
|
||||
def unipc_after_update(self, x, model_x):
|
||||
self.update_step(x)
|
||||
|
||||
def initialize(self, p):
|
||||
self.p = p
|
||||
|
||||
if self.is_ddim:
|
||||
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
||||
else:
|
||||
self.eta = 0.0
|
||||
|
||||
if self.eta != 0.0:
|
||||
p.extra_generation_params["Eta DDIM"] = self.eta
|
||||
|
||||
if self.is_unipc:
|
||||
keys = [
|
||||
('UniPC variant', 'uni_pc_variant'),
|
||||
('UniPC skip type', 'uni_pc_skip_type'),
|
||||
('UniPC order', 'uni_pc_order'),
|
||||
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
||||
]
|
||||
|
||||
for name, key in keys:
|
||||
v = getattr(shared.opts, key)
|
||||
if v != shared.opts.get_default(key):
|
||||
p.extra_generation_params[name] = v
|
||||
|
||||
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
||||
if hasattr(self.sampler, fieldname):
|
||||
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
||||
if self.is_unipc:
|
||||
self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
|
||||
|
||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
|
||||
|
||||
def adjust_steps_if_invalid(self, p, num_steps):
|
||||
if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'):
|
||||
if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order:
|
||||
num_steps = shared.opts.uni_pc_order
|
||||
valid_step = 999 / (1000 // num_steps)
|
||||
if valid_step == math.floor(valid_step):
|
||||
return int(valid_step) + 1
|
||||
|
||||
return num_steps
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||
steps = self.adjust_steps_if_invalid(p, steps)
|
||||
self.initialize(p)
|
||||
|
||||
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
||||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
||||
|
||||
self.init_latent = x
|
||||
self.last_latent = x
|
||||
self.step = 0
|
||||
|
||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||
if image_conditioning is not None:
|
||||
if self.conditioning_key == "crossattn-adm":
|
||||
conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]}
|
||||
unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]}
|
||||
else:
|
||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
self.initialize(p)
|
||||
|
||||
self.init_latent = None
|
||||
self.last_latent = x
|
||||
self.step = 0
|
||||
|
||||
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
|
||||
|
||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
|
||||
if image_conditioning is not None:
|
||||
if self.conditioning_key == "crossattn-adm":
|
||||
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning}
|
||||
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)}
|
||||
else:
|
||||
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
|
||||
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
|
||||
|
||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||
|
||||
return samples_ddim
|
@ -1,17 +1,16 @@
|
||||
from collections import deque
|
||||
import torch
|
||||
import inspect
|
||||
import k_diffusion.sampling
|
||||
from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra
|
||||
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser
|
||||
|
||||
from modules.processing import StableDiffusionProcessing
|
||||
from modules.shared import opts, state
|
||||
from modules.shared import opts
|
||||
import modules.shared as shared
|
||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
||||
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
|
||||
|
||||
samplers_k_diffusion = [
|
||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
||||
('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
|
||||
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
|
||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
|
||||
('Euler', 'sample_euler', ['k_euler'], {}),
|
||||
('LMS', 'sample_lms', ['k_lms'], {}),
|
||||
@ -28,10 +27,6 @@ samplers_k_diffusion = [
|
||||
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
|
||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
||||
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
|
||||
('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
|
||||
('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}),
|
||||
]
|
||||
|
||||
@ -57,342 +52,17 @@ k_diffusion_scheduler = {
|
||||
}
|
||||
|
||||
|
||||
def catenate_conds(conds):
|
||||
if not isinstance(conds[0], dict):
|
||||
return torch.cat(conds)
|
||||
|
||||
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
|
||||
|
||||
|
||||
def subscript_cond(cond, a, b):
|
||||
if not isinstance(cond, dict):
|
||||
return cond[a:b]
|
||||
|
||||
return {key: vec[a:b] for key, vec in cond.items()}
|
||||
|
||||
|
||||
def pad_cond(tensor, repeats, empty):
|
||||
if not isinstance(tensor, dict):
|
||||
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
|
||||
|
||||
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
|
||||
return tensor
|
||||
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
"""
|
||||
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
||||
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
||||
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
||||
negative prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, sampler):
|
||||
super().__init__()
|
||||
self.sampler = sampler
|
||||
self.model_wrap = None
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.steps = None
|
||||
self.step = 0
|
||||
self.image_cfg_scale = None
|
||||
self.padded_cond_uncond = False
|
||||
self.p = None
|
||||
|
||||
@property
|
||||
def inner_model(self):
|
||||
if self.model_wrap is None:
|
||||
denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||
self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
|
||||
|
||||
return self.model_wrap
|
||||
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
denoised = torch.clone(denoised_uncond)
|
||||
|
||||
for i, conds in enumerate(conds_list):
|
||||
for cond_index, weight in conds:
|
||||
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
||||
|
||||
return denoised
|
||||
|
||||
def combine_denoised_for_edit_model(self, x_out, cond_scale):
|
||||
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
||||
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
|
||||
|
||||
return denoised
|
||||
|
||||
def update_inner_model(self):
|
||||
self.model_wrap = None
|
||||
|
||||
c, uc = self.p.get_conds()
|
||||
self.sampler.sampler_extra_args['cond'] = c
|
||||
self.sampler.sampler_extra_args['uncond'] = uc
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
if sd_samplers_common.apply_refiner(self):
|
||||
cond = self.sampler.sampler_extra_args['cond']
|
||||
uncond = self.sampler.sampler_extra_args['uncond']
|
||||
|
||||
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
||||
# so is_edit_model is set to False to support AND composition.
|
||||
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||
|
||||
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||
|
||||
batch_size = len(conds_list)
|
||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||
|
||||
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
||||
image_uncond = torch.zeros_like(image_cond)
|
||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
|
||||
else:
|
||||
image_uncond = image_cond
|
||||
if isinstance(uncond, dict):
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
|
||||
else:
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
|
||||
|
||||
if not is_edit_model:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
|
||||
else:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
|
||||
|
||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
|
||||
cfg_denoiser_callback(denoiser_params)
|
||||
x_in = denoiser_params.x
|
||||
image_cond_in = denoiser_params.image_cond
|
||||
sigma_in = denoiser_params.sigma
|
||||
tensor = denoiser_params.text_cond
|
||||
uncond = denoiser_params.text_uncond
|
||||
skip_uncond = False
|
||||
|
||||
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
|
||||
if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
|
||||
skip_uncond = True
|
||||
x_in = x_in[:-batch_size]
|
||||
sigma_in = sigma_in[:-batch_size]
|
||||
|
||||
self.padded_cond_uncond = False
|
||||
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
||||
empty = shared.sd_model.cond_stage_model_empty_prompt
|
||||
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
||||
|
||||
if num_repeats < 0:
|
||||
tensor = pad_cond(tensor, -num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
elif num_repeats > 0:
|
||||
uncond = pad_cond(uncond, num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
||||
if is_edit_model:
|
||||
cond_in = catenate_conds([tensor, uncond, uncond])
|
||||
elif skip_uncond:
|
||||
cond_in = tensor
|
||||
else:
|
||||
cond_in = catenate_conds([tensor, uncond])
|
||||
|
||||
if shared.batch_cond_uncond:
|
||||
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = a + batch_size
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = min(a + batch_size, tensor.shape[0])
|
||||
|
||||
if not is_edit_model:
|
||||
c_crossattn = subscript_cond(tensor, a, b)
|
||||
else:
|
||||
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
||||
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
||||
|
||||
if not skip_uncond:
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
|
||||
|
||||
denoised_image_indexes = [x[0][0] for x in conds_list]
|
||||
if skip_uncond:
|
||||
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
||||
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
||||
|
||||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
|
||||
cfg_denoised_callback(denoised_params)
|
||||
|
||||
devices.test_for_nans(x_out, "unet")
|
||||
|
||||
if opts.live_preview_content == "Prompt":
|
||||
sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
|
||||
elif opts.live_preview_content == "Negative prompt":
|
||||
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
||||
|
||||
if is_edit_model:
|
||||
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
||||
elif skip_uncond:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
|
||||
else:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
|
||||
if self.mask is not None:
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
|
||||
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
|
||||
cfg_after_cfg_callback(after_cfg_callback_params)
|
||||
denoised = after_cfg_callback_params.x
|
||||
|
||||
self.step += 1
|
||||
return denoised
|
||||
|
||||
|
||||
class TorchHijack:
|
||||
def __init__(self, sampler_noises):
|
||||
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
||||
# implementation.
|
||||
self.sampler_noises = deque(sampler_noises)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'randn_like':
|
||||
return self.randn_like
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||
|
||||
def randn_like(self, x):
|
||||
if self.sampler_noises:
|
||||
noise = self.sampler_noises.popleft()
|
||||
if noise.shape == x.shape:
|
||||
return noise
|
||||
|
||||
return devices.randn_like(x)
|
||||
|
||||
|
||||
class KDiffusionSampler:
|
||||
class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
def __init__(self, funcname, sd_model):
|
||||
self.p = None
|
||||
self.funcname = funcname
|
||||
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
||||
|
||||
super().__init__(funcname)
|
||||
|
||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||
self.sampler_extra_args = {}
|
||||
self.model_wrap_cfg = CFGDenoiser(self)
|
||||
self.model_wrap = self.model_wrap_cfg.inner_model
|
||||
self.sampler_noises = None
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.config = None # set by the function calling the constructor
|
||||
self.last_latent = None
|
||||
self.s_min_uncond = None
|
||||
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
||||
|
||||
# NOTE: These are also defined in the StableDiffusionProcessing class.
|
||||
# They should have been here to begin with but we're going to
|
||||
# leave that class __init__ signature alone.
|
||||
self.s_churn = 0.0
|
||||
self.s_tmin = 0.0
|
||||
self.s_tmax = float('inf')
|
||||
self.s_noise = 1.0
|
||||
|
||||
self.conditioning_key = sd_model.model.conditioning_key
|
||||
|
||||
def callback_state(self, d):
|
||||
step = d['i']
|
||||
latent = d["denoised"]
|
||||
if opts.live_preview_content == "Combined":
|
||||
sd_samplers_common.store_latent(latent)
|
||||
self.last_latent = latent
|
||||
|
||||
if self.stop_at is not None and step > self.stop_at:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
state.sampling_step = step
|
||||
shared.total_tqdm.update()
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
self.model_wrap_cfg.steps = steps
|
||||
state.sampling_steps = steps
|
||||
state.sampling_step = 0
|
||||
|
||||
try:
|
||||
return func()
|
||||
except RecursionError:
|
||||
print(
|
||||
'Encountered RecursionError during sampling, returning last latent. '
|
||||
'rho >5 with a polyexponential scheduler may cause this error. '
|
||||
'You should try to use a smaller rho value instead.'
|
||||
)
|
||||
return self.last_latent
|
||||
except sd_samplers_common.InterruptedException:
|
||||
return self.last_latent
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return p.steps
|
||||
|
||||
def initialize(self, p: StableDiffusionProcessing):
|
||||
self.p = p
|
||||
self.model_wrap_cfg.p = p
|
||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
self.model_wrap_cfg.step = 0
|
||||
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
|
||||
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
|
||||
|
||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||
|
||||
extra_params_kwargs = {}
|
||||
for param_name in self.extra_params:
|
||||
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs[param_name] = getattr(p, param_name)
|
||||
|
||||
if 'eta' in inspect.signature(self.func).parameters:
|
||||
if self.eta != 1.0:
|
||||
p.extra_generation_params["Eta"] = self.eta
|
||||
|
||||
extra_params_kwargs['eta'] = self.eta
|
||||
|
||||
if len(self.extra_params) > 0:
|
||||
s_churn = getattr(opts, 's_churn', p.s_churn)
|
||||
s_tmin = getattr(opts, 's_tmin', p.s_tmin)
|
||||
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
|
||||
s_noise = getattr(opts, 's_noise', p.s_noise)
|
||||
|
||||
if s_churn != self.s_churn:
|
||||
extra_params_kwargs['s_churn'] = s_churn
|
||||
p.s_churn = s_churn
|
||||
p.extra_generation_params['Sigma churn'] = s_churn
|
||||
if s_tmin != self.s_tmin:
|
||||
extra_params_kwargs['s_tmin'] = s_tmin
|
||||
p.s_tmin = s_tmin
|
||||
p.extra_generation_params['Sigma tmin'] = s_tmin
|
||||
if s_tmax != self.s_tmax:
|
||||
extra_params_kwargs['s_tmax'] = s_tmax
|
||||
p.s_tmax = s_tmax
|
||||
p.extra_generation_params['Sigma tmax'] = s_tmax
|
||||
if s_noise != self.s_noise:
|
||||
extra_params_kwargs['s_noise'] = s_noise
|
||||
p.s_noise = s_noise
|
||||
p.extra_generation_params['Sigma noise'] = s_noise
|
||||
|
||||
return extra_params_kwargs
|
||||
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||
self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap, self)
|
||||
|
||||
def get_sigmas(self, p, steps):
|
||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||
@ -444,22 +114,12 @@ class KDiffusionSampler:
|
||||
|
||||
return sigmas
|
||||
|
||||
def create_noise_sampler(self, x, sigmas, p):
|
||||
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
|
||||
if shared.opts.no_dpmpp_sde_batch_determinism:
|
||||
return None
|
||||
|
||||
from k_diffusion.sampling import BrownianTreeNoiseSampler
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
|
||||
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||
|
||||
sigmas = self.get_sigmas(p, steps)
|
||||
|
||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||
|
||||
xi = x + noise * sigma_sched[0]
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
@ -508,12 +168,14 @@ class KDiffusionSampler:
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
parameters = inspect.signature(self.func).parameters
|
||||
|
||||
if 'n' in parameters:
|
||||
extra_params_kwargs['n'] = steps
|
||||
|
||||
if 'sigma_min' in parameters:
|
||||
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
||||
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
||||
if 'n' in parameters:
|
||||
extra_params_kwargs['n'] = steps
|
||||
else:
|
||||
|
||||
if 'sigmas' in parameters:
|
||||
extra_params_kwargs['sigmas'] = sigmas
|
||||
|
||||
if self.config.options.get('brownian_noise', False):
|
||||
@ -535,3 +197,4 @@ class KDiffusionSampler:
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
|
147
modules/sd_samplers_timesteps.py
Normal file
147
modules/sd_samplers_timesteps.py
Normal file
@ -0,0 +1,147 @@
|
||||
import torch
|
||||
import inspect
|
||||
from modules import devices, sd_samplers_common, sd_samplers_timesteps_impl
|
||||
from modules.sd_samplers_cfg_denoiser import CFGDenoiser
|
||||
|
||||
from modules.shared import opts
|
||||
import modules.shared as shared
|
||||
|
||||
samplers_timesteps = [
|
||||
('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}),
|
||||
('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}),
|
||||
('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}),
|
||||
]
|
||||
|
||||
|
||||
samplers_data_timesteps = [
|
||||
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: CompVisSampler(funcname, model), aliases, options)
|
||||
for label, funcname, aliases, options in samplers_timesteps
|
||||
]
|
||||
|
||||
|
||||
class CompVisTimestepsDenoiser(torch.nn.Module):
|
||||
def __init__(self, model, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.inner_model = model
|
||||
|
||||
def forward(self, input, timesteps, **kwargs):
|
||||
return self.inner_model.apply_model(input, timesteps, **kwargs)
|
||||
|
||||
|
||||
class CompVisTimestepsVDenoiser(torch.nn.Module):
|
||||
def __init__(self, model, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.inner_model = model
|
||||
|
||||
def predict_eps_from_z_and_v(self, x_t, t, v):
|
||||
return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
|
||||
|
||||
def forward(self, input, timesteps, **kwargs):
|
||||
model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
|
||||
e_t = self.predict_eps_from_z_and_v(input, timesteps, model_output)
|
||||
return e_t
|
||||
|
||||
|
||||
class CFGDenoiserTimesteps(CFGDenoiser):
|
||||
|
||||
def __init__(self, model, sampler):
|
||||
super().__init__(model, sampler)
|
||||
|
||||
self.alphas = model.inner_model.alphas_cumprod
|
||||
|
||||
def get_pred_x0(self, x_in, x_out, sigma):
|
||||
ts = int(sigma.item())
|
||||
|
||||
s_in = x_in.new_ones([x_in.shape[0]])
|
||||
a_t = self.alphas[ts].item() * s_in
|
||||
sqrt_one_minus_at = (1 - a_t).sqrt()
|
||||
|
||||
pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt()
|
||||
|
||||
return pred_x0
|
||||
|
||||
|
||||
class CompVisSampler(sd_samplers_common.Sampler):
|
||||
def __init__(self, funcname, sd_model):
|
||||
super().__init__(funcname)
|
||||
|
||||
self.eta_option_field = 'eta_ddim'
|
||||
self.eta_infotext_field = 'Eta DDIM'
|
||||
|
||||
denoiser = CompVisTimestepsVDenoiser if sd_model.parameterization == "v" else CompVisTimestepsDenoiser
|
||||
self.model_wrap = denoiser(sd_model)
|
||||
self.model_wrap_cfg = CFGDenoiserTimesteps(self.model_wrap, self)
|
||||
|
||||
def get_timesteps(self, p, steps):
|
||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||
if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
|
||||
discard_next_to_last_sigma = True
|
||||
p.extra_generation_params["Discard penultimate sigma"] = True
|
||||
|
||||
steps += 1 if discard_next_to_last_sigma else 0
|
||||
|
||||
timesteps = torch.clip(torch.asarray(list(range(0, 1000, 1000 // steps)), device=devices.device) + 1, 0, 999)
|
||||
|
||||
return timesteps
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||
|
||||
timesteps = self.get_timesteps(p, steps)
|
||||
timesteps_sched = timesteps[:t_enc]
|
||||
|
||||
alphas_cumprod = shared.sd_model.alphas_cumprod
|
||||
sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[timesteps[t_enc]])
|
||||
sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[timesteps[t_enc]])
|
||||
|
||||
xi = x * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
parameters = inspect.signature(self.func).parameters
|
||||
|
||||
if 'timesteps' in parameters:
|
||||
extra_params_kwargs['timesteps'] = timesteps_sched
|
||||
if 'is_img2img' in parameters:
|
||||
extra_params_kwargs['is_img2img'] = True
|
||||
|
||||
self.model_wrap_cfg.init_latent = x
|
||||
self.last_latent = x
|
||||
extra_args = {
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale,
|
||||
's_min_uncond': self.s_min_uncond
|
||||
}
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
if self.model_wrap_cfg.padded_cond_uncond:
|
||||
p.extra_generation_params["Pad conds"] = True
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps = steps or p.steps
|
||||
timesteps = self.get_timesteps(p, steps)
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
parameters = inspect.signature(self.func).parameters
|
||||
|
||||
if 'timesteps' in parameters:
|
||||
extra_params_kwargs['timesteps'] = timesteps
|
||||
|
||||
self.last_latent = x
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale,
|
||||
's_min_uncond': self.s_min_uncond
|
||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
if self.model_wrap_cfg.padded_cond_uncond:
|
||||
p.extra_generation_params["Pad conds"] = True
|
||||
|
||||
return samples
|
||||
|
135
modules/sd_samplers_timesteps_impl.py
Normal file
135
modules/sd_samplers_timesteps_impl.py
Normal file
@ -0,0 +1,135 @@
|
||||
import torch
|
||||
import tqdm
|
||||
import k_diffusion.sampling
|
||||
import numpy as np
|
||||
|
||||
from modules import shared
|
||||
from modules.models.diffusion.uni_pc import uni_pc
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
|
||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||
alphas = alphas_cumprod[timesteps]
|
||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64)
|
||||
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
|
||||
index = len(timesteps) - 1 - i
|
||||
|
||||
e_t = model(x, timesteps[index].item() * s_in, **extra_args)
|
||||
|
||||
a_t = alphas[index].item() * s_in
|
||||
a_prev = alphas_prev[index].item() * s_in
|
||||
sigma_t = sigmas[index].item() * s_in
|
||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in
|
||||
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
|
||||
noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
|
||||
x = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
|
||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||
alphas = alphas_cumprod[timesteps]
|
||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64)
|
||||
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
old_eps = []
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = alphas[index].item() * s_in
|
||||
a_prev = alphas_prev[index].item() * s_in
|
||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev).sqrt() * e_t
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
|
||||
return x_prev, pred_x0
|
||||
|
||||
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
|
||||
index = len(timesteps) - 1 - i
|
||||
ts = timesteps[index].item() * s_in
|
||||
t_next = timesteps[max(index - 1, 0)].item() * s_in
|
||||
|
||||
e_t = model(x, ts, **extra_args)
|
||||
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = model(x_prev, t_next, **extra_args)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
else:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
old_eps.append(e_t)
|
||||
if len(old_eps) >= 4:
|
||||
old_eps.pop(0)
|
||||
|
||||
x = x_prev
|
||||
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class UniPCCFG(uni_pc.UniPC):
|
||||
def __init__(self, cfg_model, extra_args, callback, *args, **kwargs):
|
||||
super().__init__(None, *args, **kwargs)
|
||||
|
||||
def after_update(x, model_x):
|
||||
callback({'x': x, 'i': self.index, 'sigma': 0, 'sigma_hat': 0, 'denoised': model_x})
|
||||
self.index += 1
|
||||
|
||||
self.cfg_model = cfg_model
|
||||
self.extra_args = extra_args
|
||||
self.callback = callback
|
||||
self.index = 0
|
||||
self.after_update = after_update
|
||||
|
||||
def get_model_input_time(self, t_continuous):
|
||||
return (t_continuous - 1. / self.noise_schedule.total_N) * 1000.
|
||||
|
||||
def model(self, x, t):
|
||||
t_input = self.get_model_input_time(t)
|
||||
|
||||
res = self.cfg_model(x, t_input, **self.extra_args)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):
|
||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||
|
||||
ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
||||
t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means
|
||||
unipc_sampler = UniPCCFG(model, extra_args, callback, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant)
|
||||
x = unipc_sampler.sample(x, steps=len(timesteps), t_start=t_start, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
|
||||
|
||||
return x
|
@ -1,5 +1,7 @@
|
||||
import os
|
||||
import collections
|
||||
from dataclasses import dataclass
|
||||
|
||||
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks
|
||||
import glob
|
||||
from copy import deepcopy
|
||||
@ -97,37 +99,74 @@ def find_vae_near_checkpoint(checkpoint_file):
|
||||
return None
|
||||
|
||||
|
||||
def resolve_vae(checkpoint_file):
|
||||
if shared.cmd_opts.vae_path is not None:
|
||||
return shared.cmd_opts.vae_path, 'from commandline argument'
|
||||
@dataclass
|
||||
class VaeResolution:
|
||||
vae: str = None
|
||||
source: str = None
|
||||
resolved: bool = True
|
||||
|
||||
def tuple(self):
|
||||
return self.vae, self.source
|
||||
|
||||
|
||||
def is_automatic():
|
||||
return shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
|
||||
|
||||
|
||||
def resolve_vae_from_setting() -> VaeResolution:
|
||||
if shared.opts.sd_vae == "None":
|
||||
return VaeResolution()
|
||||
|
||||
vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
|
||||
if vae_from_options is not None:
|
||||
return VaeResolution(vae_from_options, 'specified in settings')
|
||||
|
||||
if not is_automatic():
|
||||
print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
|
||||
|
||||
return VaeResolution(resolved=False)
|
||||
|
||||
|
||||
def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution:
|
||||
metadata = extra_networks.get_user_metadata(checkpoint_file)
|
||||
vae_metadata = metadata.get("vae", None)
|
||||
if vae_metadata is not None and vae_metadata != "Automatic":
|
||||
if vae_metadata == "None":
|
||||
return None, None
|
||||
return VaeResolution()
|
||||
|
||||
vae_from_metadata = vae_dict.get(vae_metadata, None)
|
||||
if vae_from_metadata is not None:
|
||||
return vae_from_metadata, "from user metadata"
|
||||
return VaeResolution(vae_from_metadata, "from user metadata")
|
||||
|
||||
is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
|
||||
return VaeResolution(resolved=False)
|
||||
|
||||
|
||||
def resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution:
|
||||
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
|
||||
if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
|
||||
return vae_near_checkpoint, 'found near the checkpoint'
|
||||
return VaeResolution(vae_near_checkpoint, 'found near the checkpoint')
|
||||
|
||||
if shared.opts.sd_vae == "None":
|
||||
return None, None
|
||||
return VaeResolution(resolved=False)
|
||||
|
||||
vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
|
||||
if vae_from_options is not None:
|
||||
return vae_from_options, 'specified in settings'
|
||||
|
||||
if not is_automatic:
|
||||
print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
|
||||
def resolve_vae(checkpoint_file) -> VaeResolution:
|
||||
if shared.cmd_opts.vae_path is not None:
|
||||
return VaeResolution(shared.cmd_opts.vae_path, 'from commandline argument')
|
||||
|
||||
return None, None
|
||||
if shared.opts.sd_vae_overrides_per_model_preferences and not is_automatic():
|
||||
return resolve_vae_from_setting()
|
||||
|
||||
res = resolve_vae_from_user_metadata(checkpoint_file)
|
||||
if res.resolved:
|
||||
return res
|
||||
|
||||
res = resolve_vae_near_checkpoint(checkpoint_file)
|
||||
if res.resolved:
|
||||
return res
|
||||
|
||||
res = resolve_vae_from_setting()
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def load_vae_dict(filename, map_location):
|
||||
@ -201,7 +240,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
||||
checkpoint_file = checkpoint_info.filename
|
||||
|
||||
if vae_file == unspecified:
|
||||
vae_file, vae_source = resolve_vae(checkpoint_file)
|
||||
vae_file, vae_source = resolve_vae(checkpoint_file).tuple()
|
||||
else:
|
||||
vae_source = "from function argument"
|
||||
|
||||
|
@ -422,6 +422,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('system', "System"), {
|
||||
"auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}),
|
||||
"show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
|
||||
"show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(),
|
||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
||||
@ -481,7 +482,7 @@ For img2img, VAE is used to process user's input image before the sampling, and
|
||||
"""),
|
||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||
"sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
|
||||
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
||||
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
|
||||
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"),
|
||||
@ -610,14 +611,14 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
||||
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
|
||||
"eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
|
||||
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
||||
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}),
|
||||
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}).info("0 = inf"),
|
||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
'k_sched_type': OptionInfo("Automatic", "scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
||||
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}).info('amount of stochasticity; only applies to Euler, Heun, and DPM2'),
|
||||
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}).info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
|
||||
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}).info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
|
||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}).info('amount of additional noise to counteract loss of detail during sampling; only applies to Euler, Heun, and DPM2'),
|
||||
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
||||
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
||||
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise schedule"),
|
||||
'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a more steep noise schedule (decreases faster)"),
|
||||
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
|
||||
'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),
|
||||
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
|
||||
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
|
||||
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
|
||||
@ -735,6 +736,10 @@ class Options:
|
||||
with open(filename, "r", encoding="utf8") as file:
|
||||
self.data = json.load(file)
|
||||
|
||||
# 1.6.0 VAE defaults
|
||||
if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:
|
||||
self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')
|
||||
|
||||
# 1.1.1 quicksettings list migration
|
||||
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
|
||||
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
|
||||
|
@ -1,7 +1,7 @@
|
||||
from contextlib import closing
|
||||
|
||||
import modules.scripts
|
||||
from modules import sd_samplers, processing
|
||||
from modules import processing
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||
from modules.shared import opts, cmd_opts
|
||||
import modules.shared as shared
|
||||
@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
|
||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
|
||||
override_settings = create_override_settings_dict(override_settings_texts)
|
||||
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
@ -25,7 +25,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
||||
seed_resize_from_h=seed_resize_from_h,
|
||||
seed_resize_from_w=seed_resize_from_w,
|
||||
seed_enable_extras=seed_enable_extras,
|
||||
sampler_name=sd_samplers.samplers[sampler_index].name,
|
||||
sampler_name=sampler_name,
|
||||
batch_size=batch_size,
|
||||
n_iter=n_iter,
|
||||
steps=steps,
|
||||
@ -42,7 +42,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
||||
hr_resize_x=hr_resize_x,
|
||||
hr_resize_y=hr_resize_y,
|
||||
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
|
||||
hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
|
||||
hr_sampler_name=hr_sampler_name,
|
||||
hr_prompt=hr_prompt,
|
||||
hr_negative_prompt=hr_negative_prompt,
|
||||
override_settings=override_settings,
|
||||
|
@ -13,7 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401
|
||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||
|
||||
from modules import gradio_extensons # noqa: F401
|
||||
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts
|
||||
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers
|
||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
||||
from modules.paths import script_path
|
||||
from modules.ui_common import create_refresh_button
|
||||
@ -29,7 +29,6 @@ import modules.shared as shared
|
||||
import modules.images
|
||||
from modules import prompt_parser
|
||||
from modules.sd_hijack import model_hijack
|
||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
||||
from modules.generation_parameters_copypaste import image_from_url_text
|
||||
|
||||
create_setting_component = ui_settings.create_setting_component
|
||||
@ -41,6 +40,9 @@ warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else
|
||||
mimetypes.init()
|
||||
mimetypes.add_type('application/javascript', '.js')
|
||||
|
||||
# Likewise, add explicit content-type header for certain missing image types
|
||||
mimetypes.add_type('image/webp', '.webp')
|
||||
|
||||
if not cmd_opts.share and not cmd_opts.listen:
|
||||
# fix gradio phoning home
|
||||
gradio.utils.version_check = lambda: None
|
||||
@ -357,14 +359,14 @@ def create_output_panel(tabname, outdir):
|
||||
def create_sampler_and_steps_selection(choices, tabname):
|
||||
if opts.samplers_in_dropdown:
|
||||
with FormRow(elem_id=f"sampler_selection_{tabname}"):
|
||||
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
|
||||
sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
|
||||
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
|
||||
else:
|
||||
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
|
||||
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
|
||||
sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
|
||||
sampler_name = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
|
||||
|
||||
return steps, sampler_index
|
||||
return steps, sampler_name
|
||||
|
||||
|
||||
def ordered_ui_categories():
|
||||
@ -405,13 +407,13 @@ def create_ui():
|
||||
extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")
|
||||
extra_tabs.__enter__()
|
||||
|
||||
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, gr.Row().style(equal_height=False):
|
||||
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, gr.Row(equal_height=False):
|
||||
with gr.Column(variant='compact', elem_id="txt2img_settings"):
|
||||
scripts.scripts_txt2img.prepare_ui()
|
||||
|
||||
for category in ordered_ui_categories():
|
||||
if category == "sampler":
|
||||
steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
|
||||
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img")
|
||||
|
||||
elif category == "dimensions":
|
||||
with FormRow():
|
||||
@ -457,7 +459,7 @@ def create_ui():
|
||||
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
|
||||
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
|
||||
|
||||
hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index")
|
||||
hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
|
||||
|
||||
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
||||
with gr.Column(scale=80):
|
||||
@ -517,7 +519,7 @@ def create_ui():
|
||||
toprow.negative_prompt,
|
||||
toprow.ui_styles.dropdown,
|
||||
steps,
|
||||
sampler_index,
|
||||
sampler_name,
|
||||
restore_faces,
|
||||
tiling,
|
||||
batch_count,
|
||||
@ -535,7 +537,7 @@ def create_ui():
|
||||
hr_resize_x,
|
||||
hr_resize_y,
|
||||
hr_checkpoint_name,
|
||||
hr_sampler_index,
|
||||
hr_sampler_name,
|
||||
hr_prompt,
|
||||
hr_negative_prompt,
|
||||
override_settings,
|
||||
@ -580,7 +582,7 @@ def create_ui():
|
||||
(toprow.prompt, "Prompt"),
|
||||
(toprow.negative_prompt, "Negative prompt"),
|
||||
(steps, "Steps"),
|
||||
(sampler_index, "Sampler"),
|
||||
(sampler_name, "Sampler"),
|
||||
(restore_faces, "Face restoration"),
|
||||
(cfg_scale, "CFG scale"),
|
||||
(seed, "Seed"),
|
||||
@ -602,7 +604,7 @@ def create_ui():
|
||||
(hr_resize_x, "Hires resize-1"),
|
||||
(hr_resize_y, "Hires resize-2"),
|
||||
(hr_checkpoint_name, "Hires checkpoint"),
|
||||
(hr_sampler_index, "Hires sampler"),
|
||||
(hr_sampler_name, "Hires sampler"),
|
||||
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
|
||||
(hr_prompt, "Hires prompt"),
|
||||
(hr_negative_prompt, "Hires negative prompt"),
|
||||
@ -618,7 +620,7 @@ def create_ui():
|
||||
toprow.prompt,
|
||||
toprow.negative_prompt,
|
||||
steps,
|
||||
sampler_index,
|
||||
sampler_name,
|
||||
cfg_scale,
|
||||
seed,
|
||||
width,
|
||||
@ -741,7 +743,7 @@ def create_ui():
|
||||
|
||||
for category in ordered_ui_categories():
|
||||
if category == "sampler":
|
||||
steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
|
||||
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img")
|
||||
|
||||
elif category == "dimensions":
|
||||
with FormRow():
|
||||
@ -873,7 +875,7 @@ def create_ui():
|
||||
init_img_inpaint,
|
||||
init_mask_inpaint,
|
||||
steps,
|
||||
sampler_index,
|
||||
sampler_name,
|
||||
mask_blur,
|
||||
mask_alpha,
|
||||
inpainting_fill,
|
||||
@ -969,7 +971,7 @@ def create_ui():
|
||||
(toprow.prompt, "Prompt"),
|
||||
(toprow.negative_prompt, "Negative prompt"),
|
||||
(steps, "Steps"),
|
||||
(sampler_index, "Sampler"),
|
||||
(sampler_name, "Sampler"),
|
||||
(restore_faces, "Face restoration"),
|
||||
(cfg_scale, "CFG scale"),
|
||||
(image_cfg_scale, "Image CFG scale"),
|
||||
|
@ -1,6 +1,6 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import ui_extra_networks_user_metadata, sd_vae
|
||||
from modules import ui_extra_networks_user_metadata, sd_vae, shared
|
||||
from modules.ui_common import create_refresh_button
|
||||
|
||||
|
||||
@ -18,6 +18,10 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE
|
||||
|
||||
self.write_user_metadata(name, user_metadata)
|
||||
|
||||
def update_vae(self, name):
|
||||
if name == shared.sd_model.sd_checkpoint_info.name_for_extra:
|
||||
sd_vae.reload_vae_weights()
|
||||
|
||||
def put_values_into_components(self, name):
|
||||
user_metadata = self.get_user_metadata(name)
|
||||
values = super().put_values_into_components(name)
|
||||
@ -58,3 +62,5 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE
|
||||
]
|
||||
|
||||
self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)
|
||||
self.button_save.click(fn=self.update_vae, inputs=[self.edit_name_input])
|
||||
|
||||
|
18
webui.py
18
webui.py
@ -211,7 +211,7 @@ def configure_sigint_handler():
|
||||
def configure_opts_onchange():
|
||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
|
||||
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
||||
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
|
||||
@ -341,6 +341,7 @@ def api_only():
|
||||
setup_middleware(app)
|
||||
api = create_api(app)
|
||||
|
||||
modules.script_callbacks.before_ui_callback()
|
||||
modules.script_callbacks.app_started_callback(None, app)
|
||||
|
||||
print(f"Startup time: {startup_timer.summary()}.")
|
||||
@ -371,6 +372,13 @@ def webui():
|
||||
|
||||
gradio_auth_creds = list(get_gradio_auth_creds()) or None
|
||||
|
||||
auto_launch_browser = False
|
||||
if os.getenv('SD_WEBUI_RESTARTING') != '1':
|
||||
if shared.opts.auto_launch_browser == "Remote" or cmd_opts.autolaunch:
|
||||
auto_launch_browser = True
|
||||
elif shared.opts.auto_launch_browser == "Local":
|
||||
auto_launch_browser = not any([cmd_opts.listen, cmd_opts.share, cmd_opts.ngrok])
|
||||
|
||||
app, local_url, share_url = shared.demo.launch(
|
||||
share=cmd_opts.share,
|
||||
server_name=server_name,
|
||||
@ -380,7 +388,7 @@ def webui():
|
||||
ssl_verify=cmd_opts.disable_tls_verify,
|
||||
debug=cmd_opts.gradio_debug,
|
||||
auth=gradio_auth_creds,
|
||||
inbrowser=cmd_opts.autolaunch and os.getenv('SD_WEBUI_RESTARTING') != '1',
|
||||
inbrowser=auto_launch_browser,
|
||||
prevent_thread_lock=True,
|
||||
allowed_paths=cmd_opts.gradio_allowed_path,
|
||||
app_kwargs={
|
||||
@ -390,9 +398,6 @@ def webui():
|
||||
root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "",
|
||||
)
|
||||
|
||||
# after initial launch, disable --autolaunch for subsequent restarts
|
||||
cmd_opts.autolaunch = False
|
||||
|
||||
startup_timer.record("gradio launch")
|
||||
|
||||
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
|
||||
@ -437,6 +442,9 @@ def webui():
|
||||
shared.demo.close()
|
||||
break
|
||||
|
||||
# disable auto launch webui in browser for subsequent UI Reload
|
||||
os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
|
||||
|
||||
print('Restarting UI...')
|
||||
shared.demo.close()
|
||||
time.sleep(0.5)
|
||||
|
Loading…
Reference in New Issue
Block a user