Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

65 lines
2.1 KiB
Python
Raw Normal View History

"""
Base class and implementation of a class that moves models in and out of VRAM.
"""
2024-02-24 11:25:40 -05:00
import torch
2024-02-26 17:30:37 +11:00
from invokeai.backend.model_manager import AnyModel
2024-02-26 17:30:37 +11:00
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
class ModelLocker(ModelLockerBase):
"""Internal class that mediates movement in and out of GPU."""
def __init__(self, cache: ModelCacheBase[AnyModel], cache_entry: CacheRecord[AnyModel]):
"""
Initialize the model locker.
:param cache: The ModelCache object
:param cache_entry: The entry in the model cache
"""
self._cache = cache
self._cache_entry = cache_entry
@property
def model(self) -> AnyModel:
"""Return the model without moving it around."""
return self._cache_entry.model
def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it."""
if not hasattr(self.model, "to"):
return self.model
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
2024-04-04 23:12:49 -04:00
self._cache_entry.lock()
try:
if self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
self._cache_entry.loaded = True
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
self._cache.print_cuda_stats()
2024-02-24 11:25:40 -05:00
except torch.cuda.OutOfMemoryError:
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
self._cache_entry.unlock()
2024-02-24 11:25:40 -05:00
raise
except Exception:
self._cache_entry.unlock()
raise
return self.model
def unlock(self) -> None:
"""Call upon exit from context."""
if not hasattr(self.model, "to"):
return
self._cache_entry.unlock()
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(0)
self._cache.print_cuda_stats()