add support for safety checker (NSFW filter)

Now you can activate the Hugging Face `diffusers` library safety check
for NSFW and other potentially disturbing imagery.

To turn on the safety check, pass --safety_checker at the command
line. For developers, the flag is `safety_checker=True` passed to
ldm.generate.Generate(). Once the safety checker is turned on, it
cannot be turned off unless you reinitialize a new Generate object.

When the safety checker is active, suspect images will be blurred and
a warning icon is added. There is also a warning message printed in
the CLI, but it can be a little hard to see because of its positioning
in the output stream.

There is a slight but noticeable delay when the safety checker runs.

Note that invisible watermarking is *not* currently implemented. The
watermark code distributed by the CompViz distribution uses a library
that does not seem to be able to retrieve the watermarks it creates,
and it does not appear that Hugging Face `diffusers` or other SD
distributions are doing any watermarking.
This commit is contained in:
Lincoln Stein 2022-10-23 22:26:18 -04:00
parent b7ce5b4f1b
commit b159b2fe42
10 changed files with 195 additions and 94 deletions

View File

@ -86,6 +86,7 @@ overridden on a per-prompt basis (see [List of prompt arguments](#list-of-prompt
| `--model <modelname>` | | `stable-diffusion-1.4` | Loads model specified in configs/models.yaml. Currently one of "stable-diffusion-1.4" or "laion400m" |
| `--full_precision` | `-F` | `False` | Run in slower full-precision mode. Needed for Macintosh M1/M2 hardware and some older video cards. |
| `--png_compression <0-9>` | `-z<0-9>` | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) |
| `--safety-checker` | | False | Activate safety checker for NSFW and other potentially disturbing imagery |
| `--web` | | `False` | Start in web server mode |
| `--host <ip addr>` | | `localhost` | Which network interface web server should listen on. Set to 0.0.0.0 to listen on any. |
| `--port <port>` | | `9090` | Which port web server should listen for requests on. |
@ -97,7 +98,6 @@ overridden on a per-prompt basis (see [List of prompt arguments](#list-of-prompt
| `--embedding_path <path>` | | `None` | Path to pre-trained embedding manager checkpoints, for custom models |
| `--gfpgan_dir` | | `src/gfpgan` | Path to where GFPGAN is installed. |
| `--gfpgan_model_path` | | `experiments/pretrained_models/GFPGANv1.4.pth` | Path to GFPGAN model file, relative to `--gfpgan_dir`. |
| `--device <device>` | `-d<device>` | `torch.cuda.current_device()` | Device to run SD on, e.g. "cuda:0" |
| `--free_gpu_mem` | | `False` | Free GPU memory after sampling, to allow image decoding and saving in low VRAM conditions |
| `--precision` | | `auto` | Set model precision, default is selected by device. Options: auto, float32, float16, autocast |

View File

@ -19,6 +19,7 @@ dependencies:
# ```
- albumentations==1.2.1
- coloredlogs==15.0.1
- diffusers==0.6.0
- einops==0.4.1
- grpcio==1.46.4
- humanfriendly==10.0

View File

@ -26,6 +26,7 @@ dependencies:
- pyreadline3
- torch-fidelity==0.3.0
- transformers==4.21.3
- diffusers==0.6.0
- torchmetrics==0.7.0
- flask==2.1.3
- flask_socketio==5.3.0

View File

@ -132,20 +132,21 @@ class Generate:
def __init__(
self,
model = None,
conf = 'configs/models.yaml',
embedding_path = None,
sampler_name = 'k_lms',
ddim_eta = 0.0, # deterministic
full_precision = False,
precision = 'auto',
# these are deprecated; if present they override values in the conf file
weights = None,
config = None,
model = None,
conf = 'configs/models.yaml',
embedding_path = None,
sampler_name = 'k_lms',
ddim_eta = 0.0, # deterministic
full_precision = False,
precision = 'auto',
gfpgan=None,
codeformer=None,
esrgan=None,
free_gpu_mem=False,
safety_checker:bool=False,
# these are deprecated; if present they override values in the conf file
weights = None,
config = None,
):
mconfig = OmegaConf.load(conf)
self.height = None
@ -176,6 +177,7 @@ class Generate:
self.free_gpu_mem = free_gpu_mem
self.size_matters = True # used to warn once about large image sizes and VRAM
self.txt2mask = None
self.safety_checker = None
# Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so
@ -203,6 +205,19 @@ class Generate:
# gets rid of annoying messages about random seed
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
# load safety checker if requested
if safety_checker:
try:
print('>> Initializing safety checker')
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
safety_model_id = "CompVis/stable-diffusion-safety-checker"
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, local_files_only=True)
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, local_files_only=True)
except Exception:
print('** An error was encountered while installing the safety checker:')
print(traceback.format_exc())
def prompt2png(self, prompt, outdir, **kwargs):
"""
Takes a prompt and an output directory, writes out the requested number
@ -418,6 +433,11 @@ class Generate:
self.seed, variation_amount, with_variations
)
checker = {
'checker':self.safety_checker,
'extractor':self.safety_feature_extractor
} if self.safety_checker else None
results = generator.generate(
prompt,
iterations=iterations,
@ -428,10 +448,10 @@ class Generate:
conditioning=(uc, c),
ddim_eta=ddim_eta,
image_callback=image_callback, # called after the final image is generated
step_callback=step_callback, # called after each intermediate image is generated
step_callback=step_callback, # called after each intermediate image is generated
width=width,
height=height,
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
init_image=init_image, # notice that init_image is different from init_img
mask_image=mask_image,
strength=strength,
@ -440,7 +460,8 @@ class Generate:
embiggen=embiggen,
embiggen_tiles=embiggen_tiles,
inpaint_replace=inpaint_replace,
mask_blur_radius=mask_blur_radius
mask_blur_radius=mask_blur_radius,
safety_checker=checker
)
if init_color:

View File

@ -418,6 +418,11 @@ class Args(object):
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
default='auto',
)
model_group.add_argument(
'--safety_checker',
action='store_true',
help='Check for and blur potentially NSFW images',
)
file_group.add_argument(
'--from_file',
dest='infile',

View File

@ -7,25 +7,27 @@ import numpy as np
import random
import os
from tqdm import tqdm, trange
from PIL import Image
from PIL import Image, ImageFilter
from einops import rearrange, repeat
from pytorch_lightning import seed_everything
from ldm.invoke.devices import choose_autocast
from ldm.util import rand_perlin_2d
downsampling = 8
CAUTION_IMG = 'assets/caution.png'
class Generator():
def __init__(self, model, precision):
self.model = model
self.precision = precision
self.seed = None
self.latent_channels = model.channels
self.model = model
self.precision = precision
self.seed = None
self.latent_channels = model.channels
self.downsampling_factor = downsampling # BUG: should come from model or config
self.perlin = 0.0
self.threshold = 0
self.variation_amount = 0
self.with_variations = []
self.safety_checker = None
self.perlin = 0.0
self.threshold = 0
self.variation_amount = 0
self.with_variations = []
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
def get_make_image(self,prompt,**kwargs):
@ -42,8 +44,10 @@ class Generator():
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None,
**kwargs):
scope = choose_autocast(self.precision)
self.safety_checker = safety_checker
make_image = self.get_make_image(
prompt,
init_image = init_image,
@ -77,10 +81,17 @@ class Generator():
pass
image = make_image(x_T)
if self.safety_checker is not None:
image = self.safety_check(image)
results.append([image, seed])
if image_callback is not None:
image_callback(image, seed, first_seed=first_seed)
seed = self.new_seed()
return results
def sample_to_image(self,samples):
@ -169,6 +180,39 @@ class Generator():
return v2
def safety_check(self,image:Image.Image):
'''
If the CompViz safety checker flags an NSFW image, we
blur it out.
'''
import diffusers
checker = self.safety_checker['checker']
extractor = self.safety_checker['extractor']
features = extractor([image], return_tensors="pt")
# unfortunately checker requires the numpy version, so we have to convert back
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
diffusers.logging.set_verbosity_error()
checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values)
if has_nsfw_concept[0]:
print('** An image with potential non-safe content has been detected. A blurred image will be returned. **')
return self.blur(image)
else:
return image
def blur(self,input):
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
try:
caution = Image.open(CAUTION_IMG)
caution = caution.resize((caution.width // 2, caution.height //2))
blurry.paste(caution,(0,0),caution)
except FileNotFoundError:
pass
return blurry
# this is a handy routine for debugging use. Given a generated sample,
# convert it into a PNG image and store it at the indicated path
def save_sample(self, sample, filepath):

View File

@ -1,5 +1,6 @@
albumentations==0.4.3
einops==0.3.0
diffusers==0.6.0
huggingface-hub==0.8.1
imageio==2.9.0
imageio-ffmpeg==0.4.2

View File

@ -32,6 +32,7 @@ send2trash
dependency_injector==4.40.0
eventlet
realesrgan
diffusers
git+https://github.com/openai/CLIP.git@main#egg=clip
git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan

View File

@ -69,16 +69,17 @@ def main():
# creating a Generate object:
try:
gen = Generate(
conf = opt.conf,
model = opt.model,
sampler_name = opt.sampler_name,
conf = opt.conf,
model = opt.model,
sampler_name = opt.sampler_name,
embedding_path = opt.embedding_path,
full_precision = opt.full_precision,
precision = opt.precision,
precision = opt.precision,
gfpgan=gfpgan,
codeformer=codeformer,
esrgan=esrgan,
free_gpu_mem=opt.free_gpu_mem,
safety_checker=opt.safety_checker,
)
except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.')

View File

@ -5,7 +5,7 @@
# two machines must share a common .cache directory.
from transformers import CLIPTokenizer, CLIPTextModel
import clip
from transformers import BertTokenizerFast
from transformers import BertTokenizerFast, AutoFeatureExtractor
import sys
import transformers
import os
@ -17,41 +17,39 @@ import traceback
transformers.logging.set_verbosity_error()
#---------------------------------------------
# this will preload the Bert tokenizer fles
print('Loading bert tokenizer (ignore deprecation errors)...', end='')
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
print('...success')
sys.stdout.flush()
def download_bert():
print('Installing bert tokenizer (ignore deprecation errors)...', end='')
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
print('...success')
sys.stdout.flush()
#---------------------------------------------
# this will download requirements for Kornia
print('Loading Kornia requirements...', end='')
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
import kornia
print('...success')
def download_kornia():
print('Installing Kornia requirements...', end='')
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
import kornia
print('...success')
version = 'openai/clip-vit-large-patch14'
sys.stdout.flush()
print('Loading CLIP model...',end='')
tokenizer = CLIPTokenizer.from_pretrained(version)
transformer = CLIPTextModel.from_pretrained(version)
print('...success')
#---------------------------------------------
def download_clip():
version = 'openai/clip-vit-large-patch14'
sys.stdout.flush()
print('Loading CLIP model...',end='')
tokenizer = CLIPTokenizer.from_pretrained(version)
transformer = CLIPTextModel.from_pretrained(version)
print('...success')
# In the event that the user has installed GFPGAN and also elected to use
# RealESRGAN, this will attempt to download the model needed by RealESRGANer
gfpgan = False
try:
from realesrgan import RealESRGANer
gfpgan = True
except ModuleNotFoundError:
pass
if gfpgan:
print('Loading models from RealESRGAN and facexlib...',end='')
#---------------------------------------------
def download_gfpgan():
print('Installing models from RealESRGAN and facexlib...',end='')
try:
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
@ -94,44 +92,72 @@ if gfpgan:
print('Error loading GFPGAN:')
print(traceback.format_exc())
print('preloading CodeFormer model file...',end='')
try:
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth'
if not os.path.exists(model_dest):
print('Downloading codeformer model file...')
#---------------------------------------------
def download_codeformer():
print('Installing CodeFormer model file...',end='')
try:
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth'
if not os.path.exists(model_dest):
print('Downloading codeformer model file...')
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
urllib.request.urlretrieve(model_url,model_dest)
except Exception:
print('Error loading CodeFormer:')
print(traceback.format_exc())
print('...success')
#---------------------------------------------
def download_clipseg():
print('Installing clipseg model for text-based masking...',end='')
try:
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
model_dest = 'src/clipseg/clipseg_weights.zip'
weights_dir = 'src/clipseg/weights'
if not os.path.exists(weights_dir):
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
urllib.request.urlretrieve(model_url,model_dest)
except Exception:
print('Error loading CodeFormer:')
print(traceback.format_exc())
print('...success')
with zipfile.ZipFile(model_dest,'r') as zip:
zip.extractall('src/clipseg')
os.rename('src/clipseg/clipseg_weights','src/clipseg/weights')
os.remove(model_dest)
from clipseg_models.clipseg import CLIPDensePredT
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, )
model.eval()
model.load_state_dict(
torch.load(
'src/clipseg/weights/rd64-uni-refined.pth',
map_location=torch.device('cpu')
),
strict=False,
)
except Exception:
print('Error installing clipseg model:')
print(traceback.format_exc())
print('...success')
print('Loading clipseg model for text-based masking...',end='')
try:
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
model_dest = 'src/clipseg/clipseg_weights.zip'
weights_dir = 'src/clipseg/weights'
if not os.path.exists(weights_dir):
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
urllib.request.urlretrieve(model_url,model_dest)
with zipfile.ZipFile(model_dest,'r') as zip:
zip.extractall('src/clipseg')
os.rename('src/clipseg/clipseg_weights','src/clipseg/weights')
os.remove(model_dest)
from clipseg_models.clipseg import CLIPDensePredT
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, )
model.eval()
model.load_state_dict(
torch.load(
'src/clipseg/weights/rd64-uni-refined.pth',
map_location=torch.device('cpu')
),
strict=False,
)
except Exception:
print('Error installing clipseg model:')
print(traceback.format_exc())
print('...success')
#-------------------------------------
def download_safety_checker():
print('Installing safety model for NSFW content detection...',end='')
try:
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
except ModuleNotFoundError:
print('Error installing safety checker model:')
print(traceback.format_exc())
return
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
print('...success')
#-------------------------------------
if __name__ == '__main__':
download_bert()
download_kornia()
download_clip()
download_gfpgan()
download_codeformer()
download_clipseg()
download_safety_checker()