InvokeAI/invokeai/app/services/model_load/model_load_default.py
psychedelicious 725c03cf87 refactor(nodes): merge processors
Consolidate graph processing logic into session processor.

With graphs as the unit of work, and the session queue distributing graphs, we no longer need the invocation queue or processor.

Instead, the session processor dequeues the next session and processes it in a simple loop, greatly simplifying the app.

- Remove `graph_execution_manager` service.
- Remove `queue` (invocation queue) service.
- Remove `processor` (invocation processor) service.
- Remove queue-related logic from `Invoker`. It now only starts and stops the services, providing them with access to other services.
- Remove unused `invocation_retrieval_error` and `session_retrieval_error` events, these are no longer needed.
- Clean up stats service now that it is less coupled to the rest of the app.
- Refactor cancellation logic - cancellations now originate from session queue (i.e. HTTP cancel endpoint) and are emitted as events. Processor gets the events and sets the canceled event. Access to this event is provided to the invocation context for e.g. the step callback.
- Remove `sessions` router; it provided access to `graph_executions` but that no longer exists.
2024-03-01 10:42:33 +11:00

114 lines
4.1 KiB
Python

# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
"""Implementation of model loader service."""
from typing import Optional, Type
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import (
LoadedModel,
ModelLoaderRegistry,
ModelLoaderRegistryBase,
)
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.util.logging import InvokeAILogger
from .model_load_base import ModelLoadServiceBase
class ModelLoadService(ModelLoadServiceBase):
"""Wrapper around ModelLoaderRegistry."""
def __init__(
self,
app_config: InvokeAIAppConfig,
ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase,
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
):
"""Initialize the model load service."""
logger = InvokeAILogger.get_logger(self.__class__.__name__)
logger.setLevel(app_config.log_level.upper())
self._logger = logger
self._app_config = app_config
self._ram_cache = ram_cache
self._convert_cache = convert_cache
self._registry = registry
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@property
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the RAM cache used by this loader."""
return self._ram_cache
@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""
return self._convert_cache
def load_model(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
"""
if context_data:
self._emit_load_event(
context_data=context_data,
model_config=model_config,
)
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
loaded_model: LoadedModel = implementation(
app_config=self._app_config,
logger=self._logger,
ram_cache=self._ram_cache,
convert_cache=self._convert_cache,
).load_model(model_config, submodel_type)
if context_data:
self._emit_load_event(
context_data=context_data,
model_config=model_config,
loaded=True,
)
return loaded_model
def _emit_load_event(
self,
context_data: InvocationContextData,
model_config: AnyModelConfig,
loaded: Optional[bool] = False,
) -> None:
if not self._invoker:
return
if not loaded:
self._invoker.services.events.emit_model_load_started(
queue_id=context_data.queue_id,
queue_item_id=context_data.queue_item_id,
queue_batch_id=context_data.batch_id,
graph_execution_state_id=context_data.session_id,
model_config=model_config,
)
else:
self._invoker.services.events.emit_model_load_completed(
queue_id=context_data.queue_id,
queue_item_id=context_data.queue_item_id,
queue_batch_id=context_data.batch_id,
graph_execution_state_id=context_data.session_id,
model_config=model_config,
)