From 23c947ab0374220c39ac54fc00afcb74e809dd95 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 19 Jul 2023 20:23:30 +0300 Subject: [PATCH] automatically switch to 32-bit float VAE if the generated picture has NaNs. --- CHANGELOG.md | 3 ++- modules/processing.py | 41 ++++++++++++++++++++++++++++++++++++----- modules/shared.py | 1 + 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a561252c2..63a2c7d39 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,7 +29,8 @@ * speedup extra networks listing * added `[none]` filename token. * removed thumbs extra networks view mode (use settings tab to change width/height/scale to get thumbs) - * add always_discard_next_to_last_sigma option to XYZ plot + * add always_discard_next_to_last_sigma option to XYZ plot + * automatically switch to 32-bit float VAE if the generated picture has NaNs without the need for `--no-half-vae` commandline flag. ### Extensions and API: * api endpoints: /sdapi/v1/server-kill, /sdapi/v1/server-restart, /sdapi/v1/server-stop diff --git a/modules/processing.py b/modules/processing.py index e028bf9ef..a74a53027 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -14,7 +14,7 @@ from skimage import exposure from typing import Any, Dict, List import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -538,6 +538,40 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see return x +def decode_latent_batch(model, batch, target_device=None, check_for_nans=False): + samples = [] + + for i in range(batch.shape[0]): + sample = decode_first_stage(model, batch[i:i + 1])[0] + + if check_for_nans: + try: + devices.test_for_nans(sample, "vae") + except devices.NansException as e: + if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision: + raise e + + errors.print_error_explanation( + "A tensor with all NaNs was produced in VAE.\n" + "Web UI will now convert VAE into 32-bit float and retry.\n" + "To disable this behavior, disable the 'Automaticlly revert VAE to 32-bit floats' setting.\n" + "To always start with 32-bit VAE, use --no-half-vae commandline flag." + ) + + devices.dtype_vae = torch.float32 + model.first_stage_model.to(devices.dtype_vae) + batch = batch.to(devices.dtype_vae) + + sample = decode_first_stage(model, batch[i:i + 1])[0] + + if target_device is not None: + sample = sample.to(target_device) + + samples.append(sample) + + return samples + + def decode_first_stage(model, x): x = model.decode_first_stage(x.to(devices.dtype_vae)) @@ -758,10 +792,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) - x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] - for x in x_samples_ddim: - devices.test_for_nans(x, "vae") - + x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) diff --git a/modules/shared.py b/modules/shared.py index 1ce7b49e5..aa72c9c87 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -427,6 +427,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"), "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), + "auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), }))