mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui
synced 2025-01-08 12:07:30 +08:00
add new sampler DDIM CFG++
This commit is contained in:
parent
a30b19dd55
commit
663a4d80df
@ -58,6 +58,8 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
self.model_wrap = None
|
self.model_wrap = None
|
||||||
self.p = None
|
self.p = None
|
||||||
|
|
||||||
|
self.last_noise_uncond = None
|
||||||
|
|
||||||
# NOTE: masking before denoising can cause the original latents to be oversmoothed
|
# NOTE: masking before denoising can cause the original latents to be oversmoothed
|
||||||
# as the original latents do not have noise
|
# as the original latents do not have noise
|
||||||
self.mask_before_denoising = False
|
self.mask_before_denoising = False
|
||||||
@ -160,6 +162,8 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
# so is_edit_model is set to False to support AND composition.
|
# so is_edit_model is set to False to support AND composition.
|
||||||
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
||||||
|
|
||||||
|
is_cfg_pp = 'CFG++' in self.sampler.config.name
|
||||||
|
|
||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||||
|
|
||||||
@ -273,10 +277,16 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
|
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
|
||||||
cfg_denoised_callback(denoised_params)
|
cfg_denoised_callback(denoised_params)
|
||||||
|
|
||||||
|
if is_cfg_pp:
|
||||||
|
self.last_noise_uncond = x_out[-uncond.shape[0]:]
|
||||||
|
self.last_noise_uncond = torch.clone(self.last_noise_uncond)
|
||||||
|
|
||||||
if is_edit_model:
|
if is_edit_model:
|
||||||
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
||||||
elif skip_uncond:
|
elif skip_uncond:
|
||||||
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
|
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
|
||||||
|
elif is_cfg_pp:
|
||||||
|
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale/12.5) # CFG++ scale of (0, 1) maps to (1.0, 12.5)
|
||||||
else:
|
else:
|
||||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ import modules.shared as shared
|
|||||||
|
|
||||||
samplers_timesteps = [
|
samplers_timesteps = [
|
||||||
('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}),
|
('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}),
|
||||||
|
('DDIM CFG++', sd_samplers_timesteps_impl.ddim_cfgpp, ['ddim_cfgpp'], {}),
|
||||||
('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}),
|
('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}),
|
||||||
('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}),
|
('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}),
|
||||||
]
|
]
|
||||||
|
@ -40,6 +40,43 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
|
||||||
|
""" Implements CFG++: Manifold-constrained Classifier Free Guidance For Diffusion Models (2024).
|
||||||
|
Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction.
|
||||||
|
The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0].
|
||||||
|
"""
|
||||||
|
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||||
|
alphas = alphas_cumprod[timesteps]
|
||||||
|
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
|
||||||
|
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||||
|
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
|
||||||
|
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones((x.shape[0]))
|
||||||
|
s_x = x.new_ones((x.shape[0], 1, 1, 1))
|
||||||
|
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
|
||||||
|
index = len(timesteps) - 1 - i
|
||||||
|
|
||||||
|
e_t = model(x, timesteps[index].item() * s_in, **extra_args)
|
||||||
|
last_noise_uncond = model.last_noise_uncond
|
||||||
|
|
||||||
|
a_t = alphas[index].item() * s_x
|
||||||
|
a_prev = alphas_prev[index].item() * s_x
|
||||||
|
sigma_t = sigmas[index].item() * s_x
|
||||||
|
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
|
||||||
|
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * last_noise_uncond
|
||||||
|
noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
|
||||||
|
x = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
|
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
|
||||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||||
|
Loading…
Reference in New Issue
Block a user