Add a CustomModuleMixin class with a flag for enabling/disabling autocasting (since it incurs some runtime speed overhead.)

This commit is contained in:
Ryan Dick 2024-12-26 20:08:30 +00:00
parent 9692a36dd6
commit 7d6ab0ceb2
12 changed files with 99 additions and 14 deletions

View File

@ -1,10 +1,19 @@
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
class CustomConv1d(torch.nn.Conv1d):
def forward(self, input: torch.Tensor) -> torch.Tensor:
class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin):
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return self._conv_forward(input, weight, bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self._device_autocasting_enabled:
return self._autocast_forward(input)
else:
return super().forward(input)

View File

@ -1,10 +1,19 @@
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
class CustomConv2d(torch.nn.Conv2d):
def forward(self, input: torch.Tensor) -> torch.Tensor:
class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return self._conv_forward(input, weight, bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self._device_autocasting_enabled:
return self._autocast_forward(input)
else:
return super().forward(input)

View File

@ -1,10 +1,13 @@
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
class CustomEmbedding(torch.nn.Embedding):
def forward(self, input: torch.Tensor) -> torch.Tensor:
class CustomEmbedding(torch.nn.Embedding, CustomModuleMixin):
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
return torch.nn.functional.embedding(
input,
@ -15,3 +18,9 @@ class CustomEmbedding(torch.nn.Embedding):
self.scale_grad_by_freq,
self.sparse,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self._device_autocasting_enabled:
return self._autocast_forward(input)
else:
return super().forward(input)

View File

@ -1,10 +1,19 @@
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
class CustomGroupNorm(torch.nn.GroupNorm):
def forward(self, input: torch.Tensor) -> torch.Tensor:
class CustomGroupNorm(torch.nn.GroupNorm, CustomModuleMixin):
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self._device_autocasting_enabled:
return self._autocast_forward(input)
else:
return super().forward(input)

View File

@ -2,11 +2,14 @@ import bitsandbytes as bnb
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
class CustomInvokeLinear8bitLt(InvokeLinear8bitLt):
def forward(self, x: torch.Tensor) -> torch.Tensor:
class CustomInvokeLinear8bitLt(InvokeLinear8bitLt, CustomModuleMixin):
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
matmul_state = bnb.MatmulLtState()
matmul_state.threshold = self.state.threshold
matmul_state.has_fp16_weights = self.state.has_fp16_weights
@ -25,3 +28,9 @@ class CustomInvokeLinear8bitLt(InvokeLinear8bitLt):
# it's dtype field must be accessible, even though it's not used. We pass in self.weight even though it could be
# on the wrong device.
return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._device_autocasting_enabled:
return self._autocast_forward(x)
else:
return super().forward(x)

View File

@ -4,11 +4,14 @@ import bitsandbytes as bnb
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
class CustomInvokeLinearNF4(InvokeLinearNF4):
def forward(self, x: torch.Tensor) -> torch.Tensor:
class CustomInvokeLinearNF4(InvokeLinearNF4, CustomModuleMixin):
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
bnb.nn.modules.fix_4bit_weight_quant_state_from_module(self)
# weights are cast automatically as Int8Params, but the bias has to be cast manually
@ -43,3 +46,9 @@ class CustomInvokeLinearNF4(InvokeLinearNF4):
bias = cast_to_device(self.bias, x.device)
return bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state).to(inp_dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._device_autocasting_enabled:
return self._autocast_forward(x)
else:
return super().forward(x)

View File

@ -1,10 +1,19 @@
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
class CustomLinear(torch.nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
class CustomLinear(torch.nn.Linear, CustomModuleMixin):
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self._device_autocasting_enabled:
return self._autocast_forward(input)
else:
return super().forward(input)

View File

@ -0,0 +1,11 @@
class CustomModuleMixin:
"""A mixin class for custom modules that enables device autocasting of module parameters."""
_device_autocasting_enabled = False
def set_device_autocasting_enabled(self, enabled: bool):
"""Pass True to enable autocasting of module parameters to the same device as the input tensor. Pass False to
disable autocasting, which results in slightly faster execution speed when we know that device autocasting is
not needed.
"""
self._device_autocasting_enabled = enabled

View File

@ -46,6 +46,8 @@ def apply_custom_layers_to_model(model: torch.nn.Module):
override_type = AUTOCAST_MODULE_TYPE_MAPPING.get(type(module), None)
if override_type is not None:
module.__class__ = override_type
# TODO(ryand): In the future, we should manage this flag on a per-module basis.
module.set_device_autocasting_enabled(True)
# model.apply(...) calls apply_custom_layers(...) on each module in the model.
model.apply(apply_custom_layers)

View File

@ -215,7 +215,13 @@ def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: La
custom_layer = copy.deepcopy(orig_layer)
apply_custom_layers_to_model(custom_layer)
# Inference should still fail with autocasting disabled.
custom_layer.set_device_autocasting_enabled(False)
with pytest.raises(RuntimeError):
_ = custom_layer(x)
# Run inference with the wrapped layer on the device.
custom_layer.set_device_autocasting_enabled(True)
custom_output = custom_layer(x)
assert custom_output.device.type == device

View File

@ -68,6 +68,7 @@ def test_custom_invoke_linear_8bit_lt_all_weights_on_cpu(linear_8bit_lt_layer: I
# Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it.
linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt
linear_8bit_lt_layer.set_device_autocasting_enabled(True)
y_custom = linear_8bit_lt_layer(x)
# Assert that the quantized and custom layers produce the same output.

View File

@ -40,6 +40,7 @@ def test_custom_invoke_linear_nf4_all_weights_on_cuda(linear_nf4_layer: InvokeLi
# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
linear_nf4_layer.__class__ = CustomInvokeLinearNF4
linear_nf4_layer.set_device_autocasting_enabled(True)
y_custom = linear_nf4_layer(x)
# Assert that the quantized and custom layers produce the same output.
@ -66,6 +67,7 @@ def test_custom_invoke_linear_nf4_all_weights_on_cpu(linear_nf4_layer: InvokeLin
# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
linear_nf4_layer.__class__ = CustomInvokeLinearNF4
linear_nf4_layer.set_device_autocasting_enabled(True)
y_custom = linear_nf4_layer(x)
# Assert that the state dict (and the tensors that it references) are still on the CPU.