Partial Loading PR1: Tidy ModelCache (#7492)

## Summary

This PR tidies up the model cache code in preparation for further
refactoring to support partial loading of models onto the GPU. **These
code changes should not change the functional behavior in any way.**

Changes:
- Remove the `ModelCacheBase` class. `ModelCache` is the only
implementation, so there is no benefit to the separate abstract class.
- Split `CacheRecord` and `CacheStats` out into their own files.
- Remove the `ModelLocker` class. This extra layer of indirection was
not providing any benefit. Locking is now done directly with the
`ModelCache`.
- Tidy up relative imports that were contributing to circular import
issues.
- Pull the 'submodel' concern out of the `ModelCache`. The `ModelCache`
should not need to be aware of the model manager submodel system.
- Delete unused properties from the `ModelCache` (e.g.
`.lazy_offloading`, `.storage_device`, etc.)

## QA Instructions

I ran smoke tests with a variety of SD1, SDXL and FLUX models. No change
to behavior is expected.

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
This commit is contained in:
Ryan Dick 2024-12-24 09:30:44 -05:00 committed by GitHub
commit d3916dbdb6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 197 additions and 438 deletions

View File

@ -1364,7 +1364,6 @@ the in-memory loaded model:
|----------------|-----------------|------------------|
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
| `model` | AnyModel | The instantiated model (details below) |
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
### get_model_by_key(key, [submodel]) -> LoadedModel

View File

@ -37,7 +37,7 @@ from invokeai.backend.model_manager.config import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
from invokeai.backend.model_manager.search import ModelSearch

View File

@ -20,7 +20,7 @@ from invokeai.app.services.invocation_stats.invocation_stats_common import (
NodeExecutionStatsSummary,
)
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load.model_cache import CacheStats
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
# Size of 1GB in bytes.
GB = 2**30

View File

@ -7,7 +7,7 @@ from typing import Callable, Optional
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
class ModelLoadServiceBase(ABC):
@ -24,7 +24,7 @@ class ModelLoadServiceBase(ABC):
@property
@abstractmethod
def ram_cache(self) -> ModelCacheBase[AnyModel]:
def ram_cache(self) -> ModelCache:
"""Return the RAM cache used by this loader."""
@abstractmethod

View File

@ -18,7 +18,7 @@ from invokeai.backend.model_manager.load import (
ModelLoaderRegistry,
ModelLoaderRegistryBase,
)
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@ -30,7 +30,7 @@ class ModelLoadService(ModelLoadServiceBase):
def __init__(
self,
app_config: InvokeAIAppConfig,
ram_cache: ModelCacheBase[AnyModel],
ram_cache: ModelCache,
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
):
"""Initialize the model load service."""
@ -45,7 +45,7 @@ class ModelLoadService(ModelLoadServiceBase):
self._invoker = invoker
@property
def ram_cache(self) -> ModelCacheBase[AnyModel]:
def ram_cache(self) -> ModelCache:
"""Return the RAM cache used by this loader."""
return self._ram_cache
@ -78,9 +78,8 @@ class ModelLoadService(ModelLoadServiceBase):
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
) -> LoadedModelWithoutConfig:
cache_key = str(model_path)
ram_cache = self.ram_cache
try:
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)
except IndexError:
pass
@ -109,5 +108,5 @@ class ModelLoadService(ModelLoadServiceBase):
)
assert loader is not None
raw_model = loader(model_path)
ram_cache.put(key=cache_key, model=raw_model)
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
self._ram_cache.put(key=cache_key, model=raw_model)
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)

View File

@ -16,7 +16,8 @@ from invokeai.app.services.model_load.model_load_base import ModelLoadServiceBas
from invokeai.app.services.model_load.model_load_default import ModelLoadService
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordServiceBase
from invokeai.backend.model_manager.load import ModelCache, ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger

View File

@ -8,7 +8,7 @@ from pathlib import Path
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_cache.model_cache_default import ModelCache
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
# This registers the subclasses that implement loaders of specific model types

View File

@ -5,7 +5,6 @@ Base class for model loading in InvokeAI.
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from logging import Logger
from pathlib import Path
from typing import Any, Dict, Generator, Optional, Tuple
@ -18,19 +17,17 @@ from invokeai.backend.model_manager.config import (
AnyModelConfig,
SubModelType,
)
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
@dataclass
class LoadedModelWithoutConfig:
"""
Context manager object that mediates transfer from RAM<->VRAM.
"""Context manager object that mediates transfer from RAM<->VRAM.
This is a context manager object that has two distinct APIs:
1. Older API (deprecated):
Use the LoadedModel object directly as a context manager.
It will move the model into VRAM (on CUDA devices), and
Use the LoadedModel object directly as a context manager. It will move the model into VRAM (on CUDA devices), and
return the model in a form suitable for passing to torch.
Example:
```
@ -40,13 +37,9 @@ class LoadedModelWithoutConfig:
```
2. Newer API (recommended):
Call the LoadedModel's `model_on_device()` method in a
context. It returns a tuple consisting of a copy of
the model's state dict in CPU RAM followed by a copy
of the model in VRAM. The state dict is provided to allow
LoRAs and other model patchers to return the model to
its unpatched state without expensive copy and restore
operations.
Call the LoadedModel's `model_on_device()` method in a context. It returns a tuple consisting of a copy of the
model's state dict in CPU RAM followed by a copy of the model in VRAM. The state dict is provided to allow LoRAs and
other model patchers to return the model to its unpatched state without expensive copy and restore operations.
Example:
```
@ -55,43 +48,42 @@ class LoadedModelWithoutConfig:
image = vae.decode(latents)[0]
```
The state_dict should be treated as a read-only object and
never modified. Also be aware that some loadable models do
not have a state_dict, in which case this value will be None.
The state_dict should be treated as a read-only object and never modified. Also be aware that some loadable models
do not have a state_dict, in which case this value will be None.
"""
_locker: ModelLockerBase
def __init__(self, cache_record: CacheRecord, cache: ModelCache):
self._cache_record = cache_record
self._cache = cache
def __enter__(self) -> AnyModel:
"""Context entry."""
self._locker.lock()
self._cache.lock(self._cache_record.key)
return self.model
def __exit__(self, *args: Any, **kwargs: Any) -> None:
"""Context exit."""
self._locker.unlock()
self._cache.unlock(self._cache_record.key)
@contextmanager
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
locked_model = self._locker.lock()
self._cache.lock(self._cache_record.key)
try:
state_dict = self._locker.get_state_dict()
yield (state_dict, locked_model)
yield (self._cache_record.state_dict, self._cache_record.model)
finally:
self._locker.unlock()
self._cache.unlock(self._cache_record.key)
@property
def model(self) -> AnyModel:
"""Return the model without locking it."""
return self._locker.model
return self._cache_record.model
@dataclass
class LoadedModel(LoadedModelWithoutConfig):
"""Context manager object that mediates transfer from RAM<->VRAM."""
config: Optional[AnyModelConfig] = None
def __init__(self, config: Optional[AnyModelConfig], cache_record: CacheRecord, cache: ModelCache):
super().__init__(cache_record=cache_record, cache=cache)
self.config = config
# TODO(MM2):
@ -110,7 +102,7 @@ class ModelLoaderBase(ABC):
self,
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase[AnyModel],
ram_cache: ModelCache,
):
"""Initialize the loader."""
pass
@ -138,6 +130,6 @@ class ModelLoaderBase(ABC):
@property
@abstractmethod
def ram_cache(self) -> ModelCacheBase[AnyModel]:
def ram_cache(self) -> ModelCache:
"""Return the ram cache associated with this loader."""
pass

View File

@ -14,7 +14,8 @@ from invokeai.backend.model_manager import (
)
from invokeai.backend.model_manager.config import DiffusersConfigBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import TorchDevice
@ -28,7 +29,7 @@ class ModelLoader(ModelLoaderBase):
self,
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase[AnyModel],
ram_cache: ModelCache,
):
"""Initialize the loader."""
self._app_config = app_config
@ -54,11 +55,11 @@ class ModelLoader(ModelLoaderBase):
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
with skip_torch_weight_init():
locker = self._load_and_cache(model_config, submodel_type)
return LoadedModel(config=model_config, _locker=locker)
cache_record = self._load_and_cache(model_config, submodel_type)
return LoadedModel(config=model_config, cache_record=cache_record, cache=self._ram_cache)
@property
def ram_cache(self) -> ModelCacheBase[AnyModel]:
def ram_cache(self) -> ModelCache:
"""Return the ram cache associated with this loader."""
return self._ram_cache
@ -66,10 +67,10 @@ class ModelLoader(ModelLoaderBase):
model_base = self._app_config.models_path
return (model_base / config.path).resolve()
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> CacheRecord:
stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")])
try:
return self._ram_cache.get(config.key, submodel_type, stats_name=stats_name)
return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name)
except IndexError:
pass
@ -78,16 +79,11 @@ class ModelLoader(ModelLoaderBase):
loaded_model = self._load_model(config, submodel_type)
self._ram_cache.put(
config.key,
submodel_type=submodel_type,
get_model_cache_key(config.key, submodel_type),
model=loaded_model,
)
return self._ram_cache.get(
key=config.key,
submodel_type=submodel_type,
stats_name=stats_name,
)
return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name)
def get_size_fs(
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None

View File

@ -1,6 +0,0 @@
"""Init file for ModelCache."""
from .model_cache_base import ModelCacheBase, CacheStats # noqa F401
from .model_cache_default import ModelCache # noqa F401
_all__ = ["ModelCacheBase", "ModelCache", "CacheStats"]

View File

@ -0,0 +1,50 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional
import torch
@dataclass
class CacheRecord:
"""
Elements of the cache:
key: Unique key for each model, same as used in the models database.
model: Model in memory.
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model
loaded: True if the model's state dict is currently in VRAM
Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
The state_dict should be treated as a read-only attribute. Do not attempt
to patch or otherwise modify it. Instead, patch the copy of the state_dict
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
context manager call `model_on_device()`.
"""
key: str
model: Any
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int
loaded: bool = False
_locks: int = 0
def lock(self) -> None:
"""Lock this record."""
self._locks += 1
def unlock(self) -> None:
"""Unlock this record."""
self._locks -= 1
assert self._locks >= 0
@property
def locked(self) -> bool:
"""Return true if record is locked."""
return self._locks > 0

View File

@ -0,0 +1,15 @@
from dataclasses import dataclass, field
from typing import Dict
@dataclass
class CacheStats(object):
"""Collect statistics on cache performance."""
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)

View File

@ -1,11 +1,9 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
# TODO: Add Stalker's proper name to copyright
""" """
import gc
import math
import time
from contextlib import suppress
from logging import Logger
from typing import Dict, List, Optional
@ -13,13 +11,8 @@ import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.model_cache.model_cache_base import (
CacheRecord,
CacheStats,
ModelCacheBase,
ModelLockerBase,
)
from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLocker
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@ -31,7 +24,15 @@ GB = 2**30
MB = 2**20
class ModelCache(ModelCacheBase[AnyModel]):
# TODO(ryand): Where should this go? The ModelCache shouldn't be concerned with submodels.
def get_model_cache_key(model_key: str, submodel_type: Optional[SubModelType] = None) -> str:
if submodel_type:
return f"{model_key}:{submodel_type.value}"
else:
return model_key
class ModelCache:
"""A cache for managing models in memory.
The cache is based on two levels of model storage:
@ -70,7 +71,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
max_vram_cache_size: float,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float16,
lazy_offloading: bool = True,
log_memory_usage: bool = False,
logger: Optional[Logger] = None,
@ -82,7 +82,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
@ -100,29 +99,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._log_memory_usage = log_memory_usage
self._stats: Optional[CacheStats] = None
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
self._cached_models: Dict[str, CacheRecord] = {}
self._cache_stack: List[str] = []
@property
def logger(self) -> Logger:
"""Return the logger used by the cache."""
return self._logger
@property
def lazy_offloading(self) -> bool:
"""Return true if the cache is configured to lazily offload models in VRAM."""
return self._lazy_offloading
@property
def storage_device(self) -> torch.device:
"""Return the storage device (e.g. "CPU" for RAM)."""
return self._storage_device
@property
def execution_device(self) -> torch.device:
"""Return the exection device (e.g. "cuda" for VRAM)."""
return self._execution_device
@property
def max_cache_size(self) -> float:
"""Return the cap on cache size."""
@ -153,49 +132,35 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""Set the CacheStats object for collectin cache statistics."""
self._stats = stats
def cache_size(self) -> int:
"""Get the total size of the models currently cached."""
total = 0
for cache_record in self._cached_models.values():
total += cache_record.size
return total
def put(
self,
key: str,
model: AnyModel,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""
key = self._make_cache_key(key, submodel_type)
"""Insert model into the cache."""
if key in self._cached_models:
return
size = calc_model_size_by_data(self.logger, model)
size = calc_model_size_by_data(self._logger, model)
self.make_room(size)
running_on_cpu = self.execution_device == torch.device("cpu")
running_on_cpu = self._execution_device == torch.device("cpu")
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
cache_record = CacheRecord(key=key, model=model, device=self._storage_device, state_dict=state_dict, size=size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
def get(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
stats_name: Optional[str] = None,
) -> ModelLockerBase:
"""
Retrieve model using key and optional submodel_type.
) -> CacheRecord:
"""Retrieve a model from the cache.
:param key: Opaque model key
:param submodel_type: Type of the submodel to fetch
:param stats_name: A human-readable id for the model for the purposes of
stats reporting.
:param key: Model key
:param stats_name: A human-readable id for the model for the purposes of stats reporting.
This may raise an IndexError if the model is not in the cache.
Raises IndexError if the model is not in the cache.
"""
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
if self.stats:
self.stats.hits += 1
@ -210,20 +175,52 @@ class ModelCache(ModelCacheBase[AnyModel]):
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GB)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.high_watermark = max(self.stats.high_watermark, self._get_cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
)
# this moves the entry to the top (right end) of the stack
with suppress(Exception):
self._cache_stack.remove(key)
self._cache_stack = [k for k in self._cache_stack if k != key]
self._cache_stack.append(key)
return ModelLocker(
cache=self,
cache_entry=cache_entry,
)
return cache_entry
def lock(self, key: str) -> None:
"""Lock a model for use and move it into VRAM."""
cache_entry = self._cached_models[key]
cache_entry.lock()
try:
if self._lazy_offloading:
self._offload_unlocked_models(cache_entry.size)
self._move_model_to_device(cache_entry, self._execution_device)
cache_entry.loaded = True
self._logger.debug(f"Locking {cache_entry.key} in {self._execution_device}")
self._print_cuda_stats()
except torch.cuda.OutOfMemoryError:
self._logger.warning("Insufficient GPU memory to load model. Aborting")
cache_entry.unlock()
raise
except Exception:
cache_entry.unlock()
raise
def unlock(self, key: str) -> None:
"""Unlock a model."""
cache_entry = self._cached_models[key]
cache_entry.unlock()
if not self._lazy_offloading:
self._offload_unlocked_models(0)
self._print_cuda_stats()
def _get_cache_size(self) -> int:
"""Get the total size of the models currently cached."""
total = 0
for cache_record in self._cached_models.values():
total += cache_record.size
return total
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:
@ -236,30 +233,30 @@ class ModelCache(ModelCacheBase[AnyModel]):
else:
return model_key
def offload_unlocked_models(self, size_required: int) -> None:
def _offload_unlocked_models(self, size_required: int) -> None:
"""Offload models from the execution_device to make room for size_required.
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
"""
reserved = self._max_vram_cache_size * GB
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
self._logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
if not cache_entry.loaded:
continue
if not cache_entry.locked:
self.move_model_to_device(cache_entry, self.storage_device)
self._move_model_to_device(cache_entry, self._storage_device)
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(
self._logger.debug(
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
)
TorchDevice.empty_cache()
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
def _move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None:
"""Move model into the indicated device.
:param cache_entry: The CacheRecord for the model
@ -267,7 +264,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
May raise a torch.cuda.OutOfMemoryError
"""
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
self._logger.debug(f"Called to move {cache_entry.key} to {target_device}")
source_device = cache_entry.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
@ -294,7 +291,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
try:
if cache_entry.state_dict is not None:
assert hasattr(cache_entry.model, "load_state_dict")
if target_device == self.storage_device:
if target_device == self._storage_device:
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
else:
new_dict: Dict[str, torch.Tensor] = {}
@ -309,7 +306,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time()
self.logger.debug(
self._logger.debug(
f"Moved model '{cache_entry.key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
@ -331,7 +328,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
rel_tol=0.1,
abs_tol=10 * MB,
):
self.logger.debug(
self._logger.debug(
f"Moving model '{cache_entry.key}' from {source_device} to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
@ -339,24 +336,24 @@ class ModelCache(ModelCacheBase[AnyModel]):
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
def print_cuda_stats(self) -> None:
def _print_cuda_stats(self) -> None:
"""Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
ram = "%4.2fG" % (self.cache_size() / GB)
ram = "%4.2fG" % (self._get_cache_size() / GB)
in_ram_models = 0
in_vram_models = 0
locked_in_vram_models = 0
for cache_record in self._cached_models.values():
if hasattr(cache_record.model, "device"):
if cache_record.model.device == self.storage_device:
if cache_record.model.device == self._storage_device:
in_ram_models += 1
else:
in_vram_models += 1
if cache_record.locked:
locked_in_vram_models += 1
self.logger.debug(
self._logger.debug(
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
)
@ -369,16 +366,16 @@ class ModelCache(ModelCacheBase[AnyModel]):
garbage-collected.
"""
bytes_needed = size
maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes
current_size = self.cache_size()
maximum_size = self._max_cache_size * GB # stored in GB, convert to bytes
current_size = self._get_cache_size()
if current_size + bytes_needed > maximum_size:
self.logger.debug(
self._logger.debug(
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GB):.2f} GB"
)
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
self._logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
pos = 0
models_cleared = 0
@ -386,12 +383,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug(
self._logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
)
if not cache_entry.locked:
self.logger.debug(
self._logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
)
current_size -= cache_entry.size
@ -419,8 +416,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
gc.collect()
TorchDevice.empty_cache()
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
self._logger.debug(f"After making room: cached_models={len(self._cached_models)}")
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
def _delete_cache_entry(self, cache_entry: CacheRecord) -> None:
self._cache_stack.remove(cache_entry.key)
del self._cached_models[cache_entry.key]

View File

@ -1,221 +0,0 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
# TODO: Add Stalker's proper name to copyright
"""
Manage a RAM cache of diffusion/transformer models for fast switching.
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
grows larger than a preset maximum, then the least recently used
model will be cleared and (re)loaded from disk when next needed.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from logging import Logger
from typing import Dict, Generic, Optional, TypeVar
import torch
from invokeai.backend.model_manager.config import AnyModel, SubModelType
class ModelLockerBase(ABC):
"""Base class for the model locker used by the loader."""
@abstractmethod
def lock(self) -> AnyModel:
"""Lock the contained model and move it into VRAM."""
pass
@abstractmethod
def unlock(self) -> None:
"""Unlock the contained model, and remove it from VRAM."""
pass
@abstractmethod
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
pass
@property
@abstractmethod
def model(self) -> AnyModel:
"""Return the model."""
pass
T = TypeVar("T")
@dataclass
class CacheRecord(Generic[T]):
"""
Elements of the cache:
key: Unique key for each model, same as used in the models database.
model: Model in memory.
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model
loaded: True if the model's state dict is currently in VRAM
Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
The state_dict should be treated as a read-only attribute. Do not attempt
to patch or otherwise modify it. Instead, patch the copy of the state_dict
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
context manager call `model_on_device()`.
"""
key: str
model: T
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int
loaded: bool = False
_locks: int = 0
def lock(self) -> None:
"""Lock this record."""
self._locks += 1
def unlock(self) -> None:
"""Unlock this record."""
self._locks -= 1
assert self._locks >= 0
@property
def locked(self) -> bool:
"""Return true if record is locked."""
return self._locks > 0
@dataclass
class CacheStats(object):
"""Collect statistics on cache performance."""
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
class ModelCacheBase(ABC, Generic[T]):
"""Virtual base class for RAM model cache."""
@property
@abstractmethod
def storage_device(self) -> torch.device:
"""Return the storage device (e.g. "CPU" for RAM)."""
pass
@property
@abstractmethod
def execution_device(self) -> torch.device:
"""Return the exection device (e.g. "cuda" for VRAM)."""
pass
@property
@abstractmethod
def lazy_offloading(self) -> bool:
"""Return true if the cache is configured to lazily offload models in VRAM."""
pass
@property
@abstractmethod
def max_cache_size(self) -> float:
"""Return the maximum size the RAM cache can grow to."""
pass
@max_cache_size.setter
@abstractmethod
def max_cache_size(self, value: float) -> None:
"""Set the cap on vram cache size."""
@property
@abstractmethod
def max_vram_cache_size(self) -> float:
"""Return the maximum size the VRAM cache can grow to."""
pass
@max_vram_cache_size.setter
@abstractmethod
def max_vram_cache_size(self, value: float) -> float:
"""Set the maximum size the VRAM cache can grow to."""
pass
@abstractmethod
def offload_unlocked_models(self, size_required: int) -> None:
"""Offload from VRAM any models not actively in use."""
pass
@abstractmethod
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device."""
pass
@property
@abstractmethod
def stats(self) -> Optional[CacheStats]:
"""Return collected CacheStats object."""
pass
@stats.setter
@abstractmethod
def stats(self, stats: CacheStats) -> None:
"""Set the CacheStats object for collectin cache statistics."""
pass
@property
@abstractmethod
def logger(self) -> Logger:
"""Return the logger used by the cache."""
pass
@abstractmethod
def make_room(self, size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size."""
pass
@abstractmethod
def put(
self,
key: str,
model: T,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""
pass
@abstractmethod
def get(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
stats_name: Optional[str] = None,
) -> ModelLockerBase:
"""
Retrieve model using key and optional submodel_type.
:param key: Opaque model key
:param submodel_type: Type of the submodel to fetch
:param stats_name: A human-readable id for the model for the purposes of
stats reporting.
This may raise an IndexError if the model is not in the cache.
"""
pass
@abstractmethod
def cache_size(self) -> int:
"""Get the total size of the models currently cached."""
pass
@abstractmethod
def print_cuda_stats(self) -> None:
"""Log debugging information on CUDA usage."""
pass

View File

@ -1,64 +0,0 @@
"""
Base class and implementation of a class that moves models in and out of VRAM.
"""
from typing import Dict, Optional
import torch
from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.load.model_cache.model_cache_base import (
CacheRecord,
ModelCacheBase,
ModelLockerBase,
)
class ModelLocker(ModelLockerBase):
"""Internal class that mediates movement in and out of GPU."""
def __init__(self, cache: ModelCacheBase[AnyModel], cache_entry: CacheRecord[AnyModel]):
"""
Initialize the model locker.
:param cache: The ModelCache object
:param cache_entry: The entry in the model cache
"""
self._cache = cache
self._cache_entry = cache_entry
@property
def model(self) -> AnyModel:
"""Return the model without moving it around."""
return self._cache_entry.model
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
return self._cache_entry.state_dict
def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it."""
self._cache_entry.lock()
try:
if self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
self._cache_entry.loaded = True
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
self._cache.print_cuda_stats()
except torch.cuda.OutOfMemoryError:
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
self._cache_entry.unlock()
raise
except Exception:
self._cache_entry.unlock()
raise
return self.model
def unlock(self) -> None:
"""Call upon exit from context."""
self._cache_entry.unlock()
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(0)
self._cache.print_cuda_stats()

View File

@ -18,7 +18,7 @@ from invokeai.backend.model_manager import (
SubModelType,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import (
is_state_dict_likely_flux_control,
@ -47,7 +47,7 @@ class LoRALoader(ModelLoader):
self,
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase[AnyModel],
ram_cache: ModelCache,
):
"""Initialize the loader."""
super().__init__(app_config, logger, ram_cache)

View File

@ -25,6 +25,7 @@ from invokeai.backend.model_manager.config import (
DiffusersConfigBase,
MainCheckpointConfig,
)
from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.silence_warnings import SilenceWarnings
@ -132,5 +133,5 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
if subtype == submodel_type:
continue
if submodel := getattr(pipeline, subtype.value, None):
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
self._ram_cache.put(get_model_cache_key(config.key, subtype), model=submodel)
return getattr(pipeline, submodel_type.value)

View File

@ -25,7 +25,7 @@ from invokeai.backend.model_manager.config import (
ModelVariantType,
VAEDiffusersConfig,
)
from invokeai.backend.model_manager.load import ModelCache
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager.model_metadata.metadata_examples import (
HFTestLoraMetadata,