Remove unexpected artifacts in output images

This commit is contained in:
Brandon Rising 2024-12-13 23:53:22 -05:00 committed by Kent Keirsey
parent e0344a302c
commit 70811d0bd0

View File

@ -15,14 +15,14 @@ def prepare_control(
) -> torch.Tensor: ) -> torch.Tensor:
# load and encode the conditioning image # load and encode the conditioning image
img_cond = cond_image.convert("RGB") img_cond = cond_image.convert("RGB")
img_cond = img_cond.resize((width, height), Image.Resampling.LANCZOS) img_cond = img_cond.resize((width, height), Image.Resampling.BICUBIC)
img_cond = np.array(img_cond) img_cond = np.array(img_cond)
img_cond = torch.from_numpy(img_cond).float() img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
img_cond = rearrange(img_cond, "h w c -> 1 c h w") img_cond = rearrange(img_cond, "h w c -> 1 c h w")
ae_dtype = next(iter(ae.parameters())).dtype ae_dtype = next(iter(ae.parameters())).dtype
ae_device = next(iter(ae.parameters())).device ae_device = next(iter(ae.parameters())).device
img_cond = img_cond.to(device=ae_device, dtype=ae_dtype) img_cond = img_cond.to(device=ae_device, dtype=ae_dtype)
generator = torch.Generator(device=ae_device).manual_seed(seed) generator = torch.Generator(device=ae_device).manual_seed(seed)
img_cond = ae.encode(img_cond, sample=True, generator=generator) img_cond = ae.encode(img_cond, sample=False, generator=generator)
img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
return img_cond return img_cond