Added --from_file argument to load input from a file. Closes #23

This commit is contained in:
Lincoln Stein 2022-08-23 00:30:06 -04:00
parent 6d1219deec
commit bc7b1fdd37

View File

@ -67,29 +67,45 @@ def main():
# gets rid of annoying messages about random seed
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
infile = None
try:
if opt.infile is not None:
infile = open(opt.infile,'r')
except FileNotFoundError as e:
print(e)
exit(-1)
# preload the model
if not debugging:
t2i.load_model()
print("\n* Initialization done! Awaiting your command (-h for help, 'q' to quit, 'cd' to change output dir, 'pwd' to print output dir)...")
log_path = os.path.join(opt.outdir,'..','dream_log.txt')
log_path = os.path.join(opt.outdir,'dream_log.txt')
with open(log_path,'a') as log:
cmd_parser = create_cmd_parser()
main_loop(t2i,cmd_parser,log)
main_loop(t2i,cmd_parser,log,infile)
log.close()
infile.close()
def main_loop(t2i,parser,log):
def main_loop(t2i,parser,log,infile):
''' prompt/read/execute loop '''
done = False
while not done:
try:
command = input("dream> ")
command = infile.readline() if infile else input("dream> ")
except EOFError:
done = True
break
if infile and len(command)==0:
done = True
break
if command.startswith(('#','//')):
continue
try:
elements = shlex.split(command)
except ValueError as e:
@ -98,7 +114,7 @@ def main_loop(t2i,parser,log):
if len(elements)==0:
continue
if elements[0]=='q':
done = True
break
@ -232,6 +248,10 @@ def create_argv_parser():
dest='laion400m',
action='store_true',
help="fallback to the latent diffusion (laion400m) weights and config")
parser.add_argument("--from_file",
dest='infile',
type=str,
help="if specified, load prompts from this file")
parser.add_argument('-n','--iterations',
type=int,
default=1,