mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-07 03:17:05 +08:00
Add a CustomModuleMixin class with a flag for enabling/disabling autocasting (since it incurs some runtime speed overhead.)
This commit is contained in:
parent
9692a36dd6
commit
7d6ab0ceb2
@ -1,10 +1,19 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
|
||||
|
||||
class CustomConv1d(torch.nn.Conv1d):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin):
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self._device_autocasting_enabled:
|
||||
return self._autocast_forward(input)
|
||||
else:
|
||||
return super().forward(input)
|
||||
|
@ -1,10 +1,19 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
|
||||
|
||||
class CustomConv2d(torch.nn.Conv2d):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self._device_autocasting_enabled:
|
||||
return self._autocast_forward(input)
|
||||
else:
|
||||
return super().forward(input)
|
||||
|
@ -1,10 +1,13 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
|
||||
|
||||
class CustomEmbedding(torch.nn.Embedding):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
class CustomEmbedding(torch.nn.Embedding, CustomModuleMixin):
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
return torch.nn.functional.embedding(
|
||||
input,
|
||||
@ -15,3 +18,9 @@ class CustomEmbedding(torch.nn.Embedding):
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self._device_autocasting_enabled:
|
||||
return self._autocast_forward(input)
|
||||
else:
|
||||
return super().forward(input)
|
||||
|
@ -1,10 +1,19 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
|
||||
|
||||
class CustomGroupNorm(torch.nn.GroupNorm):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
class CustomGroupNorm(torch.nn.GroupNorm, CustomModuleMixin):
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self._device_autocasting_enabled:
|
||||
return self._autocast_forward(input)
|
||||
else:
|
||||
return super().forward(input)
|
||||
|
@ -2,11 +2,14 @@ import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
|
||||
|
||||
|
||||
class CustomInvokeLinear8bitLt(InvokeLinear8bitLt):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
class CustomInvokeLinear8bitLt(InvokeLinear8bitLt, CustomModuleMixin):
|
||||
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
matmul_state = bnb.MatmulLtState()
|
||||
matmul_state.threshold = self.state.threshold
|
||||
matmul_state.has_fp16_weights = self.state.has_fp16_weights
|
||||
@ -25,3 +28,9 @@ class CustomInvokeLinear8bitLt(InvokeLinear8bitLt):
|
||||
# it's dtype field must be accessible, even though it's not used. We pass in self.weight even though it could be
|
||||
# on the wrong device.
|
||||
return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self._device_autocasting_enabled:
|
||||
return self._autocast_forward(x)
|
||||
else:
|
||||
return super().forward(x)
|
||||
|
@ -4,11 +4,14 @@ import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
|
||||
|
||||
|
||||
class CustomInvokeLinearNF4(InvokeLinearNF4):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
class CustomInvokeLinearNF4(InvokeLinearNF4, CustomModuleMixin):
|
||||
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
bnb.nn.modules.fix_4bit_weight_quant_state_from_module(self)
|
||||
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
@ -43,3 +46,9 @@ class CustomInvokeLinearNF4(InvokeLinearNF4):
|
||||
|
||||
bias = cast_to_device(self.bias, x.device)
|
||||
return bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state).to(inp_dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self._device_autocasting_enabled:
|
||||
return self._autocast_forward(x)
|
||||
else:
|
||||
return super().forward(x)
|
||||
|
@ -1,10 +1,19 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
|
||||
|
||||
class CustomLinear(torch.nn.Linear):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
class CustomLinear(torch.nn.Linear, CustomModuleMixin):
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self._device_autocasting_enabled:
|
||||
return self._autocast_forward(input)
|
||||
else:
|
||||
return super().forward(input)
|
||||
|
@ -0,0 +1,11 @@
|
||||
class CustomModuleMixin:
|
||||
"""A mixin class for custom modules that enables device autocasting of module parameters."""
|
||||
|
||||
_device_autocasting_enabled = False
|
||||
|
||||
def set_device_autocasting_enabled(self, enabled: bool):
|
||||
"""Pass True to enable autocasting of module parameters to the same device as the input tensor. Pass False to
|
||||
disable autocasting, which results in slightly faster execution speed when we know that device autocasting is
|
||||
not needed.
|
||||
"""
|
||||
self._device_autocasting_enabled = enabled
|
@ -46,6 +46,8 @@ def apply_custom_layers_to_model(model: torch.nn.Module):
|
||||
override_type = AUTOCAST_MODULE_TYPE_MAPPING.get(type(module), None)
|
||||
if override_type is not None:
|
||||
module.__class__ = override_type
|
||||
# TODO(ryand): In the future, we should manage this flag on a per-module basis.
|
||||
module.set_device_autocasting_enabled(True)
|
||||
|
||||
# model.apply(...) calls apply_custom_layers(...) on each module in the model.
|
||||
model.apply(apply_custom_layers)
|
||||
|
@ -215,7 +215,13 @@ def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: La
|
||||
custom_layer = copy.deepcopy(orig_layer)
|
||||
apply_custom_layers_to_model(custom_layer)
|
||||
|
||||
# Inference should still fail with autocasting disabled.
|
||||
custom_layer.set_device_autocasting_enabled(False)
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = custom_layer(x)
|
||||
|
||||
# Run inference with the wrapped layer on the device.
|
||||
custom_layer.set_device_autocasting_enabled(True)
|
||||
custom_output = custom_layer(x)
|
||||
assert custom_output.device.type == device
|
||||
|
||||
|
@ -68,6 +68,7 @@ def test_custom_invoke_linear_8bit_lt_all_weights_on_cpu(linear_8bit_lt_layer: I
|
||||
|
||||
# Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it.
|
||||
linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt
|
||||
linear_8bit_lt_layer.set_device_autocasting_enabled(True)
|
||||
y_custom = linear_8bit_lt_layer(x)
|
||||
|
||||
# Assert that the quantized and custom layers produce the same output.
|
||||
|
@ -40,6 +40,7 @@ def test_custom_invoke_linear_nf4_all_weights_on_cuda(linear_nf4_layer: InvokeLi
|
||||
|
||||
# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
|
||||
linear_nf4_layer.__class__ = CustomInvokeLinearNF4
|
||||
linear_nf4_layer.set_device_autocasting_enabled(True)
|
||||
y_custom = linear_nf4_layer(x)
|
||||
|
||||
# Assert that the quantized and custom layers produce the same output.
|
||||
@ -66,6 +67,7 @@ def test_custom_invoke_linear_nf4_all_weights_on_cpu(linear_nf4_layer: InvokeLin
|
||||
|
||||
# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
|
||||
linear_nf4_layer.__class__ = CustomInvokeLinearNF4
|
||||
linear_nf4_layer.set_device_autocasting_enabled(True)
|
||||
y_custom = linear_nf4_layer(x)
|
||||
|
||||
# Assert that the state dict (and the tensors that it references) are still on the CPU.
|
||||
|
Loading…
Reference in New Issue
Block a user