Update working memory estimate for VAE decoding when tiling is being applied.

This commit is contained in:
Ryan Dick 2025-01-02 11:54:07 -05:00
parent 299eb94a05
commit bd8017ecd5

View File

@ -53,15 +53,30 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
def _estimate_working_memory(self, latents: torch.Tensor) -> int:
def _estimate_working_memory(
self, latents: torch.Tensor, use_tiling: bool, vae: AutoencoderKL | AutoencoderTiny
) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
# element size (precision). This estimate is accurate for both SD1 and SDXL.
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
element_size = 4 if self.fp32 else 2
scaling_constant = 960 # Determined experimentally.
working_memory = out_h * out_w * element_size * scaling_constant
if use_tiling:
tile_size = self.tile_size
if tile_size == 0:
tile_size = vae.tile_sample_min_size
assert isinstance(tile_size, int)
out_h = tile_size
out_w = tile_size
# We add 50% to the working memory estimate when tiling is enabled to account for factors like tile overlap
# and number of tiles. We could make this more precise in the future, but this should be good enough for
# most use cases.
working_memory = int(out_h * out_w * element_size * scaling_constant * 1.5)
else:
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
working_memory = out_h * out_w * element_size * scaling_constant
if self.fp32:
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
@ -73,11 +88,15 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
use_tiling = self.tiled or context.config.get().force_tiled_decode
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
estimated_working_memory = self._estimate_working_memory(latents, use_tiling, vae_info.model)
with (
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
vae_info.model_on_device(working_mem_bytes=self._estimate_working_memory(latents)) as (_, vae),
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),
):
context.util.signal_progress("Running VAE decoder")
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
@ -107,7 +126,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae.to(dtype=torch.float16)
latents = latents.half()
if self.tiled or context.config.get().force_tiled_decode:
if use_tiling:
vae.enable_tiling()
else:
vae.disable_tiling()