Merge branch 'tildebyte-feat-samplers-add-remaining-k' into main

This adds the remaining k_* samplers to the dream.py script.
This commit is contained in:
Lincoln Stein 2022-08-24 11:19:45 -04:00
commit 7f4a5e946d
3 changed files with 21 additions and 10 deletions

View File

@ -67,7 +67,7 @@ class KSampler(object):
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw
model_wrap_cfg = CFGDenoiser(self.model)
extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}
return (K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process),
return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process),
None)
def gather(samples_ddim):

View File

@ -18,7 +18,7 @@ t2i = T2I(outdir = <path> // outputs/txt2img-samples
batch_size = <integer> // how many images to generate per sampling (1)
steps = <integer> // 50
seed = <integer> // current system time
sampler_name= ['ddim','plms','klms'] // klms
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
grid = <boolean> // false
width = <integer> // image width, multiple of 64 (512)
height = <integer> // image height, multiple of 64 (512)
@ -448,19 +448,29 @@ The vast majority of these arguments default to reasonable values.
except AttributeError:
raise SystemExit
msg = f'setting sampler to {self.sampler_name}'
if self.sampler_name=='plms':
print("setting sampler to plms")
self.sampler = PLMSSampler(self.model)
elif self.sampler_name == 'ddim':
print("setting sampler to ddim")
self.sampler = DDIMSampler(self.model)
elif self.sampler_name == 'klms':
print("setting sampler to klms")
elif self.sampler_name == 'k_dpm_2_a':
self.sampler = KSampler(self.model,'dpm_2_ancestral')
elif self.sampler_name == 'k_dpm_2':
self.sampler = KSampler(self.model,'dpm_2')
elif self.sampler_name == 'k_euler_a':
self.sampler = KSampler(self.model,'euler_ancestral')
elif self.sampler_name == 'k_euler':
self.sampler = KSampler(self.model,'euler')
elif self.sampler_name == 'k_heun':
self.sampler = KSampler(self.model,'heun')
elif self.sampler_name == 'k_lms':
self.sampler = KSampler(self.model,'lms')
else:
print(f"unsupported sampler {self.sampler_name}, defaulting to plms")
msg = f'unsupported sampler {self.sampler_name}, defaulting to plms'
self.sampler = PLMSSampler(self.model)
print(msg)
return self.model
def _load_model_from_config(self, config, ckpt):

View File

@ -226,6 +226,7 @@ def _reconstruct_switches(t2i,opt):
switches.append(f'-W{opt.width or t2i.width}')
switches.append(f'-H{opt.height or t2i.height}')
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
switches.append(f'-m{t2i.sampler_name}')
if opt.init_img:
switches.append(f'-I{opt.init_img}')
if opt.strength and opt.init_img is not None:
@ -266,9 +267,9 @@ def create_argv_parser():
help="number of images to produce per iteration (faster, but doesn't generate individual seeds")
parser.add_argument('--sampler','-m',
dest="sampler_name",
choices=['plms','ddim', 'klms'],
default='klms',
help="which sampler to use (klms) - can only be set on command line")
choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'],
default='k_lms',
help="which sampler to use (k_lms) - can only be set on command line")
parser.add_argument('--outdir',
'-o',
type=str,