mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-09 04:18:46 +08:00
Integrate New WebUI with dream.py
This commit is contained in:
parent
3491a1688b
commit
af8383c770
813
backend/invoke_ai_web_server.py
Normal file
813
backend/invoke_ai_web_server.py
Normal file
@ -0,0 +1,813 @@
|
||||
import eventlet
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from flask import Flask, redirect, send_from_directory
|
||||
from flask_socketio import SocketIO
|
||||
from PIL import Image
|
||||
from uuid import uuid4
|
||||
from threading import Event
|
||||
|
||||
from ldm.dream.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||
from ldm.dream.pngwriter import PngWriter, retrieve_metadata
|
||||
from ldm.dream.conditioning import split_weighted_subprompts
|
||||
|
||||
from backend.modules.parameters import parameters_to_command
|
||||
|
||||
# Loading Arguments
|
||||
opt = Args()
|
||||
args = opt.parse_args()
|
||||
|
||||
|
||||
class InvokeAIWebServer:
|
||||
def __init__(self, generate, gfpgan, codeformer, esrgan) -> None:
|
||||
self.host = args.host
|
||||
self.port = args.port
|
||||
|
||||
self.generate = generate
|
||||
self.gfpgan = gfpgan
|
||||
self.codeformer = codeformer
|
||||
self.esrgan = esrgan
|
||||
|
||||
self.canceled = Event()
|
||||
|
||||
def run(self):
|
||||
self.setup_app()
|
||||
self.setup_flask()
|
||||
|
||||
def setup_flask(self):
|
||||
# Socket IO
|
||||
logger = True if args.web_verbose else False
|
||||
engineio_logger = True if args.web_verbose else False
|
||||
max_http_buffer_size = 10000000
|
||||
|
||||
# CORS Allowed Setup
|
||||
cors_allowed_origins = ['http://127.0.0.1:5173', 'http://localhost:5173']
|
||||
additional_allowed_origins = (
|
||||
opt.cors if opt.cors else []
|
||||
) # additional CORS allowed origins
|
||||
if self.host == '127.0.0.1':
|
||||
cors_allowed_origins.extend(
|
||||
[
|
||||
f'http://{self.host}:{self.port}',
|
||||
f'http://localhost:{self.port}',
|
||||
]
|
||||
)
|
||||
cors_allowed_origins = (
|
||||
cors_allowed_origins + additional_allowed_origins
|
||||
)
|
||||
|
||||
self.app = Flask(
|
||||
__name__, static_url_path='', static_folder='../frontend/dist/'
|
||||
)
|
||||
|
||||
self.socketio = SocketIO(
|
||||
self.app,
|
||||
logger=logger,
|
||||
engineio_logger=engineio_logger,
|
||||
max_http_buffer_size=max_http_buffer_size,
|
||||
cors_allowed_origins=cors_allowed_origins,
|
||||
ping_interval=(50, 50),
|
||||
ping_timeout=60,
|
||||
)
|
||||
|
||||
# Outputs Route
|
||||
self.app.config['OUTPUTS_FOLDER'] = f'../{args.outdir}'
|
||||
|
||||
@self.app.route('/outputs/<path:filename>')
|
||||
def outputs(filename):
|
||||
return send_from_directory(
|
||||
self.app.config['OUTPUTS_FOLDER'], filename
|
||||
)
|
||||
|
||||
# Base Route
|
||||
@self.app.route('/')
|
||||
def serve():
|
||||
if args.web_develop:
|
||||
return redirect('http://127.0.0.1:5173')
|
||||
else:
|
||||
return send_from_directory(
|
||||
self.app.static_folder, 'index.html'
|
||||
)
|
||||
|
||||
self.load_socketio_listeners(self.socketio)
|
||||
|
||||
print('>> Started Invoke AI Web Server!')
|
||||
if self.host == '0.0.0.0':
|
||||
print(
|
||||
f"Point your browser at http://localhost:{self.port} or use the host's DNS name or IP address."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
'>> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address.'
|
||||
)
|
||||
print(f'>> Point your browser at http://{self.host}:{self.port}')
|
||||
|
||||
self.socketio.run(app=self.app, host=self.host, port=self.port)
|
||||
|
||||
def setup_app(self):
|
||||
# location for "finished" images
|
||||
self.result_path = args.outdir
|
||||
# temporary path for intermediates
|
||||
self.intermediate_path = os.path.join(
|
||||
self.result_path, 'intermediates/'
|
||||
)
|
||||
# path for user-uploaded init images and masks
|
||||
self.init_image_path = os.path.join(self.result_path, 'init-images/')
|
||||
self.mask_image_path = os.path.join(self.result_path, 'mask-images/')
|
||||
# txt log
|
||||
self.log_path = os.path.join(self.result_path, 'dream_log.txt')
|
||||
# make all output paths
|
||||
[
|
||||
os.makedirs(path, exist_ok=True)
|
||||
for path in [
|
||||
self.result_path,
|
||||
self.intermediate_path,
|
||||
self.init_image_path,
|
||||
self.mask_image_path,
|
||||
]
|
||||
]
|
||||
|
||||
def load_socketio_listeners(self, socketio):
|
||||
@socketio.on('requestSystemConfig')
|
||||
def handle_request_capabilities():
|
||||
print(f'>> System config requested')
|
||||
config = self.get_system_config()
|
||||
socketio.emit('systemConfig', config)
|
||||
|
||||
@socketio.on('requestImages')
|
||||
def handle_request_images(page=1, offset=0, last_mtime=None):
|
||||
chunk_size = 50
|
||||
|
||||
if last_mtime:
|
||||
print(f'>> Latest images requested')
|
||||
else:
|
||||
print(
|
||||
f'>> Page {page} of images requested (page size {chunk_size} offset {offset})'
|
||||
)
|
||||
|
||||
paths = glob.glob(os.path.join(self.result_path, '*.png'))
|
||||
sorted_paths = sorted(
|
||||
paths, key=lambda x: os.path.getmtime(x), reverse=True
|
||||
)
|
||||
|
||||
if last_mtime:
|
||||
image_paths = filter(
|
||||
lambda x: os.path.getmtime(x) > last_mtime, sorted_paths
|
||||
)
|
||||
else:
|
||||
|
||||
image_paths = sorted_paths[
|
||||
slice(
|
||||
chunk_size * (page - 1) + offset,
|
||||
chunk_size * page + offset,
|
||||
)
|
||||
]
|
||||
page = page + 1
|
||||
|
||||
image_array = []
|
||||
|
||||
for path in image_paths:
|
||||
metadata = retrieve_metadata(path)
|
||||
image_array.append(
|
||||
{
|
||||
'url': path,
|
||||
'mtime': os.path.getmtime(path),
|
||||
'metadata': metadata['sd-metadata'],
|
||||
}
|
||||
)
|
||||
|
||||
socketio.emit(
|
||||
'galleryImages',
|
||||
{
|
||||
'images': image_array,
|
||||
'nextPage': page,
|
||||
'offset': offset,
|
||||
'onlyNewImages': True if last_mtime else False,
|
||||
},
|
||||
)
|
||||
|
||||
@socketio.on('generateImage')
|
||||
def handle_generate_image_event(
|
||||
generation_parameters, esrgan_parameters, gfpgan_parameters
|
||||
):
|
||||
print(
|
||||
f'>> Image generation requested: {generation_parameters}\nESRGAN parameters: {esrgan_parameters}\nGFPGAN parameters: {gfpgan_parameters}'
|
||||
)
|
||||
self.generate_images(
|
||||
generation_parameters, esrgan_parameters, gfpgan_parameters
|
||||
)
|
||||
|
||||
@socketio.on('runESRGAN')
|
||||
def handle_run_esrgan_event(original_image, esrgan_parameters):
|
||||
print(
|
||||
f'>> ESRGAN upscale requested for "{original_image["url"]}": {esrgan_parameters}'
|
||||
)
|
||||
progress = {
|
||||
'currentStep': 1,
|
||||
'totalSteps': 1,
|
||||
'currentIteration': 1,
|
||||
'totalIterations': 1,
|
||||
'currentStatus': 'Preparing',
|
||||
'isProcessing': True,
|
||||
'currentStatusHasSteps': False,
|
||||
}
|
||||
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
image = Image.open(original_image['url'])
|
||||
|
||||
seed = (
|
||||
original_image['metadata']['seed']
|
||||
if 'seed' in original_image['metadata']
|
||||
else 'unknown_seed'
|
||||
)
|
||||
|
||||
progress['currentStatus'] = 'Upscaling'
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
image = self.esrgan.process(
|
||||
image=image,
|
||||
upsampler_scale=esrgan_parameters['upscale'][0],
|
||||
strength=esrgan_parameters['upscale'][1],
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
progress['currentStatus'] = 'Saving image'
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
esrgan_parameters['seed'] = seed
|
||||
metadata = self.parameters_to_post_processed_image_metadata(
|
||||
parameters=esrgan_parameters,
|
||||
original_image_path=original_image['url'],
|
||||
type='esrgan',
|
||||
)
|
||||
command = parameters_to_command(esrgan_parameters)
|
||||
|
||||
path = self.save_image(
|
||||
image,
|
||||
command,
|
||||
metadata,
|
||||
self.result_path,
|
||||
postprocessing='esrgan',
|
||||
)
|
||||
|
||||
self.write_log_message(
|
||||
f'[Upscaled] "{original_image["url"]}" > "{path}": {command}'
|
||||
)
|
||||
|
||||
progress['currentStatus'] = 'Finished'
|
||||
progress['currentStep'] = 0
|
||||
progress['totalSteps'] = 0
|
||||
progress['currentIteration'] = 0
|
||||
progress['totalIterations'] = 0
|
||||
progress['isProcessing'] = False
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
socketio.emit(
|
||||
'esrganResult',
|
||||
{
|
||||
'url': os.path.relpath(path),
|
||||
'mtime': os.path.getmtime(path),
|
||||
'metadata': metadata,
|
||||
},
|
||||
)
|
||||
|
||||
@socketio.on('runGFPGAN')
|
||||
def handle_run_gfpgan_event(original_image, gfpgan_parameters):
|
||||
print(
|
||||
f'>> GFPGAN face fix requested for "{original_image["url"]}": {gfpgan_parameters}'
|
||||
)
|
||||
progress = {
|
||||
'currentStep': 1,
|
||||
'totalSteps': 1,
|
||||
'currentIteration': 1,
|
||||
'totalIterations': 1,
|
||||
'currentStatus': 'Preparing',
|
||||
'isProcessing': True,
|
||||
'currentStatusHasSteps': False,
|
||||
}
|
||||
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
image = Image.open(original_image['url'])
|
||||
|
||||
seed = (
|
||||
original_image['metadata']['seed']
|
||||
if 'seed' in original_image['metadata']
|
||||
else 'unknown_seed'
|
||||
)
|
||||
|
||||
progress['currentStatus'] = 'Fixing faces'
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
image = self.gfpgan.process(
|
||||
image=image,
|
||||
strength=gfpgan_parameters['gfpgan_strength'],
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
progress['currentStatus'] = 'Saving image'
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
gfpgan_parameters['seed'] = seed
|
||||
metadata = self.parameters_to_post_processed_image_metadata(
|
||||
parameters=gfpgan_parameters,
|
||||
original_image_path=original_image['url'],
|
||||
type='gfpgan',
|
||||
)
|
||||
command = parameters_to_command(gfpgan_parameters)
|
||||
|
||||
path = self.save_image(
|
||||
image,
|
||||
command,
|
||||
metadata,
|
||||
self.result_path,
|
||||
postprocessing='gfpgan',
|
||||
)
|
||||
|
||||
self.write_log_message(
|
||||
f'[Fixed faces] "{original_image["url"]}" > "{path}": {command}'
|
||||
)
|
||||
|
||||
progress['currentStatus'] = 'Finished'
|
||||
progress['currentStep'] = 0
|
||||
progress['totalSteps'] = 0
|
||||
progress['currentIteration'] = 0
|
||||
progress['totalIterations'] = 0
|
||||
progress['isProcessing'] = False
|
||||
socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
socketio.emit(
|
||||
'gfpganResult',
|
||||
{
|
||||
'url': os.path.relpath(path),
|
||||
'mtime': os.path.getmtime(path),
|
||||
'metadata': metadata,
|
||||
},
|
||||
)
|
||||
|
||||
@socketio.on('cancel')
|
||||
def handle_cancel():
|
||||
print(f'>> Cancel processing requested')
|
||||
self.canceled.set()
|
||||
socketio.emit('processingCanceled')
|
||||
|
||||
# TODO: I think this needs a safety mechanism.
|
||||
@socketio.on('deleteImage')
|
||||
def handle_delete_image(path, uuid):
|
||||
print(f'>> Delete requested "{path}"')
|
||||
from send2trash import send2trash
|
||||
|
||||
send2trash(path)
|
||||
socketio.emit('imageDeleted', {'url': path, 'uuid': uuid})
|
||||
|
||||
# TODO: I think this needs a safety mechanism.
|
||||
@socketio.on('uploadInitialImage')
|
||||
def handle_upload_initial_image(bytes, name):
|
||||
print(f'>> Init image upload requested "{name}"')
|
||||
uuid = uuid4().hex
|
||||
split = os.path.splitext(name)
|
||||
name = f'{split[0]}.{uuid}{split[1]}'
|
||||
file_path = os.path.join(self.init_image_path, name)
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
newFile = open(file_path, 'wb')
|
||||
newFile.write(bytes)
|
||||
socketio.emit(
|
||||
'initialImageUploaded', {'url': file_path, 'uuid': ''}
|
||||
)
|
||||
|
||||
# TODO: I think this needs a safety mechanism.
|
||||
@socketio.on('uploadMaskImage')
|
||||
def handle_upload_mask_image(bytes, name):
|
||||
print(f'>> Mask image upload requested "{name}"')
|
||||
uuid = uuid4().hex
|
||||
split = os.path.splitext(name)
|
||||
name = f'{split[0]}.{uuid}{split[1]}'
|
||||
file_path = os.path.join(self.mask_image_path, name)
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
newFile = open(file_path, 'wb')
|
||||
newFile.write(bytes)
|
||||
socketio.emit('maskImageUploaded', {'url': file_path, 'uuid': ''})
|
||||
|
||||
# App Functions
|
||||
def get_system_config(self):
|
||||
return {
|
||||
'model': 'stable diffusion',
|
||||
'model_id': args.model,
|
||||
'model_hash': self.generate.model_hash,
|
||||
'app_id': APP_ID,
|
||||
'app_version': APP_VERSION,
|
||||
}
|
||||
|
||||
def generate_images(
|
||||
self, generation_parameters, esrgan_parameters, gfpgan_parameters
|
||||
):
|
||||
self.canceled.clear()
|
||||
|
||||
step_index = 1
|
||||
prior_variations = (
|
||||
generation_parameters['with_variations']
|
||||
if 'with_variations' in generation_parameters
|
||||
else []
|
||||
)
|
||||
"""
|
||||
If a result image is used as an init image, and then deleted, we will want to be
|
||||
able to use it as an init image in the future. Need to copy it.
|
||||
|
||||
If the init/mask image doesn't exist in the init_image_path/mask_image_path,
|
||||
make a unique filename for it and copy it there.
|
||||
"""
|
||||
if 'init_img' in generation_parameters:
|
||||
filename = os.path.basename(generation_parameters['init_img'])
|
||||
if not os.path.exists(
|
||||
os.path.join(self.init_image_path, filename)
|
||||
):
|
||||
unique_filename = self.make_unique_init_image_filename(
|
||||
filename
|
||||
)
|
||||
new_path = os.path.join(self.init_image_path, unique_filename)
|
||||
shutil.copy(generation_parameters['init_img'], new_path)
|
||||
generation_parameters['init_img'] = new_path
|
||||
if 'init_mask' in generation_parameters:
|
||||
filename = os.path.basename(generation_parameters['init_mask'])
|
||||
if not os.path.exists(
|
||||
os.path.join(self.mask_image_path, filename)
|
||||
):
|
||||
unique_filename = self.make_unique_init_image_filename(
|
||||
filename
|
||||
)
|
||||
new_path = os.path.join(
|
||||
self.init_image_path, unique_filename
|
||||
)
|
||||
shutil.copy(generation_parameters['init_img'], new_path)
|
||||
generation_parameters['init_mask'] = new_path
|
||||
|
||||
totalSteps = self.calculate_real_steps(
|
||||
steps=generation_parameters['steps'],
|
||||
strength=generation_parameters['strength']
|
||||
if 'strength' in generation_parameters
|
||||
else None,
|
||||
has_init_image='init_img' in generation_parameters,
|
||||
)
|
||||
|
||||
progress = {
|
||||
'currentStep': 1,
|
||||
'totalSteps': totalSteps,
|
||||
'currentIteration': 1,
|
||||
'totalIterations': generation_parameters['iterations'],
|
||||
'currentStatus': 'Preparing',
|
||||
'isProcessing': True,
|
||||
'currentStatusHasSteps': False,
|
||||
}
|
||||
|
||||
self.socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
def image_progress(sample, step):
|
||||
if self.canceled.is_set():
|
||||
raise CanceledException
|
||||
|
||||
nonlocal step_index
|
||||
nonlocal generation_parameters
|
||||
nonlocal progress
|
||||
|
||||
progress['currentStep'] = step + 1
|
||||
progress['currentStatus'] = 'Generating'
|
||||
progress['currentStatusHasSteps'] = True
|
||||
|
||||
if (
|
||||
generation_parameters['progress_images']
|
||||
and step % 5 == 0
|
||||
and step < generation_parameters['steps'] - 1
|
||||
):
|
||||
image = self.generate.sample_to_image(sample)
|
||||
metadata = self.parameters_to_generated_image_metadata(generation_parameters)
|
||||
command = parameters_to_command(generation_parameters)
|
||||
path = self.save_image(image, command, metadata, self.intermediate_path, step_index=step_index, postprocessing=False)
|
||||
|
||||
step_index += 1
|
||||
self.socketio.emit(
|
||||
'intermediateResult',
|
||||
{
|
||||
'url': os.path.relpath(path),
|
||||
'mtime': os.path.getmtime(path),
|
||||
'metadata': metadata,
|
||||
},
|
||||
)
|
||||
self.socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
def image_done(image, seed, first_seed):
|
||||
nonlocal generation_parameters
|
||||
nonlocal esrgan_parameters
|
||||
nonlocal gfpgan_parameters
|
||||
nonlocal progress
|
||||
|
||||
step_index = 1
|
||||
nonlocal prior_variations
|
||||
|
||||
progress['currentStatus'] = 'Generation complete'
|
||||
self.socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
all_parameters = generation_parameters
|
||||
postprocessing = False
|
||||
|
||||
if (
|
||||
'variation_amount' in all_parameters
|
||||
and all_parameters['variation_amount'] > 0
|
||||
):
|
||||
first_seed = first_seed or seed
|
||||
this_variation = [[seed, all_parameters['variation_amount']]]
|
||||
all_parameters['with_variations'] = (
|
||||
prior_variations + this_variation
|
||||
)
|
||||
all_parameters['seed'] = first_seed
|
||||
elif 'with_variations' in all_parameters:
|
||||
all_parameters['seed'] = first_seed
|
||||
else:
|
||||
all_parameters['seed'] = seed
|
||||
|
||||
if esrgan_parameters:
|
||||
progress['currentStatus'] = 'Upscaling'
|
||||
progress['currentStatusHasSteps'] = False
|
||||
self.socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
image = self.esrgan.process(
|
||||
image=image,
|
||||
upsampler_scale=esrgan_parameters['level'],
|
||||
strength=esrgan_parameters['strength'],
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
postprocessing = True
|
||||
all_parameters['upscale'] = [
|
||||
esrgan_parameters['level'],
|
||||
esrgan_parameters['strength'],
|
||||
]
|
||||
|
||||
if gfpgan_parameters:
|
||||
progress['currentStatus'] = 'Fixing faces'
|
||||
progress['currentStatusHasSteps'] = False
|
||||
self.socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
image = self.gfpgan.process(
|
||||
image=image,
|
||||
strength=gfpgan_parameters['strength'],
|
||||
seed=seed,
|
||||
)
|
||||
postprocessing = True
|
||||
all_parameters['gfpgan_strength'] = gfpgan_parameters[
|
||||
'strength'
|
||||
]
|
||||
|
||||
progress['currentStatus'] = 'Saving image'
|
||||
self.socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
metadata = self.parameters_to_generated_image_metadata(
|
||||
all_parameters
|
||||
)
|
||||
command = parameters_to_command(all_parameters)
|
||||
|
||||
path = self.save_image(
|
||||
image,
|
||||
command,
|
||||
metadata,
|
||||
self.result_path,
|
||||
postprocessing=postprocessing,
|
||||
)
|
||||
|
||||
print(f'>> Image generated: "{path}"')
|
||||
self.write_log_message(f'[Generated] "{path}": {command}')
|
||||
|
||||
if progress['totalIterations'] > progress['currentIteration']:
|
||||
progress['currentStep'] = 1
|
||||
progress['currentIteration'] += 1
|
||||
progress['currentStatus'] = 'Iteration finished'
|
||||
progress['currentStatusHasSteps'] = False
|
||||
else:
|
||||
progress['currentStep'] = 0
|
||||
progress['totalSteps'] = 0
|
||||
progress['currentIteration'] = 0
|
||||
progress['totalIterations'] = 0
|
||||
progress['currentStatus'] = 'Finished'
|
||||
progress['isProcessing'] = False
|
||||
|
||||
self.socketio.emit('progressUpdate', progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
self.socketio.emit(
|
||||
'generationResult',
|
||||
{
|
||||
'url': os.path.relpath(path),
|
||||
'mtime': os.path.getmtime(path),
|
||||
'metadata': metadata,
|
||||
},
|
||||
)
|
||||
eventlet.sleep(0)
|
||||
|
||||
try:
|
||||
self.generate.prompt2image(
|
||||
**generation_parameters,
|
||||
step_callback=image_progress,
|
||||
image_callback=image_done,
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except CanceledException:
|
||||
pass
|
||||
except Exception as e:
|
||||
self.socketio.emit('error', {'message': (str(e))})
|
||||
print('\n')
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
print('\n')
|
||||
|
||||
def parameters_to_generated_image_metadata(self, parameters):
|
||||
# top-level metadata minus `image` or `images`
|
||||
metadata = self.get_system_config()
|
||||
# remove any image keys not mentioned in RFC #266
|
||||
rfc266_img_fields = [
|
||||
'type',
|
||||
'postprocessing',
|
||||
'sampler',
|
||||
'prompt',
|
||||
'seed',
|
||||
'variations',
|
||||
'steps',
|
||||
'cfg_scale',
|
||||
'step_number',
|
||||
'width',
|
||||
'height',
|
||||
'extra',
|
||||
'seamless',
|
||||
]
|
||||
|
||||
rfc_dict = {}
|
||||
|
||||
for item in parameters.items():
|
||||
key, value = item
|
||||
if key in rfc266_img_fields:
|
||||
rfc_dict[key] = value
|
||||
|
||||
postprocessing = []
|
||||
|
||||
# 'postprocessing' is either null or an
|
||||
if 'gfpgan_strength' in parameters:
|
||||
|
||||
postprocessing.append(
|
||||
{
|
||||
'type': 'gfpgan',
|
||||
'strength': float(parameters['gfpgan_strength']),
|
||||
}
|
||||
)
|
||||
|
||||
if 'upscale' in parameters:
|
||||
postprocessing.append(
|
||||
{
|
||||
'type': 'esrgan',
|
||||
'scale': int(parameters['upscale'][0]),
|
||||
'strength': float(parameters['upscale'][1]),
|
||||
}
|
||||
)
|
||||
|
||||
rfc_dict['postprocessing'] = (
|
||||
postprocessing if len(postprocessing) > 0 else None
|
||||
)
|
||||
|
||||
# semantic drift
|
||||
rfc_dict['sampler'] = parameters['sampler_name']
|
||||
|
||||
# display weighted subprompts (liable to change)
|
||||
subprompts = split_weighted_subprompts(parameters['prompt'])
|
||||
subprompts = [{'prompt': x[0], 'weight': x[1]} for x in subprompts]
|
||||
rfc_dict['prompt'] = subprompts
|
||||
|
||||
# 'variations' should always exist and be an array, empty or consisting of {'seed': seed, 'weight': weight} pairs
|
||||
variations = []
|
||||
|
||||
if 'with_variations' in parameters:
|
||||
variations = [
|
||||
{'seed': x[0], 'weight': x[1]}
|
||||
for x in parameters['with_variations']
|
||||
]
|
||||
|
||||
rfc_dict['variations'] = variations
|
||||
|
||||
if 'init_img' in parameters:
|
||||
rfc_dict['type'] = 'img2img'
|
||||
rfc_dict['strength'] = parameters['strength']
|
||||
rfc_dict['fit'] = parameters['fit'] # TODO: Noncompliant
|
||||
rfc_dict['orig_hash'] = calculate_init_img_hash(
|
||||
parameters['init_img']
|
||||
)
|
||||
rfc_dict['init_image_path'] = parameters[
|
||||
'init_img'
|
||||
] # TODO: Noncompliant
|
||||
rfc_dict[
|
||||
'sampler'
|
||||
] = 'ddim' # TODO: FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
|
||||
if 'init_mask' in parameters:
|
||||
rfc_dict['mask_hash'] = calculate_init_img_hash(
|
||||
parameters['init_mask']
|
||||
) # TODO: Noncompliant
|
||||
rfc_dict['mask_image_path'] = parameters[
|
||||
'init_mask'
|
||||
] # TODO: Noncompliant
|
||||
else:
|
||||
rfc_dict['type'] = 'txt2img'
|
||||
|
||||
metadata['image'] = rfc_dict
|
||||
|
||||
return metadata
|
||||
|
||||
def parameters_to_post_processed_image_metadata(
|
||||
self, parameters, original_image_path, type
|
||||
):
|
||||
# top-level metadata minus `image` or `images`
|
||||
metadata = self.get_system_config()
|
||||
|
||||
orig_hash = calculate_init_img_hash(original_image_path)
|
||||
|
||||
image = {'orig_path': original_image_path, 'orig_hash': orig_hash}
|
||||
|
||||
if type == 'esrgan':
|
||||
image['type'] = 'esrgan'
|
||||
image['scale'] = parameters['upscale'][0]
|
||||
image['strength'] = parameters['upscale'][1]
|
||||
elif type == 'gfpgan':
|
||||
image['type'] = 'gfpgan'
|
||||
image['strength'] = parameters['gfpgan_strength']
|
||||
else:
|
||||
raise TypeError(f'Invalid type: {type}')
|
||||
|
||||
metadata['image'] = image
|
||||
return metadata
|
||||
|
||||
def save_image(
|
||||
self,
|
||||
image,
|
||||
command,
|
||||
metadata,
|
||||
output_dir,
|
||||
step_index=None,
|
||||
postprocessing=False,
|
||||
):
|
||||
pngwriter = PngWriter(output_dir)
|
||||
prefix = pngwriter.unique_prefix()
|
||||
|
||||
seed = 'unknown_seed'
|
||||
|
||||
if 'image' in metadata:
|
||||
if 'seed' in metadata['image']:
|
||||
seed = metadata['image']['seed']
|
||||
|
||||
filename = f'{prefix}.{seed}'
|
||||
|
||||
if step_index:
|
||||
filename += f'.{step_index}'
|
||||
if postprocessing:
|
||||
filename += f'.postprocessed'
|
||||
|
||||
filename += '.png'
|
||||
|
||||
path = pngwriter.save_image_and_prompt_to_png(
|
||||
image=image, dream_prompt=command, metadata=metadata, name=filename
|
||||
)
|
||||
|
||||
return path
|
||||
|
||||
def make_unique_init_image_filename(self, name):
|
||||
uuid = uuid4().hex
|
||||
split = os.path.splitext(name)
|
||||
name = f'{split[0]}.{uuid}{split[1]}'
|
||||
return name
|
||||
|
||||
def calculate_real_steps(self, steps, strength, has_init_image):
|
||||
import math
|
||||
return math.floor(strength * steps) if has_init_image else steps
|
||||
|
||||
def write_log_message(self, message):
|
||||
"""Logs the filename and parameters used to generate or process that image to log file"""
|
||||
message = f'{message}\n'
|
||||
with open(self.log_path, 'a', encoding='utf-8') as file:
|
||||
file.writelines(message)
|
||||
|
||||
|
||||
class CanceledException(Exception):
|
||||
pass
|
@ -1,4 +1,4 @@
|
||||
from modules.parse_seed_weights import parse_seed_weights
|
||||
from backend.modules.parse_seed_weights import parse_seed_weights
|
||||
import argparse
|
||||
|
||||
SAMPLER_CHOICES = [
|
||||
|
@ -421,6 +421,23 @@ class Args(object):
|
||||
action='store_true',
|
||||
help='Start in web server mode.',
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--web_develop',
|
||||
dest='web_develop',
|
||||
action='store_true',
|
||||
help='Start in web server development mode.',
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
"--web_verbose",
|
||||
action="store_true",
|
||||
help="Enables verbose logging",
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
"--cors",
|
||||
nargs="*",
|
||||
type=str,
|
||||
help="Additional allowed origins, comma-separated",
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--host',
|
||||
type=str,
|
||||
|
@ -12,11 +12,12 @@ sys.path.append('.') # corrects a weird problem on Macs
|
||||
import ldm.dream.readline
|
||||
from ldm.dream.args import Args, metadata_dumps, metadata_from_png
|
||||
from ldm.dream.pngwriter import PngWriter
|
||||
from ldm.dream.server import DreamServer, ThreadingDreamServer
|
||||
from ldm.dream.image_util import make_grid
|
||||
from ldm.dream.log import write_log
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from backend.invoke_ai_web_server import InvokeAIWebServer
|
||||
|
||||
# Placeholder to be replaced with proper class that tracks the
|
||||
# outputs and associates with the prompt that generated them.
|
||||
# Just want to get the formatting look right for now.
|
||||
@ -111,16 +112,16 @@ def main():
|
||||
#set additional option
|
||||
gen.free_gpu_mem = opt.free_gpu_mem
|
||||
|
||||
# web server loops forever
|
||||
if opt.web:
|
||||
invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan)
|
||||
sys.exit(0)
|
||||
|
||||
if not infile:
|
||||
print(
|
||||
"\n* Initialization done! Awaiting your command (-h for help, 'q' to quit)"
|
||||
)
|
||||
|
||||
# web server loops forever
|
||||
if opt.web:
|
||||
dream_server_loop(gen, opt.host, opt.port, opt.outdir, gfpgan)
|
||||
sys.exit(0)
|
||||
|
||||
main_loop(gen, opt, infile)
|
||||
|
||||
# TODO: main_loop() has gotten busy. Needs to be refactored.
|
||||
@ -414,35 +415,20 @@ def get_next_command(infile=None) -> str: # command string
|
||||
print(f'#{command}')
|
||||
return command
|
||||
|
||||
def dream_server_loop(gen, host, port, outdir, gfpgan):
|
||||
def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan):
|
||||
print('\n* --web was specified, starting web server...')
|
||||
# Change working directory to the stable-diffusion directory
|
||||
os.chdir(
|
||||
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||
)
|
||||
|
||||
# Start server
|
||||
DreamServer.model = gen # misnomer in DreamServer - this is not the model you are looking for
|
||||
DreamServer.outdir = outdir
|
||||
DreamServer.gfpgan_model_exists = False
|
||||
if gfpgan is not None:
|
||||
DreamServer.gfpgan_model_exists = gfpgan.gfpgan_model_exists
|
||||
|
||||
dream_server = ThreadingDreamServer((host, port))
|
||||
print(">> Started Stable Diffusion dream server!")
|
||||
if host == '0.0.0.0':
|
||||
print(
|
||||
f"Point your browser at http://localhost:{port} or use the host's DNS name or IP address.")
|
||||
else:
|
||||
print(">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address.")
|
||||
print(f">> Point your browser at http://{host}:{port}")
|
||||
|
||||
invoke_ai_web_server = InvokeAIWebServer(generate=gen, gfpgan=gfpgan, codeformer=codeformer, esrgan=esrgan)
|
||||
|
||||
try:
|
||||
dream_server.serve_forever()
|
||||
invoke_ai_web_server.run()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
dream_server.server_close()
|
||||
|
||||
|
||||
def write_log_message(results, log_path):
|
||||
"""logs the name of the output image, prompt, and prompt args to the terminal and log file"""
|
||||
|
Loading…
Reference in New Issue
Block a user