WIP - first pass at overhauling ModelCache to work with partial loads.

This commit is contained in:
Ryan Dick 2024-12-05 23:03:40 +00:00
parent 8e409e3436
commit c7b84cf012
6 changed files with 340 additions and 163 deletions

View File

@ -68,7 +68,7 @@ class LoadedModelWithoutConfig:
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
self._cache.lock(self._cache_record.key)
try:
yield (self._cache_record.state_dict, self._cache_record.model)
yield (self._cache_record.cached_model.get_cpu_state_dict(), self._cache_record.cached_model.model)
finally:
self._cache.unlock(self._cache_record.key)

View File

@ -1,38 +1,22 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional
import torch
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
CachedModelOnlyFullLoad,
)
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
CachedModelWithPartialLoad,
)
@dataclass
class CacheRecord:
"""
Elements of the cache:
key: Unique key for each model, same as used in the models database.
model: Model in memory.
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model
loaded: True if the model's state dict is currently in VRAM
Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
The state_dict should be treated as a read-only attribute. Do not attempt
to patch or otherwise modify it. Instead, patch the copy of the state_dict
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
context manager call `model_on_device()`.
"""
"""A class that represents a model in the model cache."""
# Cache key.
key: str
model: Any
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int
loaded: bool = False
# Model in memory.
cached_model: CachedModelWithPartialLoad | CachedModelOnlyFullLoad
# If locks > 0, the model is actively being used, so we should do our best to keep it on the compute device.
_locks: int = 0
def lock(self) -> None:

View File

@ -28,10 +28,22 @@ class CachedModelOnlyFullLoad:
def model(self) -> torch.nn.Module:
return self._model
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
"""Get a read-only copy of the model's state dict in RAM."""
# TODO(ryand): Document this better and implement it.
return None
def total_bytes(self) -> int:
"""Get the total size (in bytes) of all the weights in the model."""
return self._total_bytes
def cur_vram_bytes(self) -> int:
"""Get the size (in bytes) of the weights that are currently in VRAM."""
if self._is_in_vram:
return self._total_bytes
else:
return 0
def is_in_vram(self) -> bool:
"""Return true if the model is currently in VRAM."""
return self._is_in_vram

View File

@ -1,5 +1,8 @@
import torch
from invokeai.backend.model_manager.load.model_cache.torch_function_autocast_context import (
add_autocast_to_module_forward,
)
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
@ -14,12 +17,21 @@ class CachedModelWithPartialLoad:
self._model = model
self._compute_device = compute_device
# Monkey-patch the model to add autocasting to the model's forward method.
add_autocast_to_module_forward(model, compute_device)
# TODO(ryand): Manage a read-only CPU copy of the model state dict.
# TODO(ryand): Add memoization for total_bytes and cur_vram_bytes?
@property
def model(self) -> torch.nn.Module:
return self._model
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
"""Get a read-only copy of the model's state dict in RAM."""
# TODO(ryand): Document this better and implement it.
return None
def total_bytes(self) -> int:
"""Get the total size (in bytes) of all the weights in the model."""
return sum(calc_tensor_size(p) for p in self._model.parameters())
@ -28,6 +40,14 @@ class CachedModelWithPartialLoad:
"""Get the size (in bytes) of the weights that are currently in VRAM."""
return sum(calc_tensor_size(p) for p in self._model.parameters() if p.device.type == self._compute_device.type)
def full_load_to_vram(self) -> int:
"""Load all weights into VRAM."""
return self.partial_load_to_vram(self.total_bytes())
def full_unload_from_vram(self) -> int:
"""Unload all weights from VRAM."""
return self.partial_unload_from_vram(self.total_bytes())
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
"""Load more weights into VRAM without exceeding vram_bytes_to_load.

View File

@ -1,18 +1,19 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
# TODO: Add Stalker's proper name to copyright
import gc
import math
import time
from logging import Logger
from typing import Dict, List, Optional
import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
CachedModelOnlyFullLoad,
)
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
CachedModelWithPartialLoad,
)
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@ -137,9 +138,15 @@ class ModelCache:
size = calc_model_size_by_data(self._logger, model)
self.make_room(size)
running_on_cpu = self._execution_device == torch.device("cpu")
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
cache_record = CacheRecord(key=key, model=model, device=self._storage_device, state_dict=state_dict, size=size)
# Wrap model.
if isinstance(model, torch.nn.Module):
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device)
else:
wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size)
# running_on_cpu = self._execution_device == torch.device("cpu")
# state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
cache_record = CacheRecord(key=key, cached_model=wrapped_model)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
@ -165,10 +172,10 @@ class ModelCache:
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GB)
self.stats.high_watermark = max(self.stats.high_watermark, self._get_cache_size())
self.stats.high_watermark = max(self.stats.high_watermark, self._get_ram_in_use())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.cached_model.total_bytes()
)
# this moves the entry to the top (right end) of the stack
@ -183,19 +190,53 @@ class ModelCache:
cache_entry.lock()
try:
if self._lazy_offloading:
self._offload_unlocked_models(cache_entry.size)
self._move_model_to_device(cache_entry, self._execution_device)
cache_entry.loaded = True
vram_available = self._get_vram_available()
# The amount of additional VRAM that will be used if we fully load the model into VRAM.
vram_needed_for_model = cache_entry.cached_model.total_bytes() - cache_entry.cached_model.cur_vram_bytes()
# Make room for the model in VRAM.
# 1. If the model can fit entirely in VRAM, then make enough room for it to be loaded fully.
# 2. If the model can't fit fully into VRAM, then unload all other models and load as much of the model as
# possible.
self._offload_unlocked_models(vram_needed_for_model)
# Check the updated vram_available after offloading.
vram_available = self._get_vram_available()
# Move as much of the model as possible into VRAM.
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
cache_entry.cached_model.partial_load_to_vram(vram_available)
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
# Partial load is not supported, so we have not choice but to try and fit it all into VRAM.
cache_entry.cached_model.full_load_to_vram()
else:
raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}")
self._logger.debug(f"Locking {cache_entry.key} in {self._execution_device}")
self._print_cuda_stats()
# TODO(ryand): Revive this.
# self._print_cuda_stats()
except torch.cuda.OutOfMemoryError:
self._logger.warning("Insufficient GPU memory to load model. Aborting")
cache_entry.unlock()
raise
except Exception:
finally:
cache_entry.unlock()
raise
# try:
# if self._lazy_offloading:
# self._offload_unlocked_models(cache_entry.size)
# self._move_model_to_device(cache_entry, self._execution_device)
# cache_entry.loaded = True
# self._logger.debug(f"Locking {cache_entry.key} in {self._execution_device}")
# self._print_cuda_stats()
# except torch.cuda.OutOfMemoryError:
# self._logger.warning("Insufficient GPU memory to load model. Aborting")
# cache_entry.unlock()
# raise
# except Exception:
# cache_entry.unlock()
# raise
def unlock(self, key: str) -> None:
"""Unlock a model."""
@ -203,14 +244,23 @@ class ModelCache:
cache_entry.unlock()
if not self._lazy_offloading:
self._offload_unlocked_models(0)
self._print_cuda_stats()
# self._print_cuda_stats()
def _get_cache_size(self) -> int:
"""Get the total size of the models currently cached."""
total = 0
for cache_record in self._cached_models.values():
total += cache_record.size
return total
def _get_vram_available(self) -> int:
"""Get the amount of VRAM available in the cache."""
# Calculate the total amount of VRAM currently in use.
total_vram_in_use = sum(ce.cached_model.cur_vram_bytes() for ce in self._cached_models.values())
# The amount of VRAM available in the cache.
return int(self._max_vram_cache_size * GB) - total_vram_in_use
def _get_ram_available(self) -> int:
"""Get the amount of RAM available in the cache."""
total_ram_in_use = self._get_ram_in_use()
return int(self._max_cache_size * GB) - total_ram_in_use
def _get_ram_in_use(self) -> int:
"""Get the amount of RAM currently in use."""
return sum(ce.cached_model.total_bytes() for ce in self._cached_models.values())
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:
@ -223,113 +273,143 @@ class ModelCache:
else:
return model_key
def _offload_unlocked_models(self, size_required: int) -> None:
"""Offload models from the execution_device to make room for size_required.
def _offload_unlocked_models(self, vram_bytes_to_free: int) -> int:
"""Offload models from the execution_device until vram_bytes_to_free bytes are freed, or all models are
offloaded. Of course, locked models are not offloaded.
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
Returns:
int: The number of bytes freed.
"""
reserved = self._max_vram_cache_size * GB
vram_in_use = torch.cuda.memory_allocated() + size_required
self._logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
# TODO(ryand): Should we support both LRU and smallest-first offloading policies? I can imagine scenarios where
# each would win.
self._logger.debug(f"Offloading unlocked models to free {vram_bytes_to_free/GB:.2f}GB of VRAM.")
vram_bytes_freed = 0
cache_entries_increasing_size = sorted(self._cached_models.values(), key=lambda x: x.cached_model.total_bytes())
for cache_entry in cache_entries_increasing_size:
if vram_bytes_freed >= vram_bytes_to_free:
break
if not cache_entry.loaded:
if cache_entry.is_locked:
continue
if not cache_entry.is_locked:
self._move_model_to_device(cache_entry, self._storage_device)
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
self._logger.debug(
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
cache_entry_bytes_freed = cache_entry.cached_model.partial_unload_from_vram(
vram_bytes_to_free - vram_bytes_freed
)
self._logger.debug(
f"Partially unloaded {cache_entry.key} from VRAM to free {(cache_entry_bytes_freed/GB):.2f}GB."
)
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
cache_entry_bytes_freed = cache_entry.cached_model.full_unload_from_vram()
self._logger.debug(
f"Unloaded {cache_entry.key} from VRAM to free {(cache_entry_bytes_freed/GB):.2f}GB."
)
else:
raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}")
vram_bytes_freed += cache_entry_bytes_freed
return vram_bytes_freed
# reserved = self._max_vram_cache_size * GB
# vram_in_use = torch.cuda.memory_allocated() + size_required
# self._logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
# for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
# if vram_in_use <= reserved:
# break
# if not cache_entry.loaded:
# continue
# if not cache_entry.is_locked:
# self._move_model_to_device(cache_entry, self._storage_device)
# cache_entry.loaded = False
# vram_in_use = torch.cuda.memory_allocated() + size_required
# self._logger.debug(
# f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
# )
TorchDevice.empty_cache()
def _move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None:
"""Move model into the indicated device.
# def _move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None:
# """Move model into the indicated device.
:param cache_entry: The CacheRecord for the model
:param target_device: The torch.device to move the model into
# :param cache_entry: The CacheRecord for the model
# :param target_device: The torch.device to move the model into
May raise a torch.cuda.OutOfMemoryError
"""
self._logger.debug(f"Called to move {cache_entry.key} to {target_device}")
source_device = cache_entry.device
# May raise a torch.cuda.OutOfMemoryError
# """
# self._logger.debug(f"Called to move {cache_entry.key} to {target_device}")
# source_device = cache_entry.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
# This would need to be revised to support multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type:
return
# # Note: We compare device types only so that 'cuda' == 'cuda:0'.
# # This would need to be revised to support multi-GPU.
# if torch.device(source_device).type == torch.device(target_device).type:
# return
# Some models don't have a `to` method, in which case they run in RAM/CPU.
if not hasattr(cache_entry.model, "to"):
return
# # Some models don't have a `to` method, in which case they run in RAM/CPU.
# if not hasattr(cache_entry.model, "to"):
# return
# This roundabout method for moving the model around is done to avoid
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# When moving to VRAM, we copy (not move) each element of the state dict from
# RAM to a new state dict in VRAM, and then inject it into the model.
# This operation is slightly faster than running `to()` on the whole model.
#
# When the model needs to be removed from VRAM we simply delete the copy
# of the state dict in VRAM, and reinject the state dict that is cached
# in RAM into the model. So this operation is very fast.
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
# # This roundabout method for moving the model around is done to avoid
# # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# # When moving to VRAM, we copy (not move) each element of the state dict from
# # RAM to a new state dict in VRAM, and then inject it into the model.
# # This operation is slightly faster than running `to()` on the whole model.
# #
# # When the model needs to be removed from VRAM we simply delete the copy
# # of the state dict in VRAM, and reinject the state dict that is cached
# # in RAM into the model. So this operation is very fast.
# start_model_to_time = time.time()
# snapshot_before = self._capture_memory_snapshot()
try:
if cache_entry.state_dict is not None:
assert hasattr(cache_entry.model, "load_state_dict")
if target_device == self._storage_device:
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
else:
new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(target_device, copy=True)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device)
cache_entry.device = target_device
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
raise e
# try:
# if cache_entry.state_dict is not None:
# assert hasattr(cache_entry.model, "load_state_dict")
# if target_device == self._storage_device:
# cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
# else:
# new_dict: Dict[str, torch.Tensor] = {}
# for k, v in cache_entry.state_dict.items():
# new_dict[k] = v.to(target_device, copy=True)
# cache_entry.model.load_state_dict(new_dict, assign=True)
# cache_entry.model.to(target_device)
# cache_entry.device = target_device
# except Exception as e: # blow away cache entry
# self._delete_cache_entry(cache_entry)
# raise e
snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time()
self._logger.debug(
f"Moved model '{cache_entry.key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
# snapshot_after = self._capture_memory_snapshot()
# end_model_to_time = time.time()
# self._logger.debug(
# f"Moved model '{cache_entry.key}' from {source_device} to"
# f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
# f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
# f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
# )
if (
snapshot_before is not None
and snapshot_after is not None
and snapshot_before.vram is not None
and snapshot_after.vram is not None
):
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
# if (
# snapshot_before is not None
# and snapshot_after is not None
# and snapshot_before.vram is not None
# and snapshot_after.vram is not None
# ):
# vram_change = abs(snapshot_before.vram - snapshot_after.vram)
# If the estimated model size does not match the change in VRAM, log a warning.
if not math.isclose(
vram_change,
cache_entry.size,
rel_tol=0.1,
abs_tol=10 * MB,
):
self._logger.debug(
f"Moving model '{cache_entry.key}' from {source_device} to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GB):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
# # If the estimated model size does not match the change in VRAM, log a warning.
# if not math.isclose(
# vram_change,
# cache_entry.size,
# rel_tol=0.1,
# abs_tol=10 * MB,
# ):
# self._logger.debug(
# f"Moving model '{cache_entry.key}' from {source_device} to"
# f" {target_device} caused an unexpected change in VRAM usage. The model's"
# " estimated size may be incorrect. Estimated model size:"
# f" {(cache_entry.size/GB):.3f} GB.\n"
# f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
# )
def _print_cuda_stats(self) -> None:
"""Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
ram = "%4.2fG" % (self._get_cache_size() / GB)
ram = "%4.2fG" % (self._get_ram_in_use() / GB)
in_ram_models = 0
in_vram_models = 0
@ -348,47 +428,62 @@ class ModelCache:
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
)
def make_room(self, size: int) -> None:
def make_room(self, bytes_needed: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size.
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
external references to the model, there's nothing that the cache can do about it, and those models will not be
garbage-collected.
"""
bytes_needed = size
maximum_size = self._max_cache_size * GB # stored in GB, convert to bytes
current_size = self._get_cache_size()
if current_size + bytes_needed > maximum_size:
self._logger.debug(
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GB):.2f} GB"
)
self._logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
# TODO(ryand): Add debug logging.
ram_bytes_available = self._get_ram_available()
ram_bytes_to_free = max(0, bytes_needed - ram_bytes_available)
ram_bytes_freed = 0
pos = 0
models_cleared = 0
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
while ram_bytes_freed < ram_bytes_to_free and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self._logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
)
if not cache_entry.is_locked:
self._logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
)
current_size -= cache_entry.size
models_cleared += 1
ram_bytes_freed += cache_entry.cached_model.total_bytes()
self._delete_cache_entry(cache_entry)
del cache_entry
models_cleared += 1
else:
pos += 1
# if current_size + bytes_needed > maximum_size:
# self._logger.debug(
# f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
# f" {(bytes_needed/GB):.2f} GB"
# )
# self._logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
# pos = 0
# models_cleared = 0
# while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
# model_key = self._cache_stack[pos]
# cache_entry = self._cached_models[model_key]
# device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
# self._logger.debug(
# f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
# )
# if not cache_entry.is_locked:
# self._logger.debug(
# f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
# )
# current_size -= cache_entry.size
# models_cleared += 1
# self._delete_cache_entry(cache_entry)
# del cache_entry
# else:
# pos += 1
if models_cleared > 0:
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost

View File

@ -70,3 +70,69 @@ def test_cached_model_partial_unload(device: str):
assert freed_bytes >= bytes_to_free
assert freed_bytes < model_total_bytes
assert freed_bytes == model_total_bytes - cached_model.cur_vram_bytes()
@parameterize_mps_and_cuda
def test_cached_model_full_load(device: str):
model = DummyModule()
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
# Full load the model into VRAM.
loaded_bytes = cached_model.full_load_to_vram()
assert loaded_bytes > 0
assert loaded_bytes == model_total_bytes
assert loaded_bytes == cached_model.cur_vram_bytes()
assert all(p.device.type == device for p in cached_model.model.parameters())
@parameterize_mps_and_cuda
def test_cached_model_full_load_from_partial(device: str):
model = DummyModule()
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
# Partially load the model into VRAM.
target_vram_bytes = int(model_total_bytes * 0.6)
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
assert loaded_bytes > 0
assert loaded_bytes < model_total_bytes
assert loaded_bytes == cached_model.cur_vram_bytes()
# Full load the rest of the model into VRAM.
loaded_bytes_2 = cached_model.full_load_to_vram()
assert loaded_bytes_2 > 0
assert loaded_bytes_2 < model_total_bytes
assert loaded_bytes + loaded_bytes_2 == cached_model.cur_vram_bytes()
assert loaded_bytes + loaded_bytes_2 == model_total_bytes
assert all(p.device.type == device for p in cached_model.model.parameters())
@parameterize_mps_and_cuda
def test_cached_model_full_unload(device: str):
model = DummyModule()
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
# Partially load the model into VRAM.
target_vram_bytes = int(model_total_bytes * 0.6)
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
assert loaded_bytes > 0
assert loaded_bytes < model_total_bytes
assert loaded_bytes == cached_model.cur_vram_bytes()
# Full unload the model from VRAM.
unloaded_bytes = cached_model.full_unload_from_vram()
assert unloaded_bytes > 0
assert unloaded_bytes == loaded_bytes
assert cached_model.cur_vram_bytes() == 0
assert all(p.device.type == "cpu" for p in cached_model.model.parameters())