mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-07 03:17:05 +08:00
WIP - messing around with some alternative autocast implementations
This commit is contained in:
parent
f6045682c0
commit
f109914eb3
@ -2,9 +2,7 @@ import itertools
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_function_autocast_context import (
|
||||
add_autocast_to_module_forward,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast_context import add_autocast_to_modules
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
@ -35,7 +33,9 @@ class CachedModelWithPartialLoad:
|
||||
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
|
||||
|
||||
# Monkey-patch the model to add autocasting to the model's forward method.
|
||||
add_autocast_to_module_forward(model, compute_device)
|
||||
# add_autocast_to_module_forward(model, compute_device)
|
||||
# inject_custom_layers_into_module(model)
|
||||
add_autocast_to_modules(model, compute_device)
|
||||
|
||||
self._total_bytes = sum(
|
||||
calc_tensor_size(p) for p in itertools.chain(self._model.parameters(), self._model.buffers())
|
||||
|
@ -338,7 +338,8 @@ class ModelCache:
|
||||
)
|
||||
vram_bytes_freed += cache_entry_bytes_freed
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
if vram_bytes_freed > 0:
|
||||
TorchDevice.empty_cache()
|
||||
return vram_bytes_freed
|
||||
|
||||
# def _move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None:
|
||||
|
@ -0,0 +1,80 @@
|
||||
from typing import Any, Iterator
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _add_autocast_to_module(m: torch.nn.Module, to_device: torch.device):
|
||||
def forward_pre_hook(module: torch.nn.Module, args: tuple[Any, ...]):
|
||||
# Backup shallow copies of the existing parameters and buffers.
|
||||
module._parameters_backup = {k: v for k, v in module._parameters.items()}
|
||||
module._buffers_backup = {k: v for k, v in module._buffers.items()}
|
||||
|
||||
# Replace the parameters and buffers with their device-casted versions.
|
||||
for key, param in module._parameters.items():
|
||||
if param is not None and param.device.type != to_device.type:
|
||||
out_param = torch.nn.Parameter(param.to(to_device, copy=True), requires_grad=param.requires_grad)
|
||||
module._parameters[key] = out_param
|
||||
|
||||
for key, buffer in module._buffers.items():
|
||||
if buffer is not None and buffer.device.type != to_device.type:
|
||||
out_buffer = buffer.to(to_device, copy=True)
|
||||
module._buffers[key] = out_buffer
|
||||
|
||||
def forward_post_hook(module: torch.nn.Module, args: tuple[Any, ...], output: Any):
|
||||
# Restore the original parameters and buffers.
|
||||
if hasattr(module, "_parameters_backup"):
|
||||
module._parameters = module._parameters_backup
|
||||
del module._parameters_backup
|
||||
if hasattr(module, "_buffers_backup"):
|
||||
module._buffers = module._buffers_backup
|
||||
del module._buffers_backup
|
||||
|
||||
m.register_forward_pre_hook(forward_pre_hook)
|
||||
m.register_forward_hook(forward_post_hook, always_call=True)
|
||||
|
||||
|
||||
def _add_autocast_to_module_forward(m: torch.nn.Module, to_device: torch.device):
|
||||
m.forward = _cast_to_device_and_run(m.forward, to_device)
|
||||
|
||||
|
||||
def _is_leaf_module(m: torch.nn.Module) -> bool:
|
||||
for _ in m.children():
|
||||
# If the the m.children() generator returns a value, then m is not a leaf module.
|
||||
return False
|
||||
# If we get here then the m.children() generator returned an empty generator, so m is a leaf module.
|
||||
return True
|
||||
|
||||
|
||||
def _named_leaf_modules(m: torch.nn.Module) -> Iterator[tuple[str, torch.nn.Module]]:
|
||||
"""An iterator over all leaf modules in the module hierarchy."""
|
||||
for name, module in m.named_modules():
|
||||
if _is_leaf_module(module):
|
||||
yield name, module
|
||||
|
||||
|
||||
def add_autocast_to_all_leaf_modules(m: torch.nn.Module, to_device: torch.device):
|
||||
for name, module in _named_leaf_modules(m):
|
||||
_add_autocast_to_module(module, to_device)
|
||||
|
||||
|
||||
def add_autocast_to_modules(m: torch.nn.Module, to_device: torch.device):
|
||||
for name, module in m.named_modules():
|
||||
if isinstance(
|
||||
module, (torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.GroupNorm, torch.nn.Embedding)
|
||||
):
|
||||
_add_autocast_to_module(module, to_device)
|
||||
|
||||
|
||||
# def _cast_to_device_and_run(
|
||||
# func: Callable[..., Any], args: tuple[Any, ...], kwargs: dict[str, Any], to_device: torch.device
|
||||
# ):
|
||||
# args_on_device = [a.to(to_device) if isinstance(a, torch.Tensor) else a for a in args]
|
||||
# kwargs_on_device = {k: v.to(to_device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
# return func(*args_on_device, **kwargs_on_device)
|
||||
|
||||
# - Fastest option is if we know exactly which params need to be cast.
|
||||
# - i.e. patch at module level
|
||||
# - Inheritance vs composition?
|
||||
# - Inheritance means that the module looks slightly closer to the original module in case other layers want to
|
||||
# patch it.
|
||||
# - Composition means that the module looks
|
@ -0,0 +1,76 @@
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
T = TypeVar("T", torch.Tensor, None, torch.Tensor | None)
|
||||
|
||||
# Properties to preserve:
|
||||
# - isinstance(m, torch.nn.Linear) should still work
|
||||
# - patching the weights should still work if non-quantized
|
||||
|
||||
|
||||
def cast_to_device(t: T, to_device: torch.device, non_blocking: bool = True) -> T:
|
||||
if t is None:
|
||||
return t
|
||||
|
||||
if t.device.type != to_device.type:
|
||||
return t.to(to_device, non_blocking=non_blocking)
|
||||
return t
|
||||
|
||||
|
||||
def inject_custom_layers_into_module(model: torch.nn.Module):
|
||||
def inject_custom_layers(module: torch.nn.Module):
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
module.__class__ = CustomLinear
|
||||
elif isinstance(module, torch.nn.Conv1d):
|
||||
module.__class__ = CustomConv1d
|
||||
elif isinstance(module, torch.nn.Conv2d):
|
||||
module.__class__ = CustomConv2d
|
||||
elif isinstance(module, torch.nn.GroupNorm):
|
||||
module.__class__ = CustomGroupNorm
|
||||
elif isinstance(module, torch.nn.Embedding):
|
||||
module.__class__ = CustomEmbedding
|
||||
|
||||
model.apply(inject_custom_layers)
|
||||
|
||||
|
||||
class CustomLinear(torch.nn.Linear):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
|
||||
class CustomConv1d(torch.nn.Conv1d):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
|
||||
class CustomConv2d(torch.nn.Conv2d):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
|
||||
class CustomGroupNorm(torch.nn.GroupNorm):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
|
||||
|
||||
class CustomEmbedding(torch.nn.Embedding):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
return torch.nn.functional.embedding(
|
||||
input,
|
||||
weight,
|
||||
self.padding_idx,
|
||||
self.max_norm,
|
||||
self.norm_type,
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
@ -0,0 +1,48 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast_context import (
|
||||
add_autocast_to_all_leaf_modules,
|
||||
)
|
||||
from tests.backend.model_manager.load.model_cache.dummy_module import DummyModule
|
||||
|
||||
# def test_torch_function_autocast_device_context():
|
||||
# if not torch.cuda.is_available():
|
||||
# pytest.skip("CUDA is not available.")
|
||||
|
||||
# model = DummyModule()
|
||||
# # Model parameters should start off on the CPU.
|
||||
# assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
# with TorchFunctionAutocastDeviceContext(to_device=torch.device("cuda")):
|
||||
# x = torch.randn(10, 10, device="cuda")
|
||||
# y = model(x)
|
||||
|
||||
# # The model output should be on the GPU.
|
||||
# assert y.device.type == "cuda"
|
||||
|
||||
# # The model parameters should still be on the CPU.
|
||||
# assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
|
||||
def test_add_autocast_to_module_forward():
|
||||
model = DummyModule()
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
add_autocast_to_all_leaf_modules(model, torch.device("cuda"))
|
||||
# After adding autocast, the model parameters should still be on the CPU.
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
x = torch.randn(10, 10, device="cuda")
|
||||
y = model(x)
|
||||
|
||||
# The model output should be on the GPU.
|
||||
assert y.device.type == "cuda"
|
||||
|
||||
# The model parameters should still be on the CPU.
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
# The autocast context should automatically be disabled after the model forward call completes.
|
||||
# So, attempting to perform an operation with comflicting devices should raise an error.
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = torch.randn(10, device="cuda") * torch.randn(10, device="cpu")
|
Loading…
Reference in New Issue
Block a user