mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-04-04 22:43:40 +08:00
Consolidate graph processing logic into session processor. With graphs as the unit of work, and the session queue distributing graphs, we no longer need the invocation queue or processor. Instead, the session processor dequeues the next session and processes it in a simple loop, greatly simplifying the app. - Remove `graph_execution_manager` service. - Remove `queue` (invocation queue) service. - Remove `processor` (invocation processor) service. - Remove queue-related logic from `Invoker`. It now only starts and stops the services, providing them with access to other services. - Remove unused `invocation_retrieval_error` and `session_retrieval_error` events, these are no longer needed. - Clean up stats service now that it is less coupled to the rest of the app. - Refactor cancellation logic - cancellations now originate from session queue (i.e. HTTP cancel endpoint) and are emitted as events. Processor gets the events and sets the canceled event. Access to this event is provided to the invocation context for e.g. the step callback. - Remove `sessions` router; it provided access to `graph_executions` but that no longer exists.
114 lines
4.1 KiB
Python
114 lines
4.1 KiB
Python
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
|
|
"""Implementation of model loader service."""
|
|
|
|
from typing import Optional, Type
|
|
|
|
from invokeai.app.services.config import InvokeAIAppConfig
|
|
from invokeai.app.services.invoker import Invoker
|
|
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
|
from invokeai.backend.model_manager.load import (
|
|
LoadedModel,
|
|
ModelLoaderRegistry,
|
|
ModelLoaderRegistryBase,
|
|
)
|
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
|
from invokeai.backend.util.logging import InvokeAILogger
|
|
|
|
from .model_load_base import ModelLoadServiceBase
|
|
|
|
|
|
class ModelLoadService(ModelLoadServiceBase):
|
|
"""Wrapper around ModelLoaderRegistry."""
|
|
|
|
def __init__(
|
|
self,
|
|
app_config: InvokeAIAppConfig,
|
|
ram_cache: ModelCacheBase[AnyModel],
|
|
convert_cache: ModelConvertCacheBase,
|
|
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
|
|
):
|
|
"""Initialize the model load service."""
|
|
logger = InvokeAILogger.get_logger(self.__class__.__name__)
|
|
logger.setLevel(app_config.log_level.upper())
|
|
self._logger = logger
|
|
self._app_config = app_config
|
|
self._ram_cache = ram_cache
|
|
self._convert_cache = convert_cache
|
|
self._registry = registry
|
|
|
|
def start(self, invoker: Invoker) -> None:
|
|
self._invoker = invoker
|
|
|
|
@property
|
|
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
|
"""Return the RAM cache used by this loader."""
|
|
return self._ram_cache
|
|
|
|
@property
|
|
def convert_cache(self) -> ModelConvertCacheBase:
|
|
"""Return the checkpoint convert cache used by this loader."""
|
|
return self._convert_cache
|
|
|
|
def load_model(
|
|
self,
|
|
model_config: AnyModelConfig,
|
|
submodel_type: Optional[SubModelType] = None,
|
|
context_data: Optional[InvocationContextData] = None,
|
|
) -> LoadedModel:
|
|
"""
|
|
Given a model's configuration, load it and return the LoadedModel object.
|
|
|
|
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
|
:param submodel: For main (pipeline models), the submodel to fetch.
|
|
:param context: Invocation context used for event reporting
|
|
"""
|
|
if context_data:
|
|
self._emit_load_event(
|
|
context_data=context_data,
|
|
model_config=model_config,
|
|
)
|
|
|
|
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
|
|
loaded_model: LoadedModel = implementation(
|
|
app_config=self._app_config,
|
|
logger=self._logger,
|
|
ram_cache=self._ram_cache,
|
|
convert_cache=self._convert_cache,
|
|
).load_model(model_config, submodel_type)
|
|
|
|
if context_data:
|
|
self._emit_load_event(
|
|
context_data=context_data,
|
|
model_config=model_config,
|
|
loaded=True,
|
|
)
|
|
return loaded_model
|
|
|
|
def _emit_load_event(
|
|
self,
|
|
context_data: InvocationContextData,
|
|
model_config: AnyModelConfig,
|
|
loaded: Optional[bool] = False,
|
|
) -> None:
|
|
if not self._invoker:
|
|
return
|
|
|
|
if not loaded:
|
|
self._invoker.services.events.emit_model_load_started(
|
|
queue_id=context_data.queue_id,
|
|
queue_item_id=context_data.queue_item_id,
|
|
queue_batch_id=context_data.batch_id,
|
|
graph_execution_state_id=context_data.session_id,
|
|
model_config=model_config,
|
|
)
|
|
else:
|
|
self._invoker.services.events.emit_model_load_completed(
|
|
queue_id=context_data.queue_id,
|
|
queue_item_id=context_data.queue_item_id,
|
|
queue_batch_id=context_data.batch_id,
|
|
graph_execution_state_id=context_data.session_id,
|
|
model_config=model_config,
|
|
)
|