Add custom model choice for docker, add correction modes

This commit is contained in:
lnyan 2022-10-19 18:05:02 +08:00
parent c733a9e22a
commit 445ee72bf9
6 changed files with 70 additions and 30 deletions

71
app.py
View File

@ -408,7 +408,7 @@ if args.auth is not None:
args.auth = tuple(args.auth)
def get_model(token="", model_choice=""):
def get_model(token="", model_choice="", model_path=""):
if "model" not in model:
if not USE_GLID and model_choice == "glid-3-xl-stable":
model_choice = "stablediffusion"
@ -418,6 +418,9 @@ def get_model(token="", model_choice=""):
elif model_choice == "remote_model":
print(f"Using {args.remote_model}")
tmp = StableDiffusion(token=token, model_name=args.remote_model)
elif model_path:
print(f"Using {model_path}")
tmp = StableDiffusion(token=token, model_name=model_path)
elif model_choice == "stablediffusion":
tmp = StableDiffusion(token)
elif model_choice == "waifudiffusion":
@ -476,8 +479,7 @@ def run_outpaint(
)
base64_str_lst = []
for image in images:
if use_correction:
image = correction_func.run(pil.resize(image.size), image)
image = correction_func.run(pil.resize(image.size), image, mode=use_correction)
resized_img = image.resize((width, height), resample=SAMPLING_MODE,)
out = sel_buffer.copy()
out[:, :, 0:3] = np.array(resized_img)
@ -530,6 +532,7 @@ min-height: 0rem;
}
""",
)
model_path_input_val = ""
with blocks as demo:
# title
title = gr.Markdown(
@ -544,10 +547,12 @@ with blocks as demo:
if not RUN_IN_SPACE:
model_choices_lst = ["stablediffusion", "waifudiffusion", "glid-3-xl-stable"]
if args.local_model:
model_path_input_val = args.local_model
model_choices_lst.insert(0, "local_model")
elif args.remote_model:
model_path_input_val = args.remote_model
model_choices_lst.insert(0, "remote_model")
with gr.Row():
with gr.Row(elem_id="setup_row"):
with gr.Column(scale=4, min_width=350):
token = gr.Textbox(
label="Huggingface token",
@ -581,11 +586,17 @@ with blocks as demo:
precision=0,
elem_id="selection_size",
)
model_path_input = gr.Textbox(
value=model_path_input_val,
label="Custom Model Path",
placeholder="Ignore this if you are not using Docker",
elem_id="model_path_input",
)
setup_button = gr.Button("Click to Setup (may take a while)", variant="primary")
with gr.Row():
with gr.Column(scale=3, min_width=270):
with gr.Column(scale=2, min_width=150):
init_mode = gr.Radio(
label="Init mode",
label="Init Mode",
choices=[
"patchmatch",
"edge_pad",
@ -597,8 +608,17 @@ with blocks as demo:
value="patchmatch",
type="value",
)
postprocess_check = gr.Radio(
label="Photometric Correction Mode",
choices=[
"disabled",
"mask_mode",
"border_mode",
],
value="disabled",
type="value",
)
# canvas control
sd_generate_num = gr.Number(label="Sample number", value=1, precision=0)
with gr.Column(scale=3, min_width=270):
sd_prompt = gr.Textbox(
@ -610,17 +630,26 @@ with blocks as demo:
lines=2,
)
with gr.Column(scale=2, min_width=150):
sd_strength = gr.Slider(
label="Strength", minimum=0.0, maximum=1.0, value=0.75, step=0.01
)
sd_scheduler = gr.Dropdown(
list(scheduler_dict.keys()), label="Scheduler", value="PLMS"
)
with gr.Column(scale=1, min_width=150):
with gr.Group():
with gr.Row():
sd_generate_num = gr.Number(
label="Sample number", value=1, precision=0
)
sd_strength = gr.Slider(
label="Strength",
minimum=0.0,
maximum=1.0,
value=0.75,
step=0.01,
)
with gr.Row():
sd_scheduler = gr.Dropdown(
list(scheduler_dict.keys()), label="Scheduler", value="PLMS"
)
sd_scheduler_eta = gr.Number(label="Eta", value=0.0)
with gr.Column(scale=1, min_width=80):
sd_step = gr.Number(label="Step", value=50, precision=0)
sd_guidance = gr.Number(label="Guidance", value=7.5)
sd_scheduler_eta = gr.Number(label="Eta", value=0.0)
proceed_button = gr.Button("Proceed", elem_id="proceed", visible=DEBUG_MODE)
xss_js = load_js("xss").replace("\n", " ")
@ -631,9 +660,6 @@ with blocks as demo:
)
# sd pipeline parameters
sd_img2img = gr.Checkbox(label="Enable Img2Img", value=False, visible=False)
postprocess_check = gr.Checkbox(
label="Photometric Correction", value=False, visible=False
)
sd_resize = gr.Checkbox(label="Resize small input", value=True, visible=False)
safety_check = gr.Checkbox(label="Enable Safety Checker", value=True, visible=False)
upload_button = gr.Button(
@ -648,9 +674,9 @@ with blocks as demo:
upload_output_state = gr.State(value=0)
if not RUN_IN_SPACE:
def setup_func(token_val, width, height, size, model_choice):
def setup_func(token_val, width, height, size, model_choice, model_path):
try:
get_model(token_val, model_choice)
get_model(token_val, model_choice, model_path=model_path)
except Exception as e:
print(e)
return {token: gr.update(value=str(e))}
@ -663,6 +689,7 @@ with blocks as demo:
frame: gr.update(visible=True),
upload_button: gr.update(value="Upload Image"),
model_selection: gr.update(visible=False),
model_path_input: gr.update(visible=False),
}
setup_button.click(
@ -673,6 +700,7 @@ with blocks as demo:
canvas_height,
selection_size,
model_selection,
model_path_input,
],
outputs=[
token,
@ -683,6 +711,7 @@ with blocks as demo:
frame,
upload_button,
model_selection,
model_path_input,
],
_js=setup_button_js,
)

View File

@ -283,7 +283,7 @@ dependencies:
- fpie==0.2.4
- google-auth==2.12.0
- google-auth-oauthlib==0.4.6
- gradio==3.4.0
- gradio==3.6
- grpcio==1.49.1
- h11==0.12.0
- httpcore==0.15.0

View File

@ -18,7 +18,8 @@ function(sel_buffer_str,
let app=document.querySelector("gradio-app");
app=app.shadowRoot??app;
sel_buffer=app.querySelector("#input textarea").value;
({resize_check,enable_safety,use_correction,enable_img2img,use_seed,seed_val}=window.config_obj);
let use_correction_bak=false;
({resize_check,enable_safety,use_correction_bak,enable_img2img,use_seed,seed_val}=window.config_obj);
return [
sel_buffer,
prompt_text,

View File

@ -1,8 +1,11 @@
function(token_val, width, height, size, model_choice){
function(token_val, width, height, size, model_choice, model_path){
let app=document.querySelector("gradio-app");
app=app.shadowRoot??app;
app.querySelector("#sdinfframe").style.height=80+Number(height)+"px";
app.querySelector("#setup_row").style.display="none";
app.querySelector("#model_path_input").style.display="none";
let frame=app.querySelector("#sdinfframe").contentWindow.document;
if(frame.querySelector("#setup").value=="0")
{
window.my_setup=setInterval(function(){
@ -21,5 +24,5 @@ function(token_val, width, height, size, model_choice){
{
frame.querySelector("#draw").click();
}
return [token_val, width, height, size, model_choice];
return [token_val, width, height, size, model_choice, model_path];
}

View File

@ -134,7 +134,7 @@ var toolbar=new w2toolbar({
{ type: "button", id: "setting", tooltip: "Settings", icon: "fa-solid fa-sliders" },
{ type: "break" },
check_button("enable_img2img","Enable Img2Img",false),
check_button("use_correction","Photometric Correction",false),
// check_button("use_correction","Photometric Correction",false),
check_button("resize_check","Resize Small Input",true),
check_button("enable_safety","Enable Safety Checker",true),
{type: "break"},

View File

@ -61,13 +61,18 @@ class PhotometricCorrection:
)
self.proc=proc
def run(self, original_image, inpainted_image):
def run(self, original_image, inpainted_image, mode="mask_mode"):
if mode=="disabled":
return inpainted_image
input_arr=np.array(original_image)
output_arr=np.array(inpainted_image)
mask=input_arr[:,:,-1]
mask=255-mask
mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
mask = mask.repeat(8, axis=0).repeat(8, axis=1)
if mode=="mask_mode":
mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
mask = mask.repeat(8, axis=0).repeat(8, axis=1)
else:
mask[8:-9,8:-9]=255
mask = mask[:,:,np.newaxis].repeat(3,axis=2)
nmask=mask.copy()
output_arr2=output_arr[:,:,0:3].copy()
@ -91,7 +96,7 @@ class PhotometricCorrection:
for i in range(0, args.n, args.p):
if proc.root:
result, err = proc.step(args.p) # type: ignore
print(f"PIE: Iter {i + args.p}, abs error {err}")
print(f"PIE: Iter {i + args.p}, abs_err {err}")
else:
proc.step(args.p)
@ -199,3 +204,5 @@ class PhotometricCorrection:
)
self.parser=parser
# if __name__ =="__main__":
# process=PhotometricCorrection()