mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-09 04:18:46 +08:00
first attempt to fold k_lms changes proposed by hwharrison and bmaltais
This commit is contained in:
parent
d340afc9e5
commit
bb91ca0462
64
ldm/models/diffusion/ksampler.py
Normal file
64
ldm/models/diffusion/ksampler.py
Normal file
@ -0,0 +1,64 @@
|
||||
'''wrapper around part of Karen Crownson's k-duffsion library, making it call compatible with other Samplers'''
|
||||
import k_diffusion as K
|
||||
import torch.nn as nn
|
||||
|
||||
class CFGDenoiser(nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
class KSampler(object):
|
||||
def __init__(self,model,schedule="lms", **kwargs):
|
||||
super().__init__()
|
||||
self.model = K.external.CompVisDenoiser(model)
|
||||
self.accelerator = accelerate.Accelerator()
|
||||
self.device = accelerator.device
|
||||
self.schedule = schedule
|
||||
|
||||
# most of these arguments are ignored and are only present for compatibility with
|
||||
# other samples
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
|
||||
sigmas = self.model.get_sigmas(S)
|
||||
if x_T:
|
||||
x = x_T
|
||||
else:
|
||||
x = torch.randn([batch_size, *shape], device=device) * sigmas[0] # for GPU draw
|
||||
model_wrap_cfg = CFGDenoiser(self.model)
|
||||
extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}
|
||||
return (K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process),
|
||||
None)
|
||||
|
||||
def gather(samples_ddim):
|
||||
return self.accelerator.gather(samples_ddim)
|
@ -11,7 +11,7 @@ t2i = T2I(outdir = <path> // outputs/txt2img-samples
|
||||
batch_size = <integer> // how many images to generate per sampling (1)
|
||||
steps = <integer> // 50
|
||||
seed = <integer> // current system time
|
||||
sampler = ['ddim','plms'] // ddim
|
||||
sampler = ['ddim','plms','klms'] // klms
|
||||
grid = <boolean> // false
|
||||
width = <integer> // image width, multiple of 64 (512)
|
||||
height = <integer> // image height, multiple of 64 (512)
|
||||
@ -62,8 +62,9 @@ import time
|
||||
import math
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.models.diffusion.ksampler import KSampler
|
||||
|
||||
class T2I:
|
||||
"""T2I class
|
||||
@ -101,7 +102,7 @@ class T2I:
|
||||
cfg_scale=7.5,
|
||||
weights="models/ldm/stable-diffusion-v1/model.ckpt",
|
||||
config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml",
|
||||
sampler="plms",
|
||||
sampler="klms",
|
||||
latent_channels=4,
|
||||
downsampling_factor=8,
|
||||
ddim_eta=0.0, # deterministic
|
||||
@ -387,6 +388,9 @@ class T2I:
|
||||
elif self.sampler_name == 'ddim':
|
||||
print("setting sampler to ddim")
|
||||
self.sampler = DDIMSampler(self.model)
|
||||
elif self.sampler_name == 'klms':
|
||||
print("setting sampler to klms")
|
||||
self.sampler = KSampler(self.model,'lms')
|
||||
else:
|
||||
print(f"unsupported sampler {self.sampler_name}, defaulting to plms")
|
||||
self.sampler = PLMSSampler(self.model)
|
||||
|
Loading…
Reference in New Issue
Block a user