mirror of
https://github.com/lkwq007/stablediffusion-infinity.git
synced 2025-01-08 11:57:27 +08:00
Fix dtype issue for legacy pipeline
This commit is contained in:
parent
535c95fe50
commit
3b45a5dc16
54
app.py
54
app.py
@ -80,6 +80,7 @@ finally:
|
|||||||
device = "cpu"
|
device = "cpu"
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
autocast = contextlib.nullcontext
|
autocast = contextlib.nullcontext
|
||||||
|
|
||||||
with open("config.yaml", "r") as yaml_in:
|
with open("config.yaml", "r") as yaml_in:
|
||||||
@ -290,7 +291,7 @@ class StableDiffusionInpaint:
|
|||||||
revision="fp16",
|
revision="fp16",
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
use_auth_token=token,
|
use_auth_token=token,
|
||||||
vae=vae
|
vae=vae,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
inpaint = StableDiffusionInpaintPipeline.from_pretrained(
|
inpaint = StableDiffusionInpaintPipeline.from_pretrained(
|
||||||
@ -458,7 +459,7 @@ class StableDiffusion:
|
|||||||
revision="fp16",
|
revision="fp16",
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
use_auth_token=token,
|
use_auth_token=token,
|
||||||
vae = vae
|
vae=vae,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
text2img = StableDiffusionPipeline.from_pretrained(
|
text2img = StableDiffusionPipeline.from_pretrained(
|
||||||
@ -482,11 +483,13 @@ class StableDiffusion:
|
|||||||
revision="fp16",
|
revision="fp16",
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
use_auth_token=token,
|
use_auth_token=token,
|
||||||
vae=vae
|
vae=vae,
|
||||||
).to(device)
|
).to(device)
|
||||||
else:
|
else:
|
||||||
inpaint = StableDiffusionInpaintPipeline.from_pretrained(
|
inpaint = StableDiffusionInpaintPipeline.from_pretrained(
|
||||||
"runwayml/stable-diffusion-inpainting", use_auth_token=token, vae=vae
|
"runwayml/stable-diffusion-inpainting",
|
||||||
|
use_auth_token=token,
|
||||||
|
vae=vae,
|
||||||
).to(device)
|
).to(device)
|
||||||
text2img_unet.to(device)
|
text2img_unet.to(device)
|
||||||
text2img = StableDiffusionPipeline(
|
text2img = StableDiffusionPipeline(
|
||||||
@ -657,10 +660,11 @@ class StableDiffusion:
|
|||||||
init_image = Image.fromarray(img)
|
init_image = Image.fromarray(img)
|
||||||
mask_image = Image.fromarray(mask)
|
mask_image = Image.fromarray(mask)
|
||||||
# mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
|
# mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
|
||||||
if True:
|
input_image = init_image.resize(
|
||||||
input_image = init_image.resize(
|
(process_width, process_height), resample=SAMPLING_MODE
|
||||||
(process_width, process_height), resample=SAMPLING_MODE
|
)
|
||||||
)
|
if self.inpainting_model:
|
||||||
|
|
||||||
images = inpaint_func(
|
images = inpaint_func(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
init_image=input_image,
|
init_image=input_image,
|
||||||
@ -670,6 +674,17 @@ class StableDiffusion:
|
|||||||
mask_image=mask_image.resize((process_width, process_height)),
|
mask_image=mask_image.resize((process_width, process_height)),
|
||||||
**extra_kwargs,
|
**extra_kwargs,
|
||||||
)["images"]
|
)["images"]
|
||||||
|
else:
|
||||||
|
with autocast("cuda"):
|
||||||
|
images = inpaint_func(
|
||||||
|
prompt=prompt,
|
||||||
|
init_image=input_image,
|
||||||
|
image=input_image,
|
||||||
|
width=process_width,
|
||||||
|
height=process_height,
|
||||||
|
mask_image=mask_image.resize((process_width, process_height)),
|
||||||
|
**extra_kwargs,
|
||||||
|
)["images"]
|
||||||
else:
|
else:
|
||||||
if True:
|
if True:
|
||||||
images = text2img(
|
images = text2img(
|
||||||
@ -740,20 +755,20 @@ def run_outpaint(
|
|||||||
pil = Image.open(io.BytesIO(data))
|
pil = Image.open(io.BytesIO(data))
|
||||||
if interrogate_mode:
|
if interrogate_mode:
|
||||||
if "interrogator" not in model:
|
if "interrogator" not in model:
|
||||||
model["interrogator"]=Interrogator()
|
model["interrogator"] = Interrogator()
|
||||||
interrogator = model["interrogator"]
|
interrogator = model["interrogator"]
|
||||||
img=np.array(pil)[:,:,0:3]
|
img = np.array(pil)[:, :, 0:3]
|
||||||
mask=np.array(pil)[:,:,-1]
|
mask = np.array(pil)[:, :, -1]
|
||||||
x, y = np.nonzero(mask)
|
x, y = np.nonzero(mask)
|
||||||
if len(x)>0:
|
if len(x) > 0:
|
||||||
x0, x1 = x.min(), x.max() + 1
|
x0, x1 = x.min(), x.max() + 1
|
||||||
y0, y1 = y.min(), y.max() + 1
|
y0, y1 = y.min(), y.max() + 1
|
||||||
img=img[x0:x1,y0:y1,:]
|
img = img[x0:x1, y0:y1, :]
|
||||||
pil=Image.fromarray(img)
|
pil = Image.fromarray(img)
|
||||||
interrogate_ret = interrogator.interrogate(pil)
|
interrogate_ret = interrogator.interrogate(pil)
|
||||||
return (
|
return (
|
||||||
gr.update(value=",".join([sel_buffer_str]),),
|
gr.update(value=",".join([sel_buffer_str]),),
|
||||||
gr.update(label="Prompt",value=interrogate_ret),
|
gr.update(label="Prompt", value=interrogate_ret),
|
||||||
state,
|
state,
|
||||||
)
|
)
|
||||||
width, height = pil.size
|
width, height = pil.size
|
||||||
@ -995,7 +1010,10 @@ with blocks as demo:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
return {token: gr.update(value=str(e))}
|
return {token: gr.update(value=str(e))}
|
||||||
if model_choice in [ModelChoice.INPAINTING.value,ModelChoice.INPAINTING_IMG2IMG.value]:
|
if model_choice in [
|
||||||
|
ModelChoice.INPAINTING.value,
|
||||||
|
ModelChoice.INPAINTING_IMG2IMG.value,
|
||||||
|
]:
|
||||||
init_val = "cv2_ns"
|
init_val = "cv2_ns"
|
||||||
else:
|
else:
|
||||||
init_val = "patchmatch"
|
init_val = "patchmatch"
|
||||||
@ -1009,7 +1027,7 @@ with blocks as demo:
|
|||||||
upload_button: gr.update(value="Upload Image"),
|
upload_button: gr.update(value="Upload Image"),
|
||||||
model_selection: gr.update(visible=False),
|
model_selection: gr.update(visible=False),
|
||||||
model_path_input: gr.update(visible=False),
|
model_path_input: gr.update(visible=False),
|
||||||
init_mode: gr.update(value=init_val)
|
init_mode: gr.update(value=init_val),
|
||||||
}
|
}
|
||||||
|
|
||||||
setup_button.click(
|
setup_button.click(
|
||||||
@ -1032,7 +1050,7 @@ with blocks as demo:
|
|||||||
upload_button,
|
upload_button,
|
||||||
model_selection,
|
model_selection,
|
||||||
model_path_input,
|
model_path_input,
|
||||||
init_mode
|
init_mode,
|
||||||
],
|
],
|
||||||
_js=setup_button_js,
|
_js=setup_button_js,
|
||||||
)
|
)
|
||||||
|
4
utils.py
4
utils.py
@ -93,8 +93,8 @@ def edge_pad(img, mask, mode=1):
|
|||||||
|
|
||||||
|
|
||||||
def perlin_noise(img, mask):
|
def perlin_noise(img, mask):
|
||||||
lin_x = np.linspace(0, 5, mask.shape[0], endpoint=False)
|
lin_x = np.linspace(0, 5, mask.shape[1], endpoint=False)
|
||||||
lin_y = np.linspace(0, 5, mask.shape[1], endpoint=False)
|
lin_y = np.linspace(0, 5, mask.shape[0], endpoint=False)
|
||||||
x, y = np.meshgrid(lin_x, lin_y)
|
x, y = np.meshgrid(lin_x, lin_y)
|
||||||
avg = img.mean(axis=0).mean(axis=0)
|
avg = img.mean(axis=0).mean(axis=0)
|
||||||
# noise=[((perlin(x, y)+1)*128+avg[i]).astype(np.uint8) for i in range(3)]
|
# noise=[((perlin(x, y)+1)*128+avg[i]).astype(np.uint8) for i in range(3)]
|
||||||
|
Loading…
Reference in New Issue
Block a user