2024-02-09 23:08:38 -05:00
|
|
|
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
|
|
|
|
"""Implementation of model loader service."""
|
|
|
|
|
2024-02-18 17:27:42 +11:00
|
|
|
from typing import Optional, Type
|
2024-02-09 23:08:38 -05:00
|
|
|
|
|
|
|
from invokeai.app.services.config import InvokeAIAppConfig
|
2024-02-15 20:43:41 +11:00
|
|
|
from invokeai.app.services.invoker import Invoker
|
|
|
|
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
2024-02-18 17:27:42 +11:00
|
|
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
|
|
|
from invokeai.backend.model_manager.load import (
|
|
|
|
LoadedModel,
|
|
|
|
ModelLoaderRegistry,
|
|
|
|
ModelLoaderRegistryBase,
|
|
|
|
)
|
2024-02-09 23:08:38 -05:00
|
|
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
2024-02-12 21:25:42 -05:00
|
|
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
2024-02-09 23:08:38 -05:00
|
|
|
from invokeai.backend.util.logging import InvokeAILogger
|
|
|
|
|
|
|
|
from .model_load_base import ModelLoadServiceBase
|
|
|
|
|
|
|
|
|
|
|
|
class ModelLoadService(ModelLoadServiceBase):
|
2024-02-18 17:27:42 +11:00
|
|
|
"""Wrapper around ModelLoaderRegistry."""
|
2024-02-09 23:08:38 -05:00
|
|
|
|
|
|
|
def __init__(
|
2024-02-18 17:27:42 +11:00
|
|
|
self,
|
|
|
|
app_config: InvokeAIAppConfig,
|
|
|
|
ram_cache: ModelCacheBase[AnyModel],
|
|
|
|
convert_cache: ModelConvertCacheBase,
|
|
|
|
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
|
2024-02-09 23:08:38 -05:00
|
|
|
):
|
|
|
|
"""Initialize the model load service."""
|
|
|
|
logger = InvokeAILogger.get_logger(self.__class__.__name__)
|
|
|
|
logger.setLevel(app_config.log_level.upper())
|
2024-02-18 17:27:42 +11:00
|
|
|
self._logger = logger
|
|
|
|
self._app_config = app_config
|
|
|
|
self._ram_cache = ram_cache
|
|
|
|
self._convert_cache = convert_cache
|
|
|
|
self._registry = registry
|
2024-02-09 23:08:38 -05:00
|
|
|
|
2024-02-15 20:43:41 +11:00
|
|
|
def start(self, invoker: Invoker) -> None:
|
|
|
|
self._invoker = invoker
|
|
|
|
|
2024-02-12 21:25:42 -05:00
|
|
|
@property
|
|
|
|
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
|
|
|
"""Return the RAM cache used by this loader."""
|
2024-02-18 17:27:42 +11:00
|
|
|
return self._ram_cache
|
2024-02-12 21:25:42 -05:00
|
|
|
|
|
|
|
@property
|
|
|
|
def convert_cache(self) -> ModelConvertCacheBase:
|
|
|
|
"""Return the checkpoint convert cache used by this loader."""
|
2024-02-18 17:27:42 +11:00
|
|
|
return self._convert_cache
|
2024-02-10 18:09:45 -05:00
|
|
|
|
2024-02-18 17:27:42 +11:00
|
|
|
def load_model(
|
2024-02-10 18:09:45 -05:00
|
|
|
self,
|
|
|
|
model_config: AnyModelConfig,
|
|
|
|
submodel_type: Optional[SubModelType] = None,
|
2024-02-15 20:43:41 +11:00
|
|
|
context_data: Optional[InvocationContextData] = None,
|
2024-02-10 18:09:45 -05:00
|
|
|
) -> LoadedModel:
|
|
|
|
"""
|
|
|
|
Given a model's configuration, load it and return the LoadedModel object.
|
|
|
|
|
|
|
|
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
|
|
|
:param submodel: For main (pipeline models), the submodel to fetch.
|
|
|
|
:param context: Invocation context used for event reporting
|
|
|
|
"""
|
2024-02-15 20:43:41 +11:00
|
|
|
if context_data:
|
2024-02-10 18:09:45 -05:00
|
|
|
self._emit_load_event(
|
2024-02-15 20:43:41 +11:00
|
|
|
context_data=context_data,
|
2024-02-10 18:09:45 -05:00
|
|
|
model_config=model_config,
|
2024-03-14 17:38:49 +11:00
|
|
|
submodel_type=submodel_type,
|
2024-02-10 18:09:45 -05:00
|
|
|
)
|
2024-02-18 17:27:42 +11:00
|
|
|
|
|
|
|
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
|
|
|
|
loaded_model: LoadedModel = implementation(
|
|
|
|
app_config=self._app_config,
|
|
|
|
logger=self._logger,
|
|
|
|
ram_cache=self._ram_cache,
|
|
|
|
convert_cache=self._convert_cache,
|
|
|
|
).load_model(model_config, submodel_type)
|
|
|
|
|
2024-02-15 20:43:41 +11:00
|
|
|
if context_data:
|
2024-02-10 18:09:45 -05:00
|
|
|
self._emit_load_event(
|
2024-02-15 20:43:41 +11:00
|
|
|
context_data=context_data,
|
2024-02-10 18:09:45 -05:00
|
|
|
model_config=model_config,
|
2024-03-14 17:38:49 +11:00
|
|
|
submodel_type=submodel_type,
|
2024-02-10 18:09:45 -05:00
|
|
|
loaded=True,
|
|
|
|
)
|
|
|
|
return loaded_model
|
|
|
|
|
|
|
|
def _emit_load_event(
|
|
|
|
self,
|
2024-02-15 20:43:41 +11:00
|
|
|
context_data: InvocationContextData,
|
2024-02-10 18:09:45 -05:00
|
|
|
model_config: AnyModelConfig,
|
|
|
|
loaded: Optional[bool] = False,
|
2024-03-14 17:38:49 +11:00
|
|
|
submodel_type: Optional[SubModelType] = None,
|
2024-02-10 18:09:45 -05:00
|
|
|
) -> None:
|
2024-02-15 20:43:41 +11:00
|
|
|
if not self._invoker:
|
|
|
|
return
|
2024-02-09 23:08:38 -05:00
|
|
|
|
2024-02-10 18:09:45 -05:00
|
|
|
if not loaded:
|
2024-02-15 20:43:41 +11:00
|
|
|
self._invoker.services.events.emit_model_load_started(
|
2024-02-18 16:51:58 +11:00
|
|
|
queue_id=context_data.queue_item.queue_id,
|
|
|
|
queue_item_id=context_data.queue_item.item_id,
|
|
|
|
queue_batch_id=context_data.queue_item.batch_id,
|
|
|
|
graph_execution_state_id=context_data.queue_item.session_id,
|
2024-02-10 18:09:45 -05:00
|
|
|
model_config=model_config,
|
2024-03-14 17:38:49 +11:00
|
|
|
submodel_type=submodel_type,
|
2024-02-10 18:09:45 -05:00
|
|
|
)
|
|
|
|
else:
|
2024-02-15 20:43:41 +11:00
|
|
|
self._invoker.services.events.emit_model_load_completed(
|
2024-02-18 16:51:58 +11:00
|
|
|
queue_id=context_data.queue_item.queue_id,
|
|
|
|
queue_item_id=context_data.queue_item.item_id,
|
|
|
|
queue_batch_id=context_data.queue_item.batch_id,
|
|
|
|
graph_execution_state_id=context_data.queue_item.session_id,
|
2024-02-10 18:09:45 -05:00
|
|
|
model_config=model_config,
|
2024-03-14 17:38:49 +11:00
|
|
|
submodel_type=submodel_type,
|
2024-02-10 18:09:45 -05:00
|
|
|
)
|