WIP - messing around with some alternative autocast implementations

This commit is contained in:
Ryan Dick 2024-12-11 02:58:56 +00:00
parent f6045682c0
commit f109914eb3
5 changed files with 210 additions and 5 deletions

View File

@ -2,9 +2,7 @@ import itertools
import torch
from invokeai.backend.model_manager.load.model_cache.torch_function_autocast_context import (
add_autocast_to_module_forward,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast_context import add_autocast_to_modules
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
@ -35,7 +33,9 @@ class CachedModelWithPartialLoad:
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
# Monkey-patch the model to add autocasting to the model's forward method.
add_autocast_to_module_forward(model, compute_device)
# add_autocast_to_module_forward(model, compute_device)
# inject_custom_layers_into_module(model)
add_autocast_to_modules(model, compute_device)
self._total_bytes = sum(
calc_tensor_size(p) for p in itertools.chain(self._model.parameters(), self._model.buffers())

View File

@ -338,7 +338,8 @@ class ModelCache:
)
vram_bytes_freed += cache_entry_bytes_freed
TorchDevice.empty_cache()
if vram_bytes_freed > 0:
TorchDevice.empty_cache()
return vram_bytes_freed
# def _move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None:

View File

@ -0,0 +1,80 @@
from typing import Any, Iterator
import torch
def _add_autocast_to_module(m: torch.nn.Module, to_device: torch.device):
def forward_pre_hook(module: torch.nn.Module, args: tuple[Any, ...]):
# Backup shallow copies of the existing parameters and buffers.
module._parameters_backup = {k: v for k, v in module._parameters.items()}
module._buffers_backup = {k: v for k, v in module._buffers.items()}
# Replace the parameters and buffers with their device-casted versions.
for key, param in module._parameters.items():
if param is not None and param.device.type != to_device.type:
out_param = torch.nn.Parameter(param.to(to_device, copy=True), requires_grad=param.requires_grad)
module._parameters[key] = out_param
for key, buffer in module._buffers.items():
if buffer is not None and buffer.device.type != to_device.type:
out_buffer = buffer.to(to_device, copy=True)
module._buffers[key] = out_buffer
def forward_post_hook(module: torch.nn.Module, args: tuple[Any, ...], output: Any):
# Restore the original parameters and buffers.
if hasattr(module, "_parameters_backup"):
module._parameters = module._parameters_backup
del module._parameters_backup
if hasattr(module, "_buffers_backup"):
module._buffers = module._buffers_backup
del module._buffers_backup
m.register_forward_pre_hook(forward_pre_hook)
m.register_forward_hook(forward_post_hook, always_call=True)
def _add_autocast_to_module_forward(m: torch.nn.Module, to_device: torch.device):
m.forward = _cast_to_device_and_run(m.forward, to_device)
def _is_leaf_module(m: torch.nn.Module) -> bool:
for _ in m.children():
# If the the m.children() generator returns a value, then m is not a leaf module.
return False
# If we get here then the m.children() generator returned an empty generator, so m is a leaf module.
return True
def _named_leaf_modules(m: torch.nn.Module) -> Iterator[tuple[str, torch.nn.Module]]:
"""An iterator over all leaf modules in the module hierarchy."""
for name, module in m.named_modules():
if _is_leaf_module(module):
yield name, module
def add_autocast_to_all_leaf_modules(m: torch.nn.Module, to_device: torch.device):
for name, module in _named_leaf_modules(m):
_add_autocast_to_module(module, to_device)
def add_autocast_to_modules(m: torch.nn.Module, to_device: torch.device):
for name, module in m.named_modules():
if isinstance(
module, (torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.GroupNorm, torch.nn.Embedding)
):
_add_autocast_to_module(module, to_device)
# def _cast_to_device_and_run(
# func: Callable[..., Any], args: tuple[Any, ...], kwargs: dict[str, Any], to_device: torch.device
# ):
# args_on_device = [a.to(to_device) if isinstance(a, torch.Tensor) else a for a in args]
# kwargs_on_device = {k: v.to(to_device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
# return func(*args_on_device, **kwargs_on_device)
# - Fastest option is if we know exactly which params need to be cast.
# - i.e. patch at module level
# - Inheritance vs composition?
# - Inheritance means that the module looks slightly closer to the original module in case other layers want to
# patch it.
# - Composition means that the module looks

View File

@ -0,0 +1,76 @@
from typing import TypeVar
import torch
T = TypeVar("T", torch.Tensor, None, torch.Tensor | None)
# Properties to preserve:
# - isinstance(m, torch.nn.Linear) should still work
# - patching the weights should still work if non-quantized
def cast_to_device(t: T, to_device: torch.device, non_blocking: bool = True) -> T:
if t is None:
return t
if t.device.type != to_device.type:
return t.to(to_device, non_blocking=non_blocking)
return t
def inject_custom_layers_into_module(model: torch.nn.Module):
def inject_custom_layers(module: torch.nn.Module):
if isinstance(module, torch.nn.Linear):
module.__class__ = CustomLinear
elif isinstance(module, torch.nn.Conv1d):
module.__class__ = CustomConv1d
elif isinstance(module, torch.nn.Conv2d):
module.__class__ = CustomConv2d
elif isinstance(module, torch.nn.GroupNorm):
module.__class__ = CustomGroupNorm
elif isinstance(module, torch.nn.Embedding):
module.__class__ = CustomEmbedding
model.apply(inject_custom_layers)
class CustomLinear(torch.nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return torch.nn.functional.linear(input, weight, bias)
class CustomConv1d(torch.nn.Conv1d):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return self._conv_forward(input, weight, bias)
class CustomConv2d(torch.nn.Conv2d):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return self._conv_forward(input, weight, bias)
class CustomGroupNorm(torch.nn.GroupNorm):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
class CustomEmbedding(torch.nn.Embedding):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
return torch.nn.functional.embedding(
input,
weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)

View File

@ -0,0 +1,48 @@
import pytest
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast_context import (
add_autocast_to_all_leaf_modules,
)
from tests.backend.model_manager.load.model_cache.dummy_module import DummyModule
# def test_torch_function_autocast_device_context():
# if not torch.cuda.is_available():
# pytest.skip("CUDA is not available.")
# model = DummyModule()
# # Model parameters should start off on the CPU.
# assert all(p.device.type == "cpu" for p in model.parameters())
# with TorchFunctionAutocastDeviceContext(to_device=torch.device("cuda")):
# x = torch.randn(10, 10, device="cuda")
# y = model(x)
# # The model output should be on the GPU.
# assert y.device.type == "cuda"
# # The model parameters should still be on the CPU.
# assert all(p.device.type == "cpu" for p in model.parameters())
def test_add_autocast_to_module_forward():
model = DummyModule()
assert all(p.device.type == "cpu" for p in model.parameters())
add_autocast_to_all_leaf_modules(model, torch.device("cuda"))
# After adding autocast, the model parameters should still be on the CPU.
assert all(p.device.type == "cpu" for p in model.parameters())
x = torch.randn(10, 10, device="cuda")
y = model(x)
# The model output should be on the GPU.
assert y.device.type == "cuda"
# The model parameters should still be on the CPU.
assert all(p.device.type == "cpu" for p in model.parameters())
# The autocast context should automatically be disabled after the model forward call completes.
# So, attempting to perform an operation with comflicting devices should raise an error.
with pytest.raises(RuntimeError):
_ = torch.randn(10, device="cuda") * torch.randn(10, device="cpu")