mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-09 04:18:46 +08:00
Merge branch 'tildebyte-feat-samplers-add-remaining-k' into main
This adds the remaining k_* samplers to the dream.py script.
This commit is contained in:
commit
7f4a5e946d
@ -67,7 +67,7 @@ class KSampler(object):
|
||||
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw
|
||||
model_wrap_cfg = CFGDenoiser(self.model)
|
||||
extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}
|
||||
return (K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process),
|
||||
return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process),
|
||||
None)
|
||||
|
||||
def gather(samples_ddim):
|
||||
|
@ -18,7 +18,7 @@ t2i = T2I(outdir = <path> // outputs/txt2img-samples
|
||||
batch_size = <integer> // how many images to generate per sampling (1)
|
||||
steps = <integer> // 50
|
||||
seed = <integer> // current system time
|
||||
sampler_name= ['ddim','plms','klms'] // klms
|
||||
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
||||
grid = <boolean> // false
|
||||
width = <integer> // image width, multiple of 64 (512)
|
||||
height = <integer> // image height, multiple of 64 (512)
|
||||
@ -448,19 +448,29 @@ The vast majority of these arguments default to reasonable values.
|
||||
except AttributeError:
|
||||
raise SystemExit
|
||||
|
||||
msg = f'setting sampler to {self.sampler_name}'
|
||||
if self.sampler_name=='plms':
|
||||
print("setting sampler to plms")
|
||||
self.sampler = PLMSSampler(self.model)
|
||||
elif self.sampler_name == 'ddim':
|
||||
print("setting sampler to ddim")
|
||||
self.sampler = DDIMSampler(self.model)
|
||||
elif self.sampler_name == 'klms':
|
||||
print("setting sampler to klms")
|
||||
elif self.sampler_name == 'k_dpm_2_a':
|
||||
self.sampler = KSampler(self.model,'dpm_2_ancestral')
|
||||
elif self.sampler_name == 'k_dpm_2':
|
||||
self.sampler = KSampler(self.model,'dpm_2')
|
||||
elif self.sampler_name == 'k_euler_a':
|
||||
self.sampler = KSampler(self.model,'euler_ancestral')
|
||||
elif self.sampler_name == 'k_euler':
|
||||
self.sampler = KSampler(self.model,'euler')
|
||||
elif self.sampler_name == 'k_heun':
|
||||
self.sampler = KSampler(self.model,'heun')
|
||||
elif self.sampler_name == 'k_lms':
|
||||
self.sampler = KSampler(self.model,'lms')
|
||||
else:
|
||||
print(f"unsupported sampler {self.sampler_name}, defaulting to plms")
|
||||
msg = f'unsupported sampler {self.sampler_name}, defaulting to plms'
|
||||
self.sampler = PLMSSampler(self.model)
|
||||
|
||||
print(msg)
|
||||
|
||||
return self.model
|
||||
|
||||
def _load_model_from_config(self, config, ckpt):
|
||||
|
@ -226,6 +226,7 @@ def _reconstruct_switches(t2i,opt):
|
||||
switches.append(f'-W{opt.width or t2i.width}')
|
||||
switches.append(f'-H{opt.height or t2i.height}')
|
||||
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
|
||||
switches.append(f'-m{t2i.sampler_name}')
|
||||
if opt.init_img:
|
||||
switches.append(f'-I{opt.init_img}')
|
||||
if opt.strength and opt.init_img is not None:
|
||||
@ -266,9 +267,9 @@ def create_argv_parser():
|
||||
help="number of images to produce per iteration (faster, but doesn't generate individual seeds")
|
||||
parser.add_argument('--sampler','-m',
|
||||
dest="sampler_name",
|
||||
choices=['plms','ddim', 'klms'],
|
||||
default='klms',
|
||||
help="which sampler to use (klms) - can only be set on command line")
|
||||
choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'],
|
||||
default='k_lms',
|
||||
help="which sampler to use (k_lms) - can only be set on command line")
|
||||
parser.add_argument('--outdir',
|
||||
'-o',
|
||||
type=str,
|
||||
|
Loading…
Reference in New Issue
Block a user