mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-05 10:27:02 +08:00
Get rid of ModelLocker. It was an unnecessary layer of indirection.
This commit is contained in:
parent
a39bcf7e85
commit
7dc3e0fdbe
@ -1364,7 +1364,6 @@ the in-memory loaded model:
|
||||
|----------------|-----------------|------------------|
|
||||
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
|
||||
| `model` | AnyModel | The instantiated model (details below) |
|
||||
| `locker` | ModelLocker | A context manager that mediates the movement of the model into VRAM |
|
||||
|
||||
### get_model_by_key(key, [submodel]) -> LoadedModel
|
||||
|
||||
|
@ -78,9 +78,8 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
) -> LoadedModelWithoutConfig:
|
||||
cache_key = str(model_path)
|
||||
ram_cache = self.ram_cache
|
||||
try:
|
||||
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
@ -109,5 +108,5 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
)
|
||||
assert loader is not None
|
||||
raw_model = loader(model_path)
|
||||
ram_cache.put(key=cache_key, model=raw_model)
|
||||
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||
self._ram_cache.put(key=cache_key, model=raw_model)
|
||||
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)
|
||||
|
@ -5,7 +5,6 @@ Base class for model loading in InvokeAI.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, Optional, Tuple
|
||||
@ -18,20 +17,17 @@ from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLocker
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedModelWithoutConfig:
|
||||
"""
|
||||
Context manager object that mediates transfer from RAM<->VRAM.
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM.
|
||||
|
||||
This is a context manager object that has two distinct APIs:
|
||||
|
||||
1. Older API (deprecated):
|
||||
Use the LoadedModel object directly as a context manager.
|
||||
It will move the model into VRAM (on CUDA devices), and
|
||||
Use the LoadedModel object directly as a context manager. It will move the model into VRAM (on CUDA devices), and
|
||||
return the model in a form suitable for passing to torch.
|
||||
Example:
|
||||
```
|
||||
@ -41,13 +37,9 @@ class LoadedModelWithoutConfig:
|
||||
```
|
||||
|
||||
2. Newer API (recommended):
|
||||
Call the LoadedModel's `model_on_device()` method in a
|
||||
context. It returns a tuple consisting of a copy of
|
||||
the model's state dict in CPU RAM followed by a copy
|
||||
of the model in VRAM. The state dict is provided to allow
|
||||
LoRAs and other model patchers to return the model to
|
||||
its unpatched state without expensive copy and restore
|
||||
operations.
|
||||
Call the LoadedModel's `model_on_device()` method in a context. It returns a tuple consisting of a copy of the
|
||||
model's state dict in CPU RAM followed by a copy of the model in VRAM. The state dict is provided to allow LoRAs and
|
||||
other model patchers to return the model to its unpatched state without expensive copy and restore operations.
|
||||
|
||||
Example:
|
||||
```
|
||||
@ -56,43 +48,42 @@ class LoadedModelWithoutConfig:
|
||||
image = vae.decode(latents)[0]
|
||||
```
|
||||
|
||||
The state_dict should be treated as a read-only object and
|
||||
never modified. Also be aware that some loadable models do
|
||||
not have a state_dict, in which case this value will be None.
|
||||
The state_dict should be treated as a read-only object and never modified. Also be aware that some loadable models
|
||||
do not have a state_dict, in which case this value will be None.
|
||||
"""
|
||||
|
||||
_locker: ModelLocker
|
||||
def __init__(self, cache_record: CacheRecord, cache: ModelCache):
|
||||
self._cache_record = cache_record
|
||||
self._cache = cache
|
||||
|
||||
def __enter__(self) -> AnyModel:
|
||||
"""Context entry."""
|
||||
self._locker.lock()
|
||||
self._cache.lock(self._cache_record.key)
|
||||
return self.model
|
||||
|
||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Context exit."""
|
||||
self._locker.unlock()
|
||||
self._cache.unlock(self._cache_record.key)
|
||||
|
||||
@contextmanager
|
||||
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
|
||||
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
|
||||
locked_model = self._locker.lock()
|
||||
self._cache.lock(self._cache_record.key)
|
||||
try:
|
||||
state_dict = self._locker.get_state_dict()
|
||||
yield (state_dict, locked_model)
|
||||
yield (self._cache_record.state_dict, self._cache_record.model)
|
||||
finally:
|
||||
self._locker.unlock()
|
||||
self._cache.unlock(self._cache_record.key)
|
||||
|
||||
@property
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model without locking it."""
|
||||
return self._locker.model
|
||||
return self._cache_record.model
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedModel(LoadedModelWithoutConfig):
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||
|
||||
config: Optional[AnyModelConfig] = None
|
||||
def __init__(self, config: Optional[AnyModelConfig], cache_record: CacheRecord, cache: ModelCache):
|
||||
super().__init__(cache_record=cache_record, cache=cache)
|
||||
self.config = config
|
||||
|
||||
|
||||
# TODO(MM2):
|
||||
|
@ -14,8 +14,8 @@ from invokeai.backend.model_manager import (
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key
|
||||
from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLocker
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@ -55,8 +55,8 @@ class ModelLoader(ModelLoaderBase):
|
||||
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
|
||||
|
||||
with skip_torch_weight_init():
|
||||
locker = self._load_and_cache(model_config, submodel_type)
|
||||
return LoadedModel(config=model_config, _locker=locker)
|
||||
cache_record = self._load_and_cache(model_config, submodel_type)
|
||||
return LoadedModel(config=model_config, cache_record=cache_record, cache=self._ram_cache)
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCache:
|
||||
@ -67,7 +67,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
model_base = self._app_config.models_path
|
||||
return (model_base / config.path).resolve()
|
||||
|
||||
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLocker:
|
||||
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> CacheRecord:
|
||||
stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")])
|
||||
try:
|
||||
return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name)
|
||||
|
@ -13,7 +13,6 @@ from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLocker
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@ -153,7 +152,7 @@ class ModelCache:
|
||||
self,
|
||||
key: str,
|
||||
stats_name: Optional[str] = None,
|
||||
) -> ModelLocker:
|
||||
) -> CacheRecord:
|
||||
"""Retrieve a model from the cache.
|
||||
|
||||
:param key: Model key
|
||||
@ -185,10 +184,7 @@ class ModelCache:
|
||||
self._cache_stack = [k for k in self._cache_stack if k != key]
|
||||
self._cache_stack.append(key)
|
||||
|
||||
return ModelLocker(
|
||||
cache=self,
|
||||
cache_entry=cache_entry,
|
||||
)
|
||||
return cache_entry
|
||||
|
||||
def lock(self, key: str) -> None:
|
||||
"""Lock a model for use and move it into VRAM."""
|
||||
|
@ -1,35 +0,0 @@
|
||||
"""
|
||||
Base class and implementation of a class that moves models in and out of VRAM.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
|
||||
|
||||
class ModelLocker:
|
||||
def __init__(self, cache: ModelCache, cache_entry: CacheRecord):
|
||||
self._cache = cache
|
||||
self._cache_entry = cache_entry
|
||||
|
||||
@property
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model without moving it around."""
|
||||
return self._cache_entry.model
|
||||
|
||||
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||
"""Return the state dict (if any) for the cached model."""
|
||||
return self._cache_entry.state_dict
|
||||
|
||||
def lock(self) -> AnyModel:
|
||||
"""Move the model into the execution device (GPU) and lock it."""
|
||||
self._cache.lock(self._cache_entry.key)
|
||||
return self.model
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Unlock a model."""
|
||||
self._cache.unlock(self._cache_entry.key)
|
Loading…
Reference in New Issue
Block a user