feat(app): process accepts custom invocation context builder

This commit is contained in:
psychedelicious 2024-12-11 08:51:54 +10:00
parent 4c94d41fa9
commit a1a3e60431

View File

@ -2,7 +2,7 @@ import traceback
from contextlib import suppress
from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent
from typing import Optional
from typing import Callable, Optional
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.events.events_common import (
@ -28,7 +28,11 @@ from invokeai.app.services.session_processor.session_processor_base import (
from invokeai.app.services.session_processor.session_processor_common import CanceledException, SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem, SessionQueueItemNotFoundError
from invokeai.app.services.shared.graph import NodeInputError
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
from invokeai.app.services.shared.invocation_context import (
InvocationContext,
InvocationContextData,
build_invocation_context,
)
from invokeai.app.util.profiler import Profiler
@ -42,6 +46,9 @@ class DefaultSessionRunner(SessionRunnerBase):
on_after_run_node_callbacks: Optional[list[OnAfterRunNode]] = None,
on_node_error_callbacks: Optional[list[OnNodeError]] = None,
on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None,
build_invocation_context: Callable[
[InvocationServices, InvocationContextData, Callable[[], bool]], InvocationContext
] = build_invocation_context,
):
"""
Args:
@ -50,6 +57,7 @@ class DefaultSessionRunner(SessionRunnerBase):
on_after_run_node_callbacks: Callbacks to run after each node completes.
on_node_error_callbacks: Callbacks to run when a node errors.
on_after_run_session_callbacks: Callbacks to run after the session completes.
build_invocation_context: A function that builds the invocation context. This is called for each invocation. A default implementation is provided.
"""
self._on_before_run_session_callbacks = on_before_run_session_callbacks or []
@ -57,6 +65,7 @@ class DefaultSessionRunner(SessionRunnerBase):
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
self._on_node_error_callbacks = on_node_error_callbacks or []
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
self._build_invocation_context = build_invocation_context
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
self._services = services
@ -119,11 +128,7 @@ class DefaultSessionRunner(SessionRunnerBase):
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
queue_item=queue_item,
)
context = build_invocation_context(
data=data,
services=self._services,
is_canceled=self._is_canceled,
)
context = self._build_invocation_context(self._services, data, self._is_canceled)
# Invoke the node
output = invocation.invoke_internal(context=context, services=self._services)