mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-09 04:18:46 +08:00
Update working memory estimate for VAE decoding when tiling is being applied.
This commit is contained in:
parent
299eb94a05
commit
bd8017ecd5
@ -53,15 +53,30 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
|
||||
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
|
||||
|
||||
def _estimate_working_memory(self, latents: torch.Tensor) -> int:
|
||||
def _estimate_working_memory(
|
||||
self, latents: torch.Tensor, use_tiling: bool, vae: AutoencoderKL | AutoencoderTiny
|
||||
) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
|
||||
# element size (precision). This estimate is accurate for both SD1 and SDXL.
|
||||
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
|
||||
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
|
||||
element_size = 4 if self.fp32 else 2
|
||||
scaling_constant = 960 # Determined experimentally.
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
|
||||
if use_tiling:
|
||||
tile_size = self.tile_size
|
||||
if tile_size == 0:
|
||||
tile_size = vae.tile_sample_min_size
|
||||
assert isinstance(tile_size, int)
|
||||
out_h = tile_size
|
||||
out_w = tile_size
|
||||
# We add 50% to the working memory estimate when tiling is enabled to account for factors like tile overlap
|
||||
# and number of tiles. We could make this more precise in the future, but this should be good enough for
|
||||
# most use cases.
|
||||
working_memory = int(out_h * out_w * element_size * scaling_constant * 1.5)
|
||||
else:
|
||||
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
|
||||
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
|
||||
if self.fp32:
|
||||
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
|
||||
@ -73,11 +88,15 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
use_tiling = self.tiled or context.config.get().force_tiled_decode
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
|
||||
estimated_working_memory = self._estimate_working_memory(latents, use_tiling, vae_info.model)
|
||||
with (
|
||||
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
|
||||
vae_info.model_on_device(working_mem_bytes=self._estimate_working_memory(latents)) as (_, vae),
|
||||
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),
|
||||
):
|
||||
context.util.signal_progress("Running VAE decoder")
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
@ -107,7 +126,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
vae.to(dtype=torch.float16)
|
||||
latents = latents.half()
|
||||
|
||||
if self.tiled or context.config.get().force_tiled_decode:
|
||||
if use_tiling:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
|
Loading…
Reference in New Issue
Block a user