Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

36 lines
1.1 KiB
Python
Raw Normal View History

"""
Base class and implementation of a class that moves models in and out of VRAM.
"""
from typing import Dict, Optional
2024-02-24 11:25:40 -05:00
import torch
2024-02-26 17:30:37 +11:00
from invokeai.backend.model_manager import AnyModel
2024-12-04 21:53:19 +00:00
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
2024-12-04 21:47:11 +00:00
class ModelLocker:
2024-12-04 22:05:34 +00:00
def __init__(self, cache: ModelCache, cache_entry: CacheRecord):
self._cache = cache
self._cache_entry = cache_entry
@property
def model(self) -> AnyModel:
"""Return the model without moving it around."""
return self._cache_entry.model
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
return self._cache_entry.state_dict
def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it."""
self._cache.lock(self._cache_entry.key)
return self.model
def unlock(self) -> None:
"""Unlock a model."""
self._cache.unlock(self._cache_entry.key)