Create CachedModelOnlyFullLoad class.

This commit is contained in:
Ryan Dick 2024-12-05 18:43:50 +00:00
parent 91c5af1b95
commit 987393853c
4 changed files with 204 additions and 0 deletions

View File

@ -0,0 +1,69 @@
from typing import Any
import torch
class CachedModelOnlyFullLoad:
"""A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device.
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
MPS memory, etc.
"""
def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int):
"""Initialize a CachedModelOnlyFullLoad.
Args:
model (torch.nn.Module | Any): The model to wrap. Should be on the CPU.
compute_device (torch.device): The compute device to move the model to.
total_bytes (int): The total size (in bytes) of all the weights in the model.
"""
# model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases.
self._model = model
self._compute_device = compute_device
self._total_bytes = total_bytes
self._is_in_vram = False
@property
def model(self) -> torch.nn.Module:
return self._model
def total_bytes(self) -> int:
"""Get the total size (in bytes) of all the weights in the model."""
return self._total_bytes
def is_in_vram(self) -> bool:
"""Return true if the model is currently in VRAM."""
return self._is_in_vram
def full_load_to_vram(self) -> int:
"""Load all weights into VRAM (if supported by the model).
Returns:
The number of bytes loaded into VRAM.
"""
if self._is_in_vram:
# Already in VRAM.
return 0
if not hasattr(self._model, "to"):
# Model doesn't support moving to a device.
return 0
self._model.to(self._compute_device)
self._is_in_vram = True
return self._total_bytes
def full_unload_from_vram(self) -> int:
"""Unload all weights from VRAM.
Returns:
The number of bytes unloaded from VRAM.
"""
if not self._is_in_vram:
# Already in RAM.
return 0
self._model.to("cpu")
self._is_in_vram = False
return self._total_bytes

View File

@ -0,0 +1,13 @@
import torch
class DummyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
self.linear2 = torch.nn.Linear(10, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = self.linear2(x)
return x

View File

@ -0,0 +1,50 @@
import pytest
import torch
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
CachedModelOnlyFullLoad,
)
from tests.backend.model_manager.load.model_cache.cached_model.dummy_module import DummyModule
parameterize_mps_and_cuda = pytest.mark.parametrize(
("device"),
[
pytest.param(
"mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.")
),
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
],
)
@parameterize_mps_and_cuda
def test_cached_model_total_bytes(device: str):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
assert cached_model.total_bytes() == 100
@parameterize_mps_and_cuda
def test_cached_model_is_in_vram(device: str):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
assert not cached_model.is_in_vram()
cached_model.full_load_to_vram()
assert cached_model.is_in_vram()
cached_model.full_unload_from_vram()
assert not cached_model.is_in_vram()
@parameterize_mps_and_cuda
def test_cached_model_full_load_and_unload(device: str):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
assert cached_model.full_load_to_vram() == 100
assert cached_model.is_in_vram()
assert all(p.device.type == device for p in cached_model.model.parameters())
assert cached_model.full_unload_from_vram() == 100
assert not cached_model.is_in_vram()
assert all(p.device.type == "cpu" for p in cached_model.model.parameters())

View File

@ -0,0 +1,72 @@
import pytest
import torch
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
CachedModelWithPartialLoad,
)
from tests.backend.model_manager.load.model_cache.cached_model.dummy_module import DummyModule
parameterize_mps_and_cuda = pytest.mark.parametrize(
("device"),
[
pytest.param(
"mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.")
),
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
],
)
@parameterize_mps_and_cuda
def test_cached_model_total_bytes(device: str):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA is not available.")
if device == "mps" and not torch.backends.mps.is_available():
pytest.skip("MPS is not available.")
model = DummyModule()
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
linear_numel = 10 * 10 + 10
assert cached_model.total_bytes() == linear_numel * 4 * 2
cached_model.model.to(dtype=torch.float16)
assert cached_model.total_bytes() == linear_numel * 2 * 2
@parameterize_mps_and_cuda
def test_cached_model_cur_vram_bytes(device: str):
model = DummyModule()
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
assert cached_model.cur_vram_bytes() == 0
cached_model.model.to(device=torch.device(device))
assert cached_model.cur_vram_bytes() == cached_model.total_bytes()
@parameterize_mps_and_cuda
def test_cached_model_partial_load(device: str):
model = DummyModule()
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
target_vram_bytes = int(model_total_bytes * 0.6)
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
assert loaded_bytes > 0
assert loaded_bytes < model_total_bytes
assert loaded_bytes == cached_model.cur_vram_bytes()
@parameterize_mps_and_cuda
def test_cached_model_partial_unload(device: str):
model = DummyModule()
model.to(device=torch.device(device))
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == model_total_bytes
bytes_to_free = int(model_total_bytes * 0.4)
freed_bytes = cached_model.partial_unload_from_vram(bytes_to_free)
assert freed_bytes >= bytes_to_free
assert freed_bytes < model_total_bytes
assert freed_bytes == model_total_bytes - cached_model.cur_vram_bytes()