2024-02-03 22:55:09 -05:00
|
|
|
"""
|
|
|
|
Base class and implementation of a class that moves models in and out of VRAM.
|
|
|
|
"""
|
|
|
|
|
2024-06-06 09:53:35 -04:00
|
|
|
from typing import Dict, Optional
|
|
|
|
|
2024-02-24 11:25:40 -05:00
|
|
|
import torch
|
2024-02-26 17:30:37 +11:00
|
|
|
|
2024-02-03 22:55:09 -05:00
|
|
|
from invokeai.backend.model_manager import AnyModel
|
2024-12-04 21:53:19 +00:00
|
|
|
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
2024-12-04 22:05:34 +00:00
|
|
|
from invokeai.backend.model_manager.load.model_cache.model_cache_default import ModelCache
|
2024-02-04 17:23:10 -05:00
|
|
|
|
2024-02-03 22:55:09 -05:00
|
|
|
|
2024-12-04 21:47:11 +00:00
|
|
|
class ModelLocker:
|
2024-02-03 22:55:09 -05:00
|
|
|
"""Internal class that mediates movement in and out of GPU."""
|
|
|
|
|
2024-12-04 22:05:34 +00:00
|
|
|
def __init__(self, cache: ModelCache, cache_entry: CacheRecord):
|
2024-02-03 22:55:09 -05:00
|
|
|
"""
|
|
|
|
Initialize the model locker.
|
|
|
|
|
|
|
|
:param cache: The ModelCache object
|
|
|
|
:param cache_entry: The entry in the model cache
|
|
|
|
"""
|
|
|
|
self._cache = cache
|
|
|
|
self._cache_entry = cache_entry
|
|
|
|
|
|
|
|
@property
|
|
|
|
def model(self) -> AnyModel:
|
|
|
|
"""Return the model without moving it around."""
|
|
|
|
return self._cache_entry.model
|
|
|
|
|
2024-06-06 09:53:35 -04:00
|
|
|
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
|
|
|
"""Return the state dict (if any) for the cached model."""
|
|
|
|
return self._cache_entry.state_dict
|
|
|
|
|
2024-02-03 22:55:09 -05:00
|
|
|
def lock(self) -> AnyModel:
|
|
|
|
"""Move the model into the execution device (GPU) and lock it."""
|
2024-04-04 23:12:49 -04:00
|
|
|
self._cache_entry.lock()
|
2024-02-03 22:55:09 -05:00
|
|
|
try:
|
|
|
|
if self._cache.lazy_offloading:
|
|
|
|
self._cache.offload_unlocked_models(self._cache_entry.size)
|
|
|
|
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
|
|
|
|
self._cache_entry.loaded = True
|
|
|
|
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
|
|
|
|
self._cache.print_cuda_stats()
|
2024-02-24 11:25:40 -05:00
|
|
|
except torch.cuda.OutOfMemoryError:
|
2024-02-24 12:32:30 -05:00
|
|
|
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
|
|
|
|
self._cache_entry.unlock()
|
2024-02-24 11:25:40 -05:00
|
|
|
raise
|
2024-02-03 22:55:09 -05:00
|
|
|
except Exception:
|
|
|
|
self._cache_entry.unlock()
|
|
|
|
raise
|
2024-04-04 22:51:12 -04:00
|
|
|
|
2024-02-03 22:55:09 -05:00
|
|
|
return self.model
|
|
|
|
|
|
|
|
def unlock(self) -> None:
|
|
|
|
"""Call upon exit from context."""
|
|
|
|
self._cache_entry.unlock()
|
|
|
|
if not self._cache.lazy_offloading:
|
2024-05-28 23:01:21 -04:00
|
|
|
self._cache.offload_unlocked_models(0)
|
2024-02-03 22:55:09 -05:00
|
|
|
self._cache.print_cuda_stats()
|