# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team # TODO: Add Stalker's proper name to copyright import gc import math import time from logging import Logger from typing import Dict, List, Optional import torch from invokeai.backend.model_manager import AnyModel, SubModelType from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( apply_custom_layers_to_model, ) from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger # Size of a GB in bytes. GB = 2**30 # Size of a MB in bytes. MB = 2**20 # TODO(ryand): Where should this go? The ModelCache shouldn't be concerned with submodels. def get_model_cache_key(model_key: str, submodel_type: Optional[SubModelType] = None) -> str: if submodel_type: return f"{model_key}:{submodel_type.value}" else: return model_key class ModelCache: """A cache for managing models in memory. The cache is based on two levels of model storage: - execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu"). - storage_device: The device where models are offloaded when not in active use (typically "cpu"). The model cache is based on the following assumptions: - storage_device_mem_size > execution_device_mem_size - disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on the execution_device. Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload policy. The storage_device cache uses a least-recently-used (LRU) offload policy. Note: Neither of these offload policies has really been compared against alternatives. It's likely that different policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW configuration. The cache returns context manager generators designed to load the model into the execution device (often GPU) within the context, and unload outside the context. Example usage: ``` cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0) with cache.get_model('runwayml/stable-diffusion-1-5') as SD1: do_something_on_gpu(SD1) ``` """ def __init__( self, max_cache_size: float, max_vram_cache_size: float, execution_device: torch.device = torch.device("cuda"), storage_device: torch.device = torch.device("cpu"), lazy_offloading: bool = True, log_memory_usage: bool = False, logger: Optional[Logger] = None, ): """ Initialize the model RAM cache. :param max_cache_size: Maximum size of the storage_device cache in GBs. :param max_vram_cache_size: Maximum size of the execution_device cache in GBs. :param execution_device: Torch device to load active model into [torch.device('cuda')] :param storage_device: Torch device to save inactive model in [torch.device('cpu')] :param lazy_offloading: Keep model in VRAM until another model needs to be loaded :param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's behaviour. :param logger: InvokeAILogger to use (otherwise creates one) """ # allow lazy offloading only when vram cache enabled self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0 self._max_cache_size: float = max_cache_size self._max_vram_cache_size: float = max_vram_cache_size self._execution_device: torch.device = execution_device self._storage_device: torch.device = storage_device self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._log_memory_usage = log_memory_usage self._stats: Optional[CacheStats] = None self._cached_models: Dict[str, CacheRecord] = {} self._cache_stack: List[str] = [] @property def max_cache_size(self) -> float: """Return the cap on cache size.""" return self._max_cache_size @max_cache_size.setter def max_cache_size(self, value: float) -> None: """Set the cap on cache size.""" self._max_cache_size = value @property def max_vram_cache_size(self) -> float: """Return the cap on vram cache size.""" return self._max_vram_cache_size @max_vram_cache_size.setter def max_vram_cache_size(self, value: float) -> None: """Set the cap on vram cache size.""" self._max_vram_cache_size = value @property def stats(self) -> Optional[CacheStats]: """Return collected CacheStats object.""" return self._stats @stats.setter def stats(self, stats: CacheStats) -> None: """Set the CacheStats object for collectin cache statistics.""" self._stats = stats def put( self, key: str, model: AnyModel, ) -> None: """Insert model into the cache.""" if key in self._cached_models: return size = calc_model_size_by_data(self._logger, model) self.make_room(size) # Inject custom modules into the model. if isinstance(model, torch.nn.Module): apply_custom_layers_to_model(model) running_on_cpu = self._execution_device == torch.device("cpu") state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None cache_record = CacheRecord(key=key, model=model, device=self._storage_device, state_dict=state_dict, size=size) self._cached_models[key] = cache_record self._cache_stack.append(key) def get( self, key: str, stats_name: Optional[str] = None, ) -> CacheRecord: """Retrieve a model from the cache. :param key: Model key :param stats_name: A human-readable id for the model for the purposes of stats reporting. Raises IndexError if the model is not in the cache. """ if key in self._cached_models: if self.stats: self.stats.hits += 1 else: if self.stats: self.stats.misses += 1 raise IndexError(f"The model with key {key} is not in the cache.") cache_entry = self._cached_models[key] # more stats if self.stats: stats_name = stats_name or key self.stats.cache_size = int(self._max_cache_size * GB) self.stats.high_watermark = max(self.stats.high_watermark, self._get_cache_size()) self.stats.in_cache = len(self._cached_models) self.stats.loaded_model_sizes[stats_name] = max( self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size ) # this moves the entry to the top (right end) of the stack self._cache_stack = [k for k in self._cache_stack if k != key] self._cache_stack.append(key) return cache_entry def lock(self, key: str) -> None: """Lock a model for use and move it into VRAM.""" cache_entry = self._cached_models[key] cache_entry.lock() try: if self._lazy_offloading: self._offload_unlocked_models(cache_entry.size) self._move_model_to_device(cache_entry, self._execution_device) cache_entry.loaded = True self._logger.debug(f"Locking {cache_entry.key} in {self._execution_device}") self._print_cuda_stats() except torch.cuda.OutOfMemoryError: self._logger.warning("Insufficient GPU memory to load model. Aborting") cache_entry.unlock() raise except Exception: cache_entry.unlock() raise def unlock(self, key: str) -> None: """Unlock a model.""" cache_entry = self._cached_models[key] cache_entry.unlock() if not self._lazy_offloading: self._offload_unlocked_models(0) self._print_cuda_stats() def _get_cache_size(self) -> int: """Get the total size of the models currently cached.""" total = 0 for cache_record in self._cached_models.values(): total += cache_record.size return total def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]: if self._log_memory_usage: return MemorySnapshot.capture() return None def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType] = None) -> str: if submodel_type: return f"{model_key}:{submodel_type.value}" else: return model_key def _offload_unlocked_models(self, size_required: int) -> None: """Offload models from the execution_device to make room for size_required. :param size_required: The amount of space to clear in the execution_device cache, in bytes. """ reserved = self._max_vram_cache_size * GB vram_in_use = torch.cuda.memory_allocated() + size_required self._logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB") for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): if vram_in_use <= reserved: break if not cache_entry.loaded: continue if not cache_entry.locked: self._move_model_to_device(cache_entry, self._storage_device) cache_entry.loaded = False vram_in_use = torch.cuda.memory_allocated() + size_required self._logger.debug( f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB" ) TorchDevice.empty_cache() def _move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None: """Move model into the indicated device. :param cache_entry: The CacheRecord for the model :param target_device: The torch.device to move the model into May raise a torch.cuda.OutOfMemoryError """ self._logger.debug(f"Called to move {cache_entry.key} to {target_device}") source_device = cache_entry.device # Note: We compare device types only so that 'cuda' == 'cuda:0'. # This would need to be revised to support multi-GPU. if torch.device(source_device).type == torch.device(target_device).type: return # Some models don't have a `to` method, in which case they run in RAM/CPU. if not hasattr(cache_entry.model, "to"): return # This roundabout method for moving the model around is done to avoid # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM. # When moving to VRAM, we copy (not move) each element of the state dict from # RAM to a new state dict in VRAM, and then inject it into the model. # This operation is slightly faster than running `to()` on the whole model. # # When the model needs to be removed from VRAM we simply delete the copy # of the state dict in VRAM, and reinject the state dict that is cached # in RAM into the model. So this operation is very fast. start_model_to_time = time.time() snapshot_before = self._capture_memory_snapshot() try: if cache_entry.state_dict is not None: assert hasattr(cache_entry.model, "load_state_dict") if target_device == self._storage_device: cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True) else: new_dict: Dict[str, torch.Tensor] = {} for k, v in cache_entry.state_dict.items(): new_dict[k] = v.to(target_device, copy=True) cache_entry.model.load_state_dict(new_dict, assign=True) cache_entry.model.to(target_device) cache_entry.device = target_device except Exception as e: # blow away cache entry self._delete_cache_entry(cache_entry) raise e snapshot_after = self._capture_memory_snapshot() end_model_to_time = time.time() self._logger.debug( f"Moved model '{cache_entry.key}' from {source_device} to" f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s." f"Estimated model size: {(cache_entry.size/GB):.3f} GB." f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" ) if ( snapshot_before is not None and snapshot_after is not None and snapshot_before.vram is not None and snapshot_after.vram is not None ): vram_change = abs(snapshot_before.vram - snapshot_after.vram) # If the estimated model size does not match the change in VRAM, log a warning. if not math.isclose( vram_change, cache_entry.size, rel_tol=0.1, abs_tol=10 * MB, ): self._logger.debug( f"Moving model '{cache_entry.key}' from {source_device} to" f" {target_device} caused an unexpected change in VRAM usage. The model's" " estimated size may be incorrect. Estimated model size:" f" {(cache_entry.size/GB):.3f} GB.\n" f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" ) def _print_cuda_stats(self) -> None: """Log CUDA diagnostics.""" vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB) ram = "%4.2fG" % (self._get_cache_size() / GB) in_ram_models = 0 in_vram_models = 0 locked_in_vram_models = 0 for cache_record in self._cached_models.values(): if hasattr(cache_record.model, "device"): if cache_record.model.device == self._storage_device: in_ram_models += 1 else: in_vram_models += 1 if cache_record.locked: locked_in_vram_models += 1 self._logger.debug( f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) =" f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})" ) def make_room(self, size: int) -> None: """Make enough room in the cache to accommodate a new model of indicated size. Note: This function deletes all of the cache's internal references to a model in order to free it. If there are external references to the model, there's nothing that the cache can do about it, and those models will not be garbage-collected. """ bytes_needed = size maximum_size = self._max_cache_size * GB # stored in GB, convert to bytes current_size = self._get_cache_size() if current_size + bytes_needed > maximum_size: self._logger.debug( f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional" f" {(bytes_needed/GB):.2f} GB" ) self._logger.debug(f"Before making_room: cached_models={len(self._cached_models)}") pos = 0 models_cleared = 0 while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): model_key = self._cache_stack[pos] cache_entry = self._cached_models[model_key] device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None self._logger.debug( f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}" ) if not cache_entry.locked: self._logger.debug( f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)" ) current_size -= cache_entry.size models_cleared += 1 self._delete_cache_entry(cache_entry) del cache_entry else: pos += 1 if models_cleared > 0: # There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but # there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost # is high even if no garbage gets collected.) # # Calling gc.collect(...) when a model is cleared seems like a good middle-ground: # - If models had to be cleared, it's a signal that we are close to our memory limit. # - If models were cleared, there's a good chance that there's a significant amount of garbage to be # collected. # # Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up # immediately when their reference count hits 0. if self.stats: self.stats.cleared = models_cleared gc.collect() TorchDevice.empty_cache() self._logger.debug(f"After making room: cached_models={len(self._cached_models)}") def _delete_cache_entry(self, cache_entry: CacheRecord) -> None: self._cache_stack.remove(cache_entry.key) del self._cached_models[cache_entry.key]