added seamless tiling mode and commands

This commit is contained in:
prixt 2022-09-03 15:13:31 +09:00
parent 33936430d0
commit d922b53c26
2 changed files with 35 additions and 0 deletions

View File

@ -14,6 +14,7 @@ from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange, repeat
from torch import nn
from torchvision.utils import make_grid
from pytorch_lightning import seed_everything
from torch import autocast
@ -109,6 +110,7 @@ class T2I:
downsampling_factor
precision
strength
seamless
embedding_path
The vast majority of these arguments default to reasonable values.
@ -132,6 +134,7 @@ class T2I:
precision='autocast',
full_precision=False,
strength=0.75, # default in scripts/img2img.py
seamless=False,
embedding_path=None,
device_type = 'cuda',
# just to keep track of this parameter when regenerating prompt
@ -153,6 +156,7 @@ class T2I:
self.precision = precision
self.full_precision = full_precision
self.strength = strength
self.seamless = seamless
self.embedding_path = embedding_path
self.device_type = device_type
self.model = None # empty for now
@ -217,6 +221,7 @@ class T2I:
step_callback = None,
width = None,
height = None,
seamless = False,
# these are specific to img2img
init_img = None,
fit = False,
@ -238,6 +243,7 @@ class T2I:
width // width of image, in multiples of 64 (512)
height // height of image, in multiples of 64 (512)
cfg_scale // how strongly the prompt influences the image (7.5) (must be >1)
seamless // whether the generated image should tile
init_img // path to an initial image - its dimensions override width and height
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely
@ -265,6 +271,7 @@ class T2I:
seed = seed or self.seed
width = width or self.width
height = height or self.height
seamless = seamless or self.seamless
cfg_scale = cfg_scale or self.cfg_scale
ddim_eta = ddim_eta or self.ddim_eta
iterations = iterations or self.iterations
@ -274,6 +281,10 @@ class T2I:
model = (
self.load_model()
) # will instantiate the model or return it from cache
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
m.padding_mode = 'circular' if seamless else m._orig_padding_mode
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
assert (
0.0 <= strength <= 1.0
@ -562,6 +573,10 @@ class T2I:
self._set_sampler()
for m in self.model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
m._orig_padding_mode = m.padding_mode
return self.model
def _set_sampler(self):

View File

@ -9,6 +9,7 @@ import sys
import copy
import warnings
import time
import torch.nn as nn
from ldm.dream.devices import choose_torch_device
import ldm.dream.readline
from ldm.dream.pngwriter import PngWriter, PromptFormatter
@ -60,6 +61,7 @@ def main():
grid = opt.grid,
# this is solely for recreating the prompt
latent_diffusion_weights=opt.laion400m,
seamless=opt.seamless,
embedding_path=opt.embedding_path,
device_type=opt.device
)
@ -92,6 +94,14 @@ def main():
f'>> model loaded in', '%4.2fs' % (time.time() - tic)
)
for m in t2i.model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
m._orig_padding_mode = m.padding_mode
if opt.seamless:
m.padding_mode = 'circular'
if opt.seamless:
print(">> changed to seamless tiling mode")
if not infile:
print(
"\n* Initialization done! Awaiting your command (-h for help, 'q' to quit)"
@ -374,6 +384,11 @@ def create_argv_parser():
default='outputs/img-samples',
help='Directory to save generated images and a log of prompts and seeds. Default: outputs/img-samples',
)
parser.add_argument(
'--seamless',
action='store_true',
help='Change the model to seamless tiling (circular) mode',
)
parser.add_argument(
'--embedding_path',
type=str,
@ -474,6 +489,11 @@ def create_cmd_parser():
default=None,
help='Directory to save generated images and a log of prompts and seeds',
)
parser.add_argument(
'--seamless',
action='store_true',
help='Change the model to seamless tiling (circular) mode',
)
parser.add_argument(
'-i',
'--individual',