add clipseg support for creating inpaint masks from text

On the command line, the new option is --text_mask or -tm.
Example:

```
invoke> a baseball -I /path/to/still_life.png -tm orange
```

This will find the orange fruit in the still life painting and replace
it with an image of a baseball.
This commit is contained in:
Lincoln Stein 2022-10-16 23:30:24 -04:00
parent 32122e0312
commit 20551857da
9 changed files with 155 additions and 22 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 338 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

View File

@ -154,7 +154,7 @@ Here are the invoke> command that apply to txt2img:
| --seed <int> | -S<int> | None | Set the random seed for the next series of images. This can be used to recreate an image generated previously.|
| --sampler <sampler>| -A<sampler>| k_lms | Sampler to use. Use -h to get list of available samplers. |
| --hires_fix | | | Larger images often have duplication artefacts. This option suppresses duplicates by generating the image at low res, and then using img2img to increase the resolution |
| `--png_compression <0-9>` | `-z<0-9>` | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) |
| --png_compression <0-9> | -z<0-9> | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) |
| --grid | -g | False | Turn on grid mode to return a single image combining all the images generated by this prompt |
| --individual | -i | True | Turn off grid mode (deprecated; leave off --grid instead) |
| --outdir <path> | -o<path> | outputs/img_samples | Temporarily change the location of these images |
@ -212,11 +212,35 @@ accepts additional options:
[Inpainting](./INPAINTING.md) for details.
inpainting accepts all the arguments used for txt2img and img2img, as
well as the --mask (-M) argument:
well as the --mask (-M) and --text_mask (-tm) arguments:
| Argument <img width="100" align="right"/> | Shortcut | Default | Description |
|--------------------|------------|---------------------|--------------|
| `--init_mask <path>` | `-M<path>` | `None` |Path to an image the same size as the initial_image, with areas for inpainting made transparent.|
| `--text_mask <prompt> [<float>]` | `-tm <prompt> [<float>]` | <none> | Create a mask from a text prompt describing part of the image|
`--text_mask` (short form `-tm`) is a way to generate a mask using a
text description of the part of the image to replace. For example, if
you have an image of a breakfast plate with a bagel, toast and
scrambled eggs, you can selectively mask the bagel and replace it with
a piece of cake this way:
~~~
invoke> a piece of cake -I /path/to/breakfast.png -tm bagel
~~~
The algorithm uses <a
href="https://github.com/timojl/clipseg">clipseg</a> to classify
different regions of the image. The classifier puts out a confidence
score for each region it identifies. Generally regions that score
above 0.5 are reliable, but if you are getting too much or too little
masking you can adjust the threshold down (to get more mask), or up
(to get less). In this example, by passing `-tm` a higher value, we
are insisting on a more stringent classification.
~~~
invoke> a piece of cake -I /path/to/breakfast.png -tm bagel 0.6
~~~
# Other Commands

View File

@ -34,7 +34,46 @@ original unedited image and the masked (partially transparent) image:
invoke> "man with cat on shoulder" -I./images/man.png -M./images/man-transparent.png
```
We are hoping to get rid of the need for this workaround in an upcoming release.
## **Masking using Text**
You can also create a mask using a text prompt to select the part of
the image you want to alter, using the <a
href="https://github.com/timojl/clipseg">clipseg</a> algorithm. This
works on any image, not just ones generated by InvokeAI.
The `--text_mask` (short form `-tm`) option takes two arguments. The
first argument is a text description of the part of the image you wish
to mask (paint over). If the text description contains a space, you must
surround it with quotation marks. The optional second argument is the
minimum threshold for the mask classifier's confidence score, described
in more detail below.
To see how this works in practice, here's an image of a still life
painting that I got off the web.
<img src="../assets/still-life-scaled.jpg">
You can selectively mask out the
orange and replace it with a baseball in this way:
~~~
invoke> a baseball -I /path/to/still_life.png -tm orange
~~~
<img src="../assets/still-life-inpainted.png">
The clipseg classifier produces a confidence score for each region it
identifies. Generally regions that score above 0.5 are reliable, but
if you are getting too much or too little masking you can adjust the
threshold down (to get more mask), or up (to get less). In this
example, by passing `-tm` a higher value, we are insisting on a tigher
mask. However, if you make it too high, the orange may not be picked
up at all!
~~~
invoke> a baseball -I /path/to/breakfast.png -tm orange 0.6
~~~
### Inpainting is not changing the masked region enough!

View File

@ -34,7 +34,8 @@ from ldm.invoke.image_util import InitImageResizer
from ldm.invoke.devices import choose_torch_device, choose_precision
from ldm.invoke.conditioning import get_uc_and_c
from ldm.invoke.model_cache import ModelCache
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
def new_func(*args, **kw):
@ -188,6 +189,7 @@ class Generate:
self.esrgan = esrgan
self.free_gpu_mem = free_gpu_mem
self.size_matters = True # used to warn once about large image sizes and VRAM
self.txt2mask = None
# Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so
@ -266,6 +268,7 @@ class Generate:
# these are specific to img2img and inpaint
init_img = None,
init_mask = None,
text_mask = None,
fit = False,
strength = None,
init_color = None,
@ -298,6 +301,8 @@ class Generate:
seamless // whether the generated image should tile
hires_fix // whether the Hires Fix should be applied during generation
init_img // path to an initial image
init_mask // path to a mask for the initial image
text_mask // a text string that will be used to guide clipseg generation of the init_mask
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
facetool_strength // strength for GFPGAN/CodeFormer. 0.0 preserves image exactly, 1.0 replaces it completely
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
@ -405,6 +410,7 @@ class Generate:
width,
height,
fit=fit,
text_mask=text_mask,
)
# TODO: Hacky selection of operation to perform. Needs to be refactored.
@ -620,17 +626,14 @@ class Generate:
width,
height,
fit=False,
text_mask=None,
):
init_image = None
init_mask = None
if not img:
return None, None
image = self._load_img(
img,
width,
height,
)
image = self._load_img(img)
if image.width < self.width and image.height < self.height:
print(f'>> WARNING: img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions')
@ -648,10 +651,12 @@ class Generate:
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor
if mask:
mask_image = self._load_img(
mask, width, height) # this returns an Image
mask_image = self._load_img(mask) # this returns an Image
init_mask = self._create_init_mask(mask_image,width,height,fit=fit)
elif text_mask:
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
return init_image, init_mask
def _make_base(self):
@ -830,7 +835,7 @@ class Generate:
print(msg)
def _load_img(self, img, width, height)->Image:
def _load_img(self, img)->Image:
if isinstance(img, Image.Image):
image = img
print(
@ -892,6 +897,29 @@ class Generate:
mask = ImageOps.invert(mask)
return mask
# TODO: The latter part of this method repeats code from _create_init_mask()
def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image:
prompt = text_mask[0]
confidence_level = text_mask[1] if len(text_mask)>1 else 0.5
if self.txt2mask is None:
self.txt2mask = Txt2Mask(device = self.device)
segmented = self.txt2mask.segment(image, prompt)
mask = segmented.to_mask(float(confidence_level))
mask = mask.convert('RGB')
# now we adjust the size
if fit:
mask = self._fit_image(mask, (width, height))
else:
mask = self._squeeze_image(mask)
mask = mask.resize((mask.width//downsampling, mask.height //
downsampling), resample=Image.Resampling.NEAREST)
mask = np.array(mask)
mask = mask.astype(np.float32) / 255.0
mask = mask[None].transpose(0, 3, 1, 2)
mask = torch.from_numpy(mask)
return mask.to(self.device)
def _has_transparency(self, image):
if image.info.get("transparency", None) is not None:
return True

View File

@ -677,6 +677,14 @@ class Args(object):
type=str,
help='Path to input mask for inpainting mode (supersedes width and height)',
)
img2img_group.add_argument(
'-tm',
'--text_mask',
nargs='+',
type=str,
help='Use the clipseg classifier to generate the mask area for inpainting. Provide a description of the area to mask ("a mug"), optionally followed by the confidence level threshold (0-1.0; defaults to 0.5).',
default=None,
)
img2img_group.add_argument(
'--init_color',
type=str,

View File

@ -74,3 +74,4 @@ class Txt2Img(Generator):
if self.perlin > 0.0:
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
return x

View File

@ -54,6 +54,7 @@ COMMANDS = (
'--hires_fix',
'--inpaint_replace','-r',
'--png_compression','-z',
'--text_mask','-tm',
'!fix','!fetch','!history','!search','!clear',
'!models','!switch','!import_model','!edit_model'
)

View File

@ -36,6 +36,7 @@ from torchvision import transforms
CLIP_VERSION = 'ViT-B/16'
CLIPSEG_WEIGHTS = 'src/clipseg/weights/rd64-uni.pth'
CLIPSEG_SIZE = 352
class SegmentedGrayscale(object):
def __init__(self, image:Image, heatmap:torch.Tensor):
@ -43,28 +44,39 @@ class SegmentedGrayscale(object):
self.image = image
def to_grayscale(self)->Image:
return Image.fromarray(np.uint8(self.heatmap*255))
return self._rescale(Image.fromarray(np.uint8(self.heatmap*255)))
def to_mask(self,threshold:float=0.5)->Image:
discrete_heatmap = self.heatmap.lt(threshold).int()
return Image.fromarray(np.uint8(discrete_heatmap*255),mode='L')
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L'))
def to_transparent(self)->Image:
transparent_image = self.image.copy()
transparent_image.putalpha(self.to_image)
transparent_image.putalpha(self.to_grayscale())
return transparent_image
# unscales and uncrops the 352x352 heatmap so that it matches the image again
def _rescale(self, heatmap:Image)->Image:
size = self.image.width if (self.image.width > self.image.height) else self.image.height
resized_image = heatmap.resize(
(size,size),
resample=Image.Resampling.LANCZOS
)
return resized_image.crop((0,0,self.image.width,self.image.height))
class Txt2Mask(object):
'''
Create new Txt2Mask object. The optional device argument can be one of
'cuda', 'mps' or 'cpu'.
'''
def __init__(self,device='cpu'):
print('>> Initializing clipseg model')
print('>> Initializing clipseg model for text to mask inference')
self.device = device
self.model = CLIPDensePredT(version=CLIP_VERSION, reduce_dim=64, )
self.model.eval()
self.model.to(device)
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device(device)), strict=False)
# initially we keep everything in cpu to conserve space
self.model.to('cpu')
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False)
@torch.no_grad()
def segment(self, image:Image, prompt:str) -> SegmentedGrayscale:
@ -73,18 +85,38 @@ class Txt2Mask(object):
provided image and returns a SegmentedGrayscale object in which the brighter
pixels indicate where the object is inferred to be.
'''
self._to_device(self.device)
prompts = [prompt] # right now we operate on just a single prompt at a time
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize((image.width, image.height)), # must be multiple of 64...
transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64...
])
img = transform(image).unsqueeze(0)
img = self._scale_and_crop(image)
img = transform(img).unsqueeze(0)
preds = self.model(img.repeat(len(prompts),1,1,1), prompts)[0]
heatmap = torch.sigmoid(preds[0][0]).cpu()
self._to_device('cpu')
return SegmentedGrayscale(image, heatmap)
def _to_device(self, device):
self.model.to(device)
def _scale_and_crop(self, image:Image)->Image:
scaled_image = Image.new('RGB',(CLIPSEG_SIZE,CLIPSEG_SIZE))
if image.width > image.height: # width is constraint
scale = CLIPSEG_SIZE / image.width
else:
scale = CLIPSEG_SIZE / image.height
scaled_image.paste(
image.resize(
(int(scale * image.width),
int(scale * image.height)
),
resample=Image.Resampling.LANCZOS
),box=(0,0)
)
return scaled_image