From 3b45a5dc1698140369a609fbb22bf7dce5ae9b42 Mon Sep 17 00:00:00 2001 From: lnyan Date: Sun, 23 Oct 2022 19:20:49 +0800 Subject: [PATCH] Fix dtype issue for legacy pipeline --- app.py | 54 ++++++++++++++++++++++++++++++++++++------------------ utils.py | 4 ++-- 2 files changed, 38 insertions(+), 20 deletions(-) diff --git a/app.py b/app.py index a12fcb3..280c641 100644 --- a/app.py +++ b/app.py @@ -80,6 +80,7 @@ finally: device = "cpu" import contextlib + autocast = contextlib.nullcontext with open("config.yaml", "r") as yaml_in: @@ -290,7 +291,7 @@ class StableDiffusionInpaint: revision="fp16", torch_dtype=torch.float16, use_auth_token=token, - vae=vae + vae=vae, ) else: inpaint = StableDiffusionInpaintPipeline.from_pretrained( @@ -458,7 +459,7 @@ class StableDiffusion: revision="fp16", torch_dtype=torch.float16, use_auth_token=token, - vae = vae + vae=vae, ) else: text2img = StableDiffusionPipeline.from_pretrained( @@ -482,11 +483,13 @@ class StableDiffusion: revision="fp16", torch_dtype=torch.float16, use_auth_token=token, - vae=vae + vae=vae, ).to(device) else: inpaint = StableDiffusionInpaintPipeline.from_pretrained( - "runwayml/stable-diffusion-inpainting", use_auth_token=token, vae=vae + "runwayml/stable-diffusion-inpainting", + use_auth_token=token, + vae=vae, ).to(device) text2img_unet.to(device) text2img = StableDiffusionPipeline( @@ -657,10 +660,11 @@ class StableDiffusion: init_image = Image.fromarray(img) mask_image = Image.fromarray(mask) # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8)) - if True: - input_image = init_image.resize( - (process_width, process_height), resample=SAMPLING_MODE - ) + input_image = init_image.resize( + (process_width, process_height), resample=SAMPLING_MODE + ) + if self.inpainting_model: + images = inpaint_func( prompt=prompt, init_image=input_image, @@ -670,6 +674,17 @@ class StableDiffusion: mask_image=mask_image.resize((process_width, process_height)), **extra_kwargs, )["images"] + else: + with autocast("cuda"): + images = inpaint_func( + prompt=prompt, + init_image=input_image, + image=input_image, + width=process_width, + height=process_height, + mask_image=mask_image.resize((process_width, process_height)), + **extra_kwargs, + )["images"] else: if True: images = text2img( @@ -740,20 +755,20 @@ def run_outpaint( pil = Image.open(io.BytesIO(data)) if interrogate_mode: if "interrogator" not in model: - model["interrogator"]=Interrogator() + model["interrogator"] = Interrogator() interrogator = model["interrogator"] - img=np.array(pil)[:,:,0:3] - mask=np.array(pil)[:,:,-1] + img = np.array(pil)[:, :, 0:3] + mask = np.array(pil)[:, :, -1] x, y = np.nonzero(mask) - if len(x)>0: + if len(x) > 0: x0, x1 = x.min(), x.max() + 1 y0, y1 = y.min(), y.max() + 1 - img=img[x0:x1,y0:y1,:] - pil=Image.fromarray(img) + img = img[x0:x1, y0:y1, :] + pil = Image.fromarray(img) interrogate_ret = interrogator.interrogate(pil) return ( gr.update(value=",".join([sel_buffer_str]),), - gr.update(label="Prompt",value=interrogate_ret), + gr.update(label="Prompt", value=interrogate_ret), state, ) width, height = pil.size @@ -995,7 +1010,10 @@ with blocks as demo: except Exception as e: print(e) return {token: gr.update(value=str(e))} - if model_choice in [ModelChoice.INPAINTING.value,ModelChoice.INPAINTING_IMG2IMG.value]: + if model_choice in [ + ModelChoice.INPAINTING.value, + ModelChoice.INPAINTING_IMG2IMG.value, + ]: init_val = "cv2_ns" else: init_val = "patchmatch" @@ -1009,7 +1027,7 @@ with blocks as demo: upload_button: gr.update(value="Upload Image"), model_selection: gr.update(visible=False), model_path_input: gr.update(visible=False), - init_mode: gr.update(value=init_val) + init_mode: gr.update(value=init_val), } setup_button.click( @@ -1032,7 +1050,7 @@ with blocks as demo: upload_button, model_selection, model_path_input, - init_mode + init_mode, ], _js=setup_button_js, ) diff --git a/utils.py b/utils.py index 8bfcfaa..dcaedff 100644 --- a/utils.py +++ b/utils.py @@ -93,8 +93,8 @@ def edge_pad(img, mask, mode=1): def perlin_noise(img, mask): - lin_x = np.linspace(0, 5, mask.shape[0], endpoint=False) - lin_y = np.linspace(0, 5, mask.shape[1], endpoint=False) + lin_x = np.linspace(0, 5, mask.shape[1], endpoint=False) + lin_y = np.linspace(0, 5, mask.shape[0], endpoint=False) x, y = np.meshgrid(lin_x, lin_y) avg = img.mean(axis=0).mean(axis=0) # noise=[((perlin(x, y)+1)*128+avg[i]).astype(np.uint8) for i in range(3)]