mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-04-04 22:43:40 +08:00
feat(app): use new signal_progress API for spandrel nodes
Both the vanilla and autoscale invocations report progress while processing each tile. The autoscale version, which may run the spandrel model multiple times, also includes the current iteration.
This commit is contained in:
parent
44c41e9549
commit
c0609f760f
@ -1,3 +1,4 @@
|
||||
import functools
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
@ -61,6 +62,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
tile_size: int,
|
||||
spandrel_model: SpandrelImageToImageModel,
|
||||
is_canceled: Callable[[], bool],
|
||||
step_callback: Callable[[int, int], None],
|
||||
) -> Image.Image:
|
||||
# Compute the image tiles.
|
||||
if tile_size > 0:
|
||||
@ -103,7 +105,12 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
||||
|
||||
# Run the model on each tile.
|
||||
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
|
||||
pbar = tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles")
|
||||
|
||||
# Update progress, starting with 0.
|
||||
step_callback(0, pbar.total)
|
||||
|
||||
for tile, scaled_tile in pbar:
|
||||
# Exit early if the invocation has been canceled.
|
||||
if is_canceled():
|
||||
raise CanceledException
|
||||
@ -136,6 +143,8 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
:,
|
||||
] = output_tile[top_overlap:, left_overlap:, :]
|
||||
|
||||
step_callback(pbar.n + 1, pbar.total)
|
||||
|
||||
# Convert the output tensor to a PIL image.
|
||||
np_image = output_tensor.detach().numpy().astype(np.uint8)
|
||||
pil_image = Image.fromarray(np_image)
|
||||
@ -151,12 +160,20 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# Load the model.
|
||||
spandrel_model_info = context.models.load(self.image_to_image_model)
|
||||
|
||||
def step_callback(step: int, total_steps: int) -> None:
|
||||
context.util.signal_progress(
|
||||
message=f"Processing tile {step}/{total_steps}",
|
||||
percentage=step / total_steps,
|
||||
)
|
||||
|
||||
# Do the upscaling.
|
||||
with spandrel_model_info as spandrel_model:
|
||||
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
||||
|
||||
# Upscale the image
|
||||
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
|
||||
pil_image = self.upscale_image(
|
||||
image, self.tile_size, spandrel_model, context.util.is_canceled, step_callback
|
||||
)
|
||||
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
return ImageOutput.build(image_dto)
|
||||
@ -197,12 +214,27 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
|
||||
target_width = int(image.width * self.scale)
|
||||
target_height = int(image.height * self.scale)
|
||||
|
||||
def step_callback(iteration: int, step: int, total_steps: int) -> None:
|
||||
context.util.signal_progress(
|
||||
message=self._get_progress_message(iteration, step, total_steps),
|
||||
percentage=step / total_steps,
|
||||
)
|
||||
|
||||
# Do the upscaling.
|
||||
with spandrel_model_info as spandrel_model:
|
||||
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
||||
|
||||
iteration = 1
|
||||
context.util.signal_progress(self._get_progress_message(iteration))
|
||||
|
||||
# First pass of upscaling. Note: `pil_image` will be mutated.
|
||||
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
|
||||
pil_image = self.upscale_image(
|
||||
image,
|
||||
self.tile_size,
|
||||
spandrel_model,
|
||||
context.util.is_canceled,
|
||||
functools.partial(step_callback, iteration),
|
||||
)
|
||||
|
||||
# Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model
|
||||
# upscaled the image and run the loop below if it did. We'll require the model to upscale both dimensions
|
||||
@ -211,16 +243,22 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
|
||||
|
||||
if is_upscale_model:
|
||||
# This is an upscale model, so we should keep upscaling until we reach the target size.
|
||||
iterations = 1
|
||||
while pil_image.width < target_width or pil_image.height < target_height:
|
||||
pil_image = self.upscale_image(pil_image, self.tile_size, spandrel_model, context.util.is_canceled)
|
||||
iterations += 1
|
||||
iteration += 1
|
||||
context.util.signal_progress(self._get_progress_message(iteration))
|
||||
pil_image = self.upscale_image(
|
||||
pil_image,
|
||||
self.tile_size,
|
||||
spandrel_model,
|
||||
context.util.is_canceled,
|
||||
functools.partial(step_callback, iteration),
|
||||
)
|
||||
|
||||
# Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x.
|
||||
# Our max scale is 16x, so with a 2x model, we should never exceed 16x == 2^4 -> 4 iterations.
|
||||
# We'll allow one extra iteration "just in case" and bail at 5 upscaling iterations. In practice,
|
||||
# we should never reach this limit.
|
||||
if iterations >= 5:
|
||||
if iteration >= 5:
|
||||
context.logger.warning(
|
||||
"Upscale loop reached maximum iteration count of 5, stopping upscaling early."
|
||||
)
|
||||
@ -251,3 +289,10 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
|
||||
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
@classmethod
|
||||
def _get_progress_message(cls, iteration: int, step: int | None = None, total_steps: int | None = None) -> str:
|
||||
if step is not None and total_steps is not None:
|
||||
return f"Processing iteration {iteration}, tile {step}/{total_steps}"
|
||||
|
||||
return f"Processing iteration {iteration}"
|
||||
|
Loading…
x
Reference in New Issue
Block a user