mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-07 03:17:05 +08:00
feat(app): process accepts custom invocation context builder
This commit is contained in:
parent
4c94d41fa9
commit
a1a3e60431
@ -2,7 +2,7 @@ import traceback
|
||||
from contextlib import suppress
|
||||
from threading import BoundedSemaphore, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.events.events_common import (
|
||||
@ -28,7 +28,11 @@ from invokeai.app.services.session_processor.session_processor_base import (
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException, SessionProcessorStatus
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem, SessionQueueItemNotFoundError
|
||||
from invokeai.app.services.shared.graph import NodeInputError
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
|
||||
from invokeai.app.services.shared.invocation_context import (
|
||||
InvocationContext,
|
||||
InvocationContextData,
|
||||
build_invocation_context,
|
||||
)
|
||||
from invokeai.app.util.profiler import Profiler
|
||||
|
||||
|
||||
@ -42,6 +46,9 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
on_after_run_node_callbacks: Optional[list[OnAfterRunNode]] = None,
|
||||
on_node_error_callbacks: Optional[list[OnNodeError]] = None,
|
||||
on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None,
|
||||
build_invocation_context: Callable[
|
||||
[InvocationServices, InvocationContextData, Callable[[], bool]], InvocationContext
|
||||
] = build_invocation_context,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -50,6 +57,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
on_after_run_node_callbacks: Callbacks to run after each node completes.
|
||||
on_node_error_callbacks: Callbacks to run when a node errors.
|
||||
on_after_run_session_callbacks: Callbacks to run after the session completes.
|
||||
build_invocation_context: A function that builds the invocation context. This is called for each invocation. A default implementation is provided.
|
||||
"""
|
||||
|
||||
self._on_before_run_session_callbacks = on_before_run_session_callbacks or []
|
||||
@ -57,6 +65,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
|
||||
self._on_node_error_callbacks = on_node_error_callbacks or []
|
||||
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
|
||||
self._build_invocation_context = build_invocation_context
|
||||
|
||||
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
|
||||
self._services = services
|
||||
@ -119,11 +128,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
queue_item=queue_item,
|
||||
)
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self._services,
|
||||
is_canceled=self._is_canceled,
|
||||
)
|
||||
context = self._build_invocation_context(self._services, data, self._is_canceled)
|
||||
|
||||
# Invoke the node
|
||||
output = invocation.invoke_internal(context=context, services=self._services)
|
||||
|
Loading…
Reference in New Issue
Block a user