first draft at big refactoring; will be broken

This commit is contained in:
Lincoln Stein 2022-08-24 17:52:34 -04:00
parent 7ea168227c
commit 9133087850
3 changed files with 238 additions and 323 deletions

143
ldm/dream_util.py Normal file
View File

@ -0,0 +1,143 @@
'''Utilities for dealing with PNG images and their path names'''
import os
import atexit
from PIL import Image,PngImagePlugin
# ---------------readline utilities---------------------
try:
import readline
readline_available = True
except:
readline_available = False
class Completer():
def __init__(self,options):
self.options = sorted(options)
return
def complete(self,text,state):
buffer = readline.get_line_buffer()
if text.startswith(('-I','--init_img')):
return self._path_completions(text,state,('.png'))
if buffer.strip().endswith('cd') or text.startswith(('.','/')):
return self._path_completions(text,state,())
response = None
if state == 0:
# This is the first time for this text, so build a match list.
if text:
self.matches = [s
for s in self.options
if s and s.startswith(text)]
else:
self.matches = self.options[:]
# Return the state'th item from the match list,
# if we have that many.
try:
response = self.matches[state]
except IndexError:
response = None
return response
def _path_completions(self,text,state,extensions):
# get the path so far
if text.startswith('-I'):
path = text.replace('-I','',1).lstrip()
elif text.startswith('--init_img='):
path = text.replace('--init_img=','',1).lstrip()
else:
path = text
matches = list()
path = os.path.expanduser(path)
if len(path)==0:
matches.append(text+'./')
else:
dir = os.path.dirname(path)
dir_list = os.listdir(dir)
for n in dir_list:
if n.startswith('.') and len(n)>1:
continue
full_path = os.path.join(dir,n)
if full_path.startswith(path):
if os.path.isdir(full_path):
matches.append(os.path.join(os.path.dirname(text),n)+'/')
elif n.endswith(extensions):
matches.append(os.path.join(os.path.dirname(text),n))
try:
response = matches[state]
except IndexError:
response = None
return response
if readline_available:
readline.set_completer(Completer(['cd','pwd',
'--steps','-s','--seed','-S','--iterations','-n','--batch_size','-b',
'--width','-W','--height','-H','--cfg_scale','-C','--grid','-g',
'--individual','-i','--init_img','-I','--strength','-f','-v','--variants']).complete)
readline.set_completer_delims(" ")
readline.parse_and_bind('tab: complete')
histfile = os.path.join(os.path.expanduser('~'),".dream_history")
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)
except FileNotFoundError:
pass
atexit.register(readline.write_history_file,histfile)
# -------------------image generation utils-----
class PngWriter:
def __init__(self,opt):
self.opt = opt
self.filepath = None
self.files_written = []
def write_image(self,image,seed):
self.filepath = self.unique_filename(self,opt,seed,self.filepath) # will increment name in some sensible way
try:
image.save(self.filename)
except IOError as e:
print(e)
self.files_written.append([self.filepath,seed])
def unique_filename(self,opt,seed,previouspath):
revision = 1
if previouspath is None:
# sort reverse alphabetically until we find max+1
dirlist = sorted(os.listdir(outdir),reverse=True)
# find the first filename that matches our pattern or return 000000.0.png
filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png')
basecount = int(filename.split('.',1)[0])
basecount += 1
if opt.batch_size > 1:
filename = f'{basecount:06}.{seed}.01.png'
else:
filename = f'{basecount:06}.{seed}.png'
return os.path.join(outdir,filename)
else:
basename = os.path.basename(previouspath)
x = re.match('^(\d+)\..*\.png',basename)
if not x:
return self.unique_filename(opt,seed,previouspath)
basecount = int(x.groups()[0])
series = 0
finished = False
while not finished:
series += 1
filename = f'{basecount:06}.{seed}.png'
if isbatch or os.path.exists(os.path.join(outdir,filename)):
filename = f'{basecount:06}.{seed}.{series:02}.png'
finished = not os.path.exists(os.path.join(outdir,filename))
return os.path.join(outdir,filename)

View File

@ -23,7 +23,6 @@ t2i = T2I(outdir = <path> // outputs/txt2img-samples
width = <integer> // image width, multiple of 64 (512) width = <integer> // image width, multiple of 64 (512)
height = <integer> // image height, multiple of 64 (512) height = <integer> // image height, multiple of 64 (512)
cfg_scale = <float> // unconditional guidance scale (7.5) cfg_scale = <float> // unconditional guidance scale (7.5)
fixed_code = <boolean> // False
) )
# do the slow model initialization # do the slow model initialization
@ -79,7 +78,6 @@ class T2I:
"""T2I class """T2I class
Attributes Attributes
---------- ----------
outdir
model model
config config
iterations iterations
@ -87,12 +85,9 @@ class T2I:
steps steps
seed seed
sampler_name sampler_name
grid
individual
width width
height height
cfg_scale cfg_scale
fixed_code
latent_channels latent_channels
downsampling_factor downsampling_factor
precision precision
@ -102,11 +97,8 @@ class T2I:
The vast majority of these arguments default to reasonable values. The vast majority of these arguments default to reasonable values.
""" """
def __init__(self, def __init__(self,
outdir="outputs/txt2img-samples",
batch_size=1, batch_size=1,
iterations = 1, iterations = 1,
width=512,
height=512,
grid=False, grid=False,
individual=None, # redundant individual=None, # redundant
steps=50, steps=50,
@ -118,7 +110,6 @@ The vast majority of these arguments default to reasonable values.
latent_channels=4, latent_channels=4,
downsampling_factor=8, downsampling_factor=8,
ddim_eta=0.0, # deterministic ddim_eta=0.0, # deterministic
fixed_code=False,
precision='autocast', precision='autocast',
full_precision=False, full_precision=False,
strength=0.75, # default in scripts/img2img.py strength=0.75, # default in scripts/img2img.py
@ -126,7 +117,6 @@ The vast majority of these arguments default to reasonable values.
latent_diffusion_weights=False, # just to keep track of this parameter when regenerating prompt latent_diffusion_weights=False, # just to keep track of this parameter when regenerating prompt
device='cuda' device='cuda'
): ):
self.outdir = outdir
self.batch_size = batch_size self.batch_size = batch_size
self.iterations = iterations self.iterations = iterations
self.width = width self.width = width
@ -137,7 +127,6 @@ The vast majority of these arguments default to reasonable values.
self.weights = weights self.weights = weights
self.config = config self.config = config
self.sampler_name = sampler_name self.sampler_name = sampler_name
self.fixed_code = fixed_code
self.latent_channels = latent_channels self.latent_channels = latent_channels
self.downsampling_factor = downsampling_factor self.downsampling_factor = downsampling_factor
self.ddim_eta = ddim_eta self.ddim_eta = ddim_eta
@ -154,16 +143,25 @@ The vast majority of these arguments default to reasonable values.
else: else:
self.seed = seed self.seed = seed
@torch.no_grad() def generate(self,
def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None, # these are common
steps=None,seed=None,grid=None,individual=None,width=None,height=None, prompt,
cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None,init_img=None, batch_size=None,
skip_normalize=False,variants=None): # note the "variants" option is an unused hack caused by how options are passed iterations=None,
""" steps=None,
Generate an image from the prompt, writing iteration images into the outdir seed=None,
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...] cfg_scale=None,
""" ddim_eta=None,
outdir = outdir or self.outdir skip_normalize=False,
image_callback=None,
# these are specific to txt2img
width=None,
height=None,
# these are specific to img2img
init_img=None,
strength=None,
variants=None):
'''ldm.generate() is the common entry point for txt2img() and img2img()'''
steps = steps or self.steps steps = steps or self.steps
seed = seed or self.seed seed = seed or self.seed
width = width or self.width width = width or self.width
@ -172,41 +170,57 @@ The vast majority of these arguments default to reasonable values.
ddim_eta = ddim_eta or self.ddim_eta ddim_eta = ddim_eta or self.ddim_eta
batch_size = batch_size or self.batch_size batch_size = batch_size or self.batch_size
iterations = iterations or self.iterations iterations = iterations or self.iterations
strength = strength or self.strength # not actually used here, but preserved for code refactoring strength = strength or self.strength
embedding_path = embedding_path or self.embedding_path
model = self.load_model() # will instantiate the model or return it from cache model = self.load_model() # will instantiate the model or return it from cache
assert strength<1.0 and strength>=0.0, "strength (-f) must be >=0.0 and <1.0"
assert cfg_scale>1.0, "CFG_Scale (-C) must be >1.0" assert cfg_scale>1.0, "CFG_Scale (-C) must be >1.0"
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
# grid and individual are mutually exclusive, with individual taking priority.
# not necessary, but needed for compatability with dream bot
if (grid is None):
grid = self.grid
if individual:
grid = False
data = [batch_size * [prompt]] data = [batch_size * [prompt]]
scope = autocast if self.precision=="autocast" else nullcontext
if grid:
callback = self.image2png
else:
callback = None
# make directories and establish names for the output files tic = time.time()
os.makedirs(outdir, exist_ok=True) if init_img:
assert os.path.exists(init_img),f'{init_img}: File not found'
results = self._img2img(prompt,
data=data,precision_scope=scope,
batch_size=batch_size,iterations=iterations,
steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
skip_normalize=skip_normalize,
init_img=init_img,strength=strength,variants=variants,
callback=image_callback)
else:
results = self._txt2img(prompt,
data=data,precision_scope=scope,
batch_size=batch_size,iterations=iterations,
steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
skip_normalize=skip_normalize,
width=width,height=height,
callback=image_callback)
toc = time.time()
print(f'{len(results)} images generated in',"%4.2fs"% (toc-tic))
return results
@torch.no_grad()
def _txt2img(self,prompt,
data,precision_scope,
batch_size,iterations,
steps,seed,cfg_scale,ddim_eta,
skip_normalize,
width,height,
callback=callback): # the callback is called each time a new Image is generated
"""
Generate an image from the prompt, writing iteration images into the outdir
The output is a list of lists in the format: [[image1,seed1], [image2,seed2],...]
"""
start_code = None
if self.fixed_code:
start_code = torch.randn([batch_size,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=self.device)
precision_scope = autocast if self.precision=="autocast" else nullcontext
sampler = self.sampler sampler = self.sampler
images = list() images = list()
seeds = list()
filename = None
image_count = 0 image_count = 0
tic = time.time()
# Gawd. Too many levels of indent here. Need to refactor into smaller routines! # Gawd. Too many levels of indent here. Need to refactor into smaller routines!
try: try:
@ -239,38 +253,24 @@ The vast majority of these arguments default to reasonable values.
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
samples_ddim, _ = sampler.sample(S=steps, samples_ddim, _ = sampler.sample(S=steps,
conditioning=c, conditioning=c,
batch_size=batch_size, batch_size=batch_size,
shape=shape, shape=shape,
verbose=False, verbose=False,
unconditional_guidance_scale=cfg_scale, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc, unconditional_conditioning=uc,
eta=ddim_eta, eta=ddim_eta)
x_T=start_code)
x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples_ddim:
if not grid: x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
for x_sample in x_samples_ddim: image = Image.fromarray(x_sample.astype(np.uint8))
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') images.append([image,seed])
filename = self._unique_filename(outdir,previousname=filename, if callback is not None:
seed=seed,isbatch=(batch_size>1)) callback(image,seed)
assert not os.path.exists(filename)
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
images.append([filename,seed])
else:
all_samples.append(x_samples_ddim)
seeds.append(seed)
image_count += 1
seed = self._new_seed() seed = self._new_seed()
if grid:
images = self._make_grid(samples=all_samples,
seeds=seeds,
batch_size=batch_size,
iterations=iterations,
outdir=outdir)
except KeyboardInterrupt: except KeyboardInterrupt:
print('*interrupted*') print('*interrupted*')
print('Partial results will be returned; if --grid was requested, nothing will be returned.') print('Partial results will be returned; if --grid was requested, nothing will be returned.')
@ -279,48 +279,20 @@ The vast majority of these arguments default to reasonable values.
toc = time.time() toc = time.time()
print(f'{image_count} images generated in',"%4.2fs"% (toc-tic)) print(f'{image_count} images generated in',"%4.2fs"% (toc-tic))
return images return images
# There is lots of shared code between this and txt2img and should be refactored.
@torch.no_grad() @torch.no_grad()
def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None, def _img2img(self,prompt,
steps=None,seed=None,grid=None,individual=None,width=None,height=None, data,precision_scope,
cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None, batch_size,iterations,
skip_normalize=False,variants=None): # note the "variants" option is an unused hack caused by how options are passed steps,seed,cfg_scale,ddim_eta,
skip_normalize,
init_img,strength,variants,
callback):
""" """
Generate an image from the prompt and the initial image, writing iteration images into the outdir Generate an image from the prompt and the initial image, writing iteration images into the outdir
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...] The output is a list of lists in the format: [[image,seed1], [image,seed2],...]
""" """
outdir = outdir or self.outdir
steps = steps or self.steps
seed = seed or self.seed
cfg_scale = cfg_scale or self.cfg_scale
ddim_eta = ddim_eta or self.ddim_eta
batch_size = batch_size or self.batch_size
iterations = iterations or self.iterations
strength = strength or self.strength
embedding_path = embedding_path or self.embedding_path
assert strength<1.0 and strength>=0.0, "strength (-f) must be >=0.0 and <1.0"
assert cfg_scale>1.0, "CFG_Scale (-C) must be >1.0"
if init_img is None:
print("no init_img provided!")
return []
model = self.load_model() # will instantiate the model or return it from cache
precision_scope = autocast if self.precision=="autocast" else nullcontext
# grid and individual are mutually exclusive, with individual taking priority.
# not necessary, but needed for compatability with dream bot
if (grid is None):
grid = self.grid
if individual:
grid = False
data = [batch_size * [prompt]]
# PLMS sampler not supported yet, so ignore previous sampler # PLMS sampler not supported yet, so ignore previous sampler
if self.sampler_name!='ddim': if self.sampler_name!='ddim':
@ -329,33 +301,18 @@ The vast majority of these arguments default to reasonable values.
else: else:
sampler = self.sampler sampler = self.sampler
# make directories and establish names for the output files
os.makedirs(outdir, exist_ok=True)
assert os.path.isfile(init_img)
init_image = self._load_img(init_img).to(self.device) init_image = self._load_img(init_img).to(self.device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
with precision_scope(self.device.type): with precision_scope(self.device.type):
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False) sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
try:
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
except AssertionError:
print(f"strength must be between 0.0 and 1.0, but received value {strength}")
return []
t_enc = int(strength * steps) t_enc = int(strength * steps)
print(f"target t_enc is {t_enc} steps") print(f"target t_enc is {t_enc} steps")
images = list() images = list()
seeds = list()
filename = None
image_count = 0 # actual number of iterations performed
tic = time.time()
# Gawd. Too many levels of indent here. Need to refactor into smaller routines!
try: try:
with precision_scope(self.device.type), model.ema_scope(): with precision_scope(self.device.type), model.ema_scope():
all_samples = list() all_samples = list()
@ -393,25 +350,13 @@ The vast majority of these arguments default to reasonable values.
x_samples = model.decode_first_stage(samples) x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if not grid: for x_sample in x_samples:
for x_sample in x_samples: x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') image = Image.fromarray(x_sample.astype(np.uint8))
filename = self._unique_filename(outdir,previousname=filename, images.append([image,seed])
seed=seed,isbatch=(batch_size>1)) if callback is not None:
assert not os.path.exists(filename) callback(image,seed)
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
images.append([filename,seed])
else:
all_samples.append(x_samples)
seeds.append(seed)
image_count +=1
seed = self._new_seed() seed = self._new_seed()
if grid:
images = self._make_grid(samples=all_samples,
seeds=seeds,
batch_size=batch_size,
iterations=iterations,
outdir=outdir)
except KeyboardInterrupt: except KeyboardInterrupt:
print('*interrupted*') print('*interrupted*')
@ -419,26 +364,6 @@ The vast majority of these arguments default to reasonable values.
except RuntimeError as e: except RuntimeError as e:
print("Oops! A runtime error has occurred. If this is unexpected, please copy-and-paste this stack trace and post it as an Issue to http://github.com/lstein/stable-diffusion") print("Oops! A runtime error has occurred. If this is unexpected, please copy-and-paste this stack trace and post it as an Issue to http://github.com/lstein/stable-diffusion")
traceback.print_exc() traceback.print_exc()
toc = time.time()
print(f'{image_count} images generated in',"%4.2fs"% (toc-tic))
return images
def _make_grid(self,samples,seeds,batch_size,iterations,outdir):
images = list()
n_rows = batch_size if batch_size>1 else int(math.sqrt(batch_size * iterations))
# save as grid
grid = torch.stack(samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
filename = self._unique_filename(outdir,seed=seeds[0],grid_count=batch_size*iterations)
Image.fromarray(grid.astype(np.uint8)).save(filename)
for s in seeds:
images.append([filename,s])
return images return images
def _new_seed(self): def _new_seed(self):
@ -513,43 +438,6 @@ The vast majority of these arguments default to reasonable values.
image = torch.from_numpy(image) image = torch.from_numpy(image)
return 2.*image - 1. return 2.*image - 1.
def _unique_filename(self,outdir,previousname=None,seed=0,isbatch=False,grid_count=None):
revision = 1
if previousname is None:
# sort reverse alphabetically until we find max+1
dirlist = sorted(os.listdir(outdir),reverse=True)
# find the first filename that matches our pattern or return 000000.0.png
filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png')
basecount = int(filename.split('.',1)[0])
basecount += 1
if grid_count is not None:
grid_label = f'grid#1-{grid_count}'
filename = f'{basecount:06}.{seed}.{grid_label}.png'
elif isbatch:
filename = f'{basecount:06}.{seed}.01.png'
else:
filename = f'{basecount:06}.{seed}.png'
return os.path.join(outdir,filename)
else:
previousname = os.path.basename(previousname)
x = re.match('^(\d+)\..*\.png',previousname)
if not x:
return self._unique_filename(outdir,previousname,seed)
basecount = int(x.groups()[0])
series = 0
finished = False
while not finished:
series += 1
filename = f'{basecount:06}.{seed}.png'
if isbatch or os.path.exists(os.path.join(outdir,filename)):
filename = f'{basecount:06}.{seed}.{series:02}.png'
finished = not os.path.exists(os.path.join(outdir,filename))
return os.path.join(outdir,filename)
def _split_weighted_subprompts(text): def _split_weighted_subprompts(text):
""" """
grabs all text up to the first occurrence of ':' grabs all text up to the first occurrence of ':'

View File

@ -8,13 +8,7 @@ import os
import sys import sys
import copy import copy
from PIL import Image,PngImagePlugin from PIL import Image,PngImagePlugin
from ldm.dream_util import Completer,PngWriter
# readline unavailable on windows systems
try:
import readline
readline_available = True
except:
readline_available = False
debugging = False debugging = False
@ -131,13 +125,13 @@ def main_loop(t2i,parser,log,infile):
if elements[0]=='cd' and len(elements)>1: if elements[0]=='cd' and len(elements)>1:
if os.path.exists(elements[1]): if os.path.exists(elements[1]):
print(f"setting image output directory to {elements[1]}") print(f"setting image output directory to {elements[1]}")
t2i.outdir=elements[1] opt.outdir=elements[1]
else: else:
print(f"directory {elements[1]} does not exist") print(f"directory {elements[1]} does not exist")
continue continue
if elements[0]=='pwd': if elements[0]=='pwd':
print(f"current output directory is {t2i.outdir}") print(f"current output directory is {opt.outdir}")
continue continue
if elements[0].startswith('!dream'): # in case a stored prompt still contains the !dream command if elements[0].startswith('!dream'): # in case a stored prompt still contains the !dream command
@ -167,47 +161,19 @@ def main_loop(t2i,parser,log,infile):
continue continue
try: try:
if opt.init_img is None: file_writer = PngWriter(opt)
results = t2i.txt2img(**vars(opt)) opt.callback = file_writer(write_image)
else: run_generator(**vars(opt))
assert os.path.exists(opt.init_img),f"No file found at {opt.init_img}. On Linux systems, pressing <tab> after -I will autocomplete a list of possible image files." results = file_writer.files_written
if None not in (opt.width,opt.height):
print('Warning: width and height options are ignored when modifying an init image')
results = t2i.img2img(**vars(opt))
except AssertionError as e: except AssertionError as e:
print(e) print(e)
continue continue
allVariantResults = []
if opt.variants is not None:
print(f"Generating {opt.variants} variant(s)...")
newopt = copy.deepcopy(opt)
newopt.variants = None
for r in results:
newopt.init_img = r[0]
print(f"\t generating variant for {newopt.init_img}")
for j in range(0, opt.variants):
try:
variantResults = t2i.img2img(**vars(newopt))
allVariantResults.append([newopt,variantResults])
except AssertionError as e:
print(e)
continue
print(f"{opt.variants} Variants generated!")
print("Outputs:") print("Outputs:")
write_log_message(t2i,opt,results,log) write_log_message(t2i,opt,results,log)
if allVariantResults:
print("Variant outputs:")
for vr in allVariantResults:
write_log_message(t2i,vr[0],vr[1],log)
print("goodbye!") print("goodbye!")
def write_log_message(t2i,opt,results,logfile): def write_log_message(t2i,opt,results,logfile):
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata ''' ''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata '''
switches = _reconstruct_switches(t2i,opt) switches = _reconstruct_switches(t2i,opt)
@ -339,89 +305,7 @@ def create_cmd_parser():
parser.add_argument('-x','--skip_normalize',action='store_true',help="skip subprompt weight normalization") parser.add_argument('-x','--skip_normalize',action='store_true',help="skip subprompt weight normalization")
return parser return parser
if readline_available:
def setup_readline():
readline.set_completer(Completer(['cd','pwd',
'--steps','-s','--seed','-S','--iterations','-n','--batch_size','-b',
'--width','-W','--height','-H','--cfg_scale','-C','--grid','-g',
'--individual','-i','--init_img','-I','--strength','-f','-v','--variants']).complete)
readline.set_completer_delims(" ")
readline.parse_and_bind('tab: complete')
load_history()
def load_history():
histfile = os.path.join(os.path.expanduser('~'),".dream_history")
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)
except FileNotFoundError:
pass
atexit.register(readline.write_history_file,histfile)
class Completer():
def __init__(self,options):
self.options = sorted(options)
return
def complete(self,text,state):
buffer = readline.get_line_buffer()
if text.startswith(('-I','--init_img')):
return self._path_completions(text,state,('.png'))
if buffer.strip().endswith('cd') or text.startswith(('.','/')):
return self._path_completions(text,state,())
response = None
if state == 0:
# This is the first time for this text, so build a match list.
if text:
self.matches = [s
for s in self.options
if s and s.startswith(text)]
else:
self.matches = self.options[:]
# Return the state'th item from the match list,
# if we have that many.
try:
response = self.matches[state]
except IndexError:
response = None
return response
def _path_completions(self,text,state,extensions):
# get the path so far
if text.startswith('-I'):
path = text.replace('-I','',1).lstrip()
elif text.startswith('--init_img='):
path = text.replace('--init_img=','',1).lstrip()
else:
path = text
matches = list()
path = os.path.expanduser(path)
if len(path)==0:
matches.append(text+'./')
else:
dir = os.path.dirname(path)
dir_list = os.listdir(dir)
for n in dir_list:
if n.startswith('.') and len(n)>1:
continue
full_path = os.path.join(dir,n)
if full_path.startswith(path):
if os.path.isdir(full_path):
matches.append(os.path.join(os.path.dirname(text),n)+'/')
elif n.endswith(extensions):
matches.append(os.path.join(os.path.dirname(text),n))
try:
response = matches[state]
except IndexError:
response = None
return response
if __name__ == "__main__": if __name__ == "__main__":
main() main()