Partial Loading PR2: Add utils to support partial loading of models from CPU to GPU (#7494)

## Summary

This PR adds utilities to support partial loading of models from CPU to
GPU. The new utilities are not yet being used by the ModelCache, so
there should be no functional behavior changes in this PR.

Detailed changes:

- Add autocast modules that are designed to wrap common
`torch.nn.Module`s and enable them to run with automatic device casting.
E.g. a linear layer on the CPU can be executed with an input tensor on
the GPU by streaming the weights to the GPU at runtime.
- Add unit tests for the aforementioned autocast modules to verify that
they work for all supported quantization formats (GGUF, BnB NF4, BnB
LLM.int8()).
- Add `CachedModelWithPartialLoad` and `CachedModelOnlyFullLoad` classes
to manage partial loading at the model level.

## Alternative Implementations

Several options were explored for supporting inference on
partially-loaded models. The pros/cons of the explored options are
summarized here for reference. In the end, wrapper modules were selected
as the best overall solution for our use case.

Option 1: Re-implement the .forward() methods of modules to add support
for device conversions
- This is the option implemented in this PR.
- This approach is the most manual of the three, but as a result offers
the broadest compatibility with unusual model types. It is manual in
that we have to explicitly add support for all module types that we wish
to support. Fortunately, the list of foundational module types is
relatively small (e.g. the current set of implemented layers covers all
but 0.04 MB of the full FLUX model.).

Option 2: Implement a custom Tensor type that casts tensors to a
`target_device` each time the tensor is used
- This approach has the nice property that it is injected at the tensor
level, and the model does not need to be modified in any way.
- One challenge with this approach is handling interactions with other
custom tensor types (e.g. GGMLTensor). This problem is solvable, but
definitely introduces a layer of complexity. (There are likely to also
be some similar issues with interactions with the BnB quantization, but
I didn't get as far as testing BnB.)

Option 3: Override the `__torch_function__` dispatch calls globally and
cast all params to the execution device.
- This approach is nice and simple: just apply a global context manager
and all operations will happen on the compute device regardless of the
device of the participating tensors.
- Challenges:
- Overriding the `__torch_function__` dispatch calls introduces some
overhead even if the tensors are already on the correct device.
- It is difficult to manage the autocasting context manager. E.g. it is
tempting to apply it to the model's `.forward(...)` method, but we use
some models with non-standard entrypoints. And we don't want to end up
with nested autocasting context managers.
- BnB applies quantization side effects when a param is moved to the GPU
- this interacts in unexpected ways with a global context manager.


## QA Instructions

Most of the changes in this PR should not impact active code, and thus
should not cause any changes to behavior. The main risks come from
bumping the bitsandbytes dependency and some minor modifications to the
bitsandbytes quantization code.

- [x] Regression test bitsandbytes NF4 quantization
- [x] Regression test bitsandbytes LLM.int8() quantization
- [x] Regression test on MacOS (to ensure that there are no lingering
bitsandbytes import errors)

I also tested the new utilities for inference on full models in another
branch to validate that there were not major issues. This functionality
will be tested more thoroughly in a future PR.

## Merge Plan

- [x] #7492 should be merged first so that the target branch can be
updated to main.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
This commit is contained in:
Ryan Dick 2024-12-27 09:20:24 -05:00 committed by GitHub
commit 6bf5b747ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 1302 additions and 8 deletions

View File

@ -0,0 +1,93 @@
from typing import Any
import torch
class CachedModelOnlyFullLoad:
"""A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device.
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
MPS memory, etc.
"""
def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int):
"""Initialize a CachedModelOnlyFullLoad.
Args:
model (torch.nn.Module | Any): The model to wrap. Should be on the CPU.
compute_device (torch.device): The compute device to move the model to.
total_bytes (int): The total size (in bytes) of all the weights in the model.
"""
# model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases.
self._model = model
self._compute_device = compute_device
self._offload_device = torch.device("cpu")
# A CPU read-only copy of the model's state dict.
self._cpu_state_dict: dict[str, torch.Tensor] | None = None
if isinstance(model, torch.nn.Module):
self._cpu_state_dict = model.state_dict()
self._total_bytes = total_bytes
self._is_in_vram = False
@property
def model(self) -> torch.nn.Module:
return self._model
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
"""Get a read-only copy of the model's state dict in RAM."""
# TODO(ryand): Document this better.
return self._cpu_state_dict
def total_bytes(self) -> int:
"""Get the total size (in bytes) of all the weights in the model."""
return self._total_bytes
def cur_vram_bytes(self) -> int:
"""Get the size (in bytes) of the weights that are currently in VRAM."""
if self._is_in_vram:
return self._total_bytes
else:
return 0
def is_in_vram(self) -> bool:
"""Return true if the model is currently in VRAM."""
return self._is_in_vram
def full_load_to_vram(self) -> int:
"""Load all weights into VRAM (if supported by the model).
Returns:
The number of bytes loaded into VRAM.
"""
if self._is_in_vram:
# Already in VRAM.
return 0
if not hasattr(self._model, "to"):
# Model doesn't support moving to a device.
return 0
if self._cpu_state_dict is not None:
new_state_dict: dict[str, torch.Tensor] = {}
for k, v in self._cpu_state_dict.items():
new_state_dict[k] = v.to(self._compute_device, copy=True)
self._model.load_state_dict(new_state_dict, assign=True)
self._model.to(self._compute_device)
self._is_in_vram = True
return self._total_bytes
def full_unload_from_vram(self) -> int:
"""Unload all weights from VRAM.
Returns:
The number of bytes unloaded from VRAM.
"""
if not self._is_in_vram:
# Already in RAM.
return 0
if self._cpu_state_dict is not None:
self._model.load_state_dict(self._cpu_state_dict, assign=True)
self._model.to(self._offload_device)
self._is_in_vram = False
return self._total_bytes

View File

@ -0,0 +1,201 @@
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
AUTOCAST_MODULE_TYPE_MAPPING,
apply_custom_layers_to_model,
remove_custom_layers_from_model,
)
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
from invokeai.backend.util.logging import InvokeAILogger
def set_nested_attr(obj: object, attr: str, value: object):
"""A helper function that extends setattr() to support nested attributes.
Example:
set_nested_attr(model, "module.encoder.conv1.weight", new_conv1_weight)
"""
attrs = attr.split(".")
for attr in attrs[:-1]:
obj = getattr(obj, attr)
setattr(obj, attrs[-1], value)
class CachedModelWithPartialLoad:
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
MPS memory, etc.
"""
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
self._model = model
self._compute_device = compute_device
# A CPU read-only copy of the model's state dict.
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
# TODO(ryand): Handle the case where the model sizes changes after initial load (e.g. due to dtype casting).
# Consider how we should handle this for both self._total_bytes and self._cur_vram_bytes.
self._total_bytes = sum(calc_tensor_size(p) for p in self._cpu_state_dict.values())
self._cur_vram_bytes: int | None = None
self._modules_that_support_autocast = self._find_modules_that_support_autocast()
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast()
def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]:
"""Find all modules that support autocasting."""
return {n: m for n, m in self._model.named_modules() if type(m) in AUTOCAST_MODULE_TYPE_MAPPING}
def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]:
keys_in_modules_that_do_not_support_autocast = set()
for key in self._cpu_state_dict.keys():
for module_name in self._modules_that_support_autocast.keys():
if key.startswith(module_name):
break
else:
keys_in_modules_that_do_not_support_autocast.add(key)
return keys_in_modules_that_do_not_support_autocast
def _move_non_persistent_buffers_to_device(self, device: torch.device):
"""Move the non-persistent buffers to the target device. These buffers are not included in the state dict,
so we need to move them manually.
"""
# HACK(ryand): Typically, non-persistent buffers are moved when calling module.to(device). We don't move entire
# modules, because we manage the devices of individual tensors using the state dict. Since non-persistent
# buffers are not included in the state dict, we need to handle them manually. The only way to do this is by
# using private torch.nn.Module attributes.
for module in self._model.modules():
for name, buffer in module.named_buffers():
if name in module._non_persistent_buffers_set:
module._buffers[name] = buffer.to(device, copy=True)
@property
def model(self) -> torch.nn.Module:
return self._model
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
"""Get a read-only copy of the model's state dict in RAM."""
# TODO(ryand): Document this better.
return self._cpu_state_dict
def total_bytes(self) -> int:
"""Get the total size (in bytes) of all the weights in the model."""
return self._total_bytes
def cur_vram_bytes(self) -> int:
"""Get the size (in bytes) of the weights that are currently in VRAM."""
if self._cur_vram_bytes is None:
cur_state_dict = self._model.state_dict()
self._cur_vram_bytes = sum(
calc_tensor_size(p) for p in cur_state_dict.values() if p.device.type == self._compute_device.type
)
return self._cur_vram_bytes
def full_load_to_vram(self) -> int:
"""Load all weights into VRAM."""
return self.partial_load_to_vram(self.total_bytes())
def full_unload_from_vram(self) -> int:
"""Unload all weights from VRAM."""
return self.partial_unload_from_vram(self.total_bytes())
@torch.no_grad()
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
Returns:
The number of bytes loaded into VRAM.
"""
# TODO(ryand): Handle the case where an exception is thrown while loading or unloading weights. At the very
# least, we should reset self._cur_vram_bytes to None.
vram_bytes_loaded = 0
cur_state_dict = self._model.state_dict()
# First, process the keys *must* be loaded into VRAM.
for key in self._keys_in_modules_that_do_not_support_autocast:
param = cur_state_dict[key]
if param.device.type == self._compute_device.type:
continue
param_size = calc_tensor_size(param)
cur_state_dict[key] = param.to(self._compute_device, copy=True)
vram_bytes_loaded += param_size
if vram_bytes_loaded > vram_bytes_to_load:
logger = InvokeAILogger.get_logger()
logger.warning(
f"Loaded {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were "
"requested. This is the minimum set of weights in VRAM required to run the model."
)
# Next, process the keys that can optionally be loaded into VRAM.
fully_loaded = True
for key, param in cur_state_dict.items():
if param.device.type == self._compute_device.type:
continue
param_size = calc_tensor_size(param)
if vram_bytes_loaded + param_size > vram_bytes_to_load:
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
# worth continuing to search for a smaller parameter that would fit?
fully_loaded = False
continue
cur_state_dict[key] = param.to(self._compute_device, copy=True)
vram_bytes_loaded += param_size
if vram_bytes_loaded > 0:
# We load the entire state dict, not just the parameters that changed, in case there are modules that
# override _load_from_state_dict() and do some funky stuff that requires the entire state dict.
# Alternatively, in the future, grouping parameters by module could probably solve this problem.
self._model.load_state_dict(cur_state_dict, assign=True)
if self._cur_vram_bytes is not None:
self._cur_vram_bytes += vram_bytes_loaded
if fully_loaded:
remove_custom_layers_from_model(self._model)
# TODO(ryand): Warn if the self.cur_vram_bytes() and self.total_bytes() are out of sync.
else:
apply_custom_layers_to_model(self._model)
# Move all non-persistent buffers to the compute device. These are a weird edge case and do not participate in
# the vram_bytes_loaded tracking.
self._move_non_persistent_buffers_to_device(self._compute_device)
return vram_bytes_loaded
@torch.no_grad()
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded.
Returns:
The number of bytes unloaded from VRAM.
"""
vram_bytes_freed = 0
offload_device = "cpu"
cur_state_dict = self._model.state_dict()
for key, param in cur_state_dict.items():
if vram_bytes_freed >= vram_bytes_to_free:
break
if param.device.type == offload_device:
continue
cur_state_dict[key] = self._cpu_state_dict[key]
vram_bytes_freed += calc_tensor_size(param)
if vram_bytes_freed > 0:
self._model.load_state_dict(cur_state_dict, assign=True)
if self._cur_vram_bytes is not None:
self._cur_vram_bytes -= vram_bytes_freed
# We may have gone from a fully-loaded model to a partially-loaded model, so we need to reapply the custom
# layers.
apply_custom_layers_to_model(self._model)
return vram_bytes_freed

View File

@ -0,0 +1,50 @@
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
# This file contains custom torch.nn.Module classes that support streaming of weights to the target device.
# Each class sub-classes the original module type that is is replacing, so the following properties are preserved:
# - isinstance(m, torch.nn.OrginalModule) should still work.
# - Patching the weights (e.g. for LoRA) should still work if non-quantized.
class CustomLinear(torch.nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return torch.nn.functional.linear(input, weight, bias)
class CustomConv1d(torch.nn.Conv1d):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return self._conv_forward(input, weight, bias)
class CustomConv2d(torch.nn.Conv2d):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return self._conv_forward(input, weight, bias)
class CustomGroupNorm(torch.nn.GroupNorm):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
class CustomEmbedding(torch.nn.Embedding):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
return torch.nn.functional.embedding(
input,
weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)

View File

@ -0,0 +1,15 @@
from typing import TypeVar
import torch
T = TypeVar("T", torch.Tensor, None, torch.Tensor | None)
def cast_to_device(t: T, to_device: torch.device) -> T:
"""Helper function to cast an optional tensor to a target device."""
if t is None:
return t
if t.device.type != to_device.type:
return t.to(to_device)
return t

View File

@ -0,0 +1,27 @@
import bitsandbytes as bnb
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
class CustomInvokeLinear8bitLt(InvokeLinear8bitLt):
def forward(self, x: torch.Tensor) -> torch.Tensor:
matmul_state = bnb.MatmulLtState()
matmul_state.threshold = self.state.threshold
matmul_state.has_fp16_weights = self.state.has_fp16_weights
matmul_state.use_pool = self.state.use_pool
matmul_state.is_training = self.training
# The underlying InvokeInt8Params weight must already be quantized.
assert self.weight.CB is not None
matmul_state.CB = cast_to_device(self.weight.CB, x.device)
matmul_state.SCB = cast_to_device(self.weight.SCB, x.device)
# weights are cast automatically as Int8Params, but the bias has to be cast manually.
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
# NOTE(ryand): The second parameter should not be needed at all given our expected inference configuration, but
# it's dtype field must be accessible, even though it's not used. We pass in self.weight even though it could be
# on the wrong device.
return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state)

View File

@ -0,0 +1,45 @@
import copy
import bitsandbytes as bnb
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
class CustomInvokeLinearNF4(InvokeLinearNF4):
def forward(self, x: torch.Tensor) -> torch.Tensor:
bnb.nn.modules.fix_4bit_weight_quant_state_from_module(self)
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
if not self.compute_type_is_set:
self.set_compute_type(x)
self.compute_type_is_set = True
inp_dtype = x.dtype
if self.compute_dtype is not None:
x = x.to(self.compute_dtype)
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
# HACK(ryand): Casting self.weight to the device also casts the self.weight.quant_state in-place (i.e. it
# does not follow the tensor semantics of returning a new copy when converting to a different device). This
# means that quant_state elements that started on the CPU would be left on the GPU, which we don't want. To
# avoid this side effect we make a shallow copy of the original quant_state so that we can restore it. Fixing
# this properly would require more invasive changes to the bitsandbytes library.
# Make a shallow copy of the quant_state so that we can undo the in-place modification that occurs when casting
# to a new device.
old_quant_state = copy.copy(self.weight.quant_state)
weight = cast_to_device(self.weight, x.device)
self.weight.quant_state = old_quant_state
# For some reason, the quant_state.to(...) implementation fails to cast the quant_state.code field. We do this
# manually here.
weight.quant_state.code = cast_to_device(weight.quant_state.code, x.device)
bias = cast_to_device(self.bias, x.device)
return bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state).to(inp_dtype)

View File

@ -0,0 +1,56 @@
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import (
CustomConv1d,
CustomConv2d,
CustomEmbedding,
CustomGroupNorm,
CustomLinear,
)
AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]] = {
torch.nn.Linear: CustomLinear,
torch.nn.Conv1d: CustomConv1d,
torch.nn.Conv2d: CustomConv2d,
torch.nn.GroupNorm: CustomGroupNorm,
torch.nn.Embedding: CustomEmbedding,
}
try:
# These dependencies are not expected to be present on MacOS.
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_8_bit_lt import (
CustomInvokeLinear8bitLt,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_nf4 import (
CustomInvokeLinearNF4,
)
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
AUTOCAST_MODULE_TYPE_MAPPING[InvokeLinear8bitLt] = CustomInvokeLinear8bitLt
AUTOCAST_MODULE_TYPE_MAPPING[InvokeLinearNF4] = CustomInvokeLinearNF4
except ImportError:
pass
def apply_custom_layers_to_model(model: torch.nn.Module):
def apply_custom_layers(module: torch.nn.Module):
override_type = AUTOCAST_MODULE_TYPE_MAPPING.get(type(module), None)
if override_type is not None:
module.__class__ = override_type
# model.apply(...) calls apply_custom_layers(...) on each module in the model.
model.apply(apply_custom_layers)
def remove_custom_layers_from_model(model: torch.nn.Module):
# Invert AUTOCAST_MODULE_TYPE_MAPPING.
original_module_type_mapping = {v: k for k, v in AUTOCAST_MODULE_TYPE_MAPPING.items()}
def remove_custom_layers(module: torch.nn.Module):
override_type = original_module_type_mapping.get(type(module), None)
if override_type is not None:
module.__class__ = override_type
# model.apply(...) calls remove_custom_layers(...) on each module in the model.
model.apply(remove_custom_layers)

View File

@ -25,12 +25,9 @@ class InvokeInt8Params(bnb.nn.Int8Params):
self.CB = self.data
self.SCB = self.SCB.cuda()
else:
# we store the 8-bit rows-major weight
# we convert this weight to the turning/ampere weight during the first inference pass
# We quantize the weight and store in 8bit row-major
B = self.data.contiguous().half().cuda(device)
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
del CBt
del SCBt
CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B)
self.data = CB
self.CB = CB
self.SCB = SCB
@ -55,9 +52,10 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
scb = state_dict.pop(prefix + "SCB", None)
# Currently, we only support weight_format=0.
weight_format = state_dict.pop(prefix + "weight_format", None)
assert weight_format == 0
if weight_format is not None:
# Currently, we only support weight_format=0.
assert weight_format == 0
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
# rather than raising an exception to correctly implement this API.
@ -99,6 +97,27 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
new_state.use_pool = self.state.use_pool
self.state = new_state
def forward(self, x: torch.Tensor):
# The state management in the base bnb.nn.Linear8bitLt is very convoluted. We override the forward method to
# try to simplify the state management a bit. We initialize a new MatmulLtState object for each forward pass.
# By avoiding persistent state, it is easier to move the layer between devices without worrying about keeping
# references to weights on the old device (e.g. self.state.CB).
matmul_state = bnb.MatmulLtState()
matmul_state.threshold = self.state.threshold
matmul_state.has_fp16_weights = self.state.has_fp16_weights
matmul_state.use_pool = self.state.use_pool
matmul_state.is_training = self.training
# The underlying InvokeInt8Params weight must already be quantized.
assert self.weight.CB is not None
matmul_state.CB = self.weight.CB
matmul_state.SCB = self.weight.SCB
# weights are cast automatically as Int8Params, but the bias has to be cast manually.
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
return bnb.matmul(x, self.weight, bias=self.bias, state=matmul_state)
def _convert_linear_layers_to_llm_8bit(
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""

View File

@ -34,7 +34,7 @@ classifiers = [
dependencies = [
# Core generation dependencies, pinned for reproducible builds.
"accelerate==1.0.1",
"bitsandbytes==0.43.3; sys_platform!='darwin'",
"bitsandbytes==0.45.0; sys_platform!='darwin'",
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==2.0.2",
"controlnet-aux==0.0.7",

View File

@ -0,0 +1,122 @@
import torch
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
CachedModelOnlyFullLoad,
)
from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda
class NonTorchModel:
"""A model that does not sub-class torch.nn.Module."""
def __init__(self):
self.linear = torch.nn.Linear(10, 32)
def run_inference(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
@parameterize_mps_and_cuda
def test_cached_model_total_bytes(device: str):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
assert cached_model.total_bytes() == 100
@parameterize_mps_and_cuda
def test_cached_model_is_in_vram(device: str):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
assert not cached_model.is_in_vram()
assert cached_model.cur_vram_bytes() == 0
cached_model.full_load_to_vram()
assert cached_model.is_in_vram()
assert cached_model.cur_vram_bytes() == 100
cached_model.full_unload_from_vram()
assert not cached_model.is_in_vram()
assert cached_model.cur_vram_bytes() == 0
@parameterize_mps_and_cuda
def test_cached_model_full_load_and_unload(device: str):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
assert cached_model.full_load_to_vram() == 100
assert cached_model.is_in_vram()
assert all(p.device.type == device for p in cached_model.model.parameters())
assert cached_model.full_unload_from_vram() == 100
assert not cached_model.is_in_vram()
assert all(p.device.type == "cpu" for p in cached_model.model.parameters())
@parameterize_mps_and_cuda
def test_cached_model_get_cpu_state_dict(device: str):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
assert not cached_model.is_in_vram()
# The CPU state dict can be accessed and has the expected properties.
cpu_state_dict = cached_model.get_cpu_state_dict()
assert cpu_state_dict is not None
assert len(cpu_state_dict) == len(model.state_dict())
assert all(p.device.type == "cpu" for p in cpu_state_dict.values())
# Full load the model into VRAM.
cached_model.full_load_to_vram()
assert cached_model.is_in_vram()
# The CPU state dict is still available, and still on the CPU.
cpu_state_dict = cached_model.get_cpu_state_dict()
assert cpu_state_dict is not None
assert len(cpu_state_dict) == len(model.state_dict())
assert all(p.device.type == "cpu" for p in cpu_state_dict.values())
@parameterize_mps_and_cuda
def test_cached_model_full_load_and_inference(device: str):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
assert not cached_model.is_in_vram()
# Run inference on the CPU.
x = torch.randn(1, 10)
output1 = model(x)
assert output1.device.type == "cpu"
# Full load the model into VRAM.
cached_model.full_load_to_vram()
assert cached_model.is_in_vram()
# Run inference on the GPU.
output2 = model(x.to(device))
assert output2.device.type == device
# The outputs should be the same for both runs.
assert torch.allclose(output1, output2.to("cpu"))
@parameterize_mps_and_cuda
def test_non_torch_model(device: str):
model = NonTorchModel()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
assert not cached_model.is_in_vram()
# The model does not have a CPU state dict.
assert cached_model.get_cpu_state_dict() is None
# Attempting to load the model into VRAM should have no effect.
cached_model.full_load_to_vram()
assert not cached_model.is_in_vram()
assert cached_model.cur_vram_bytes() == 0
# Attempting to unload the model from VRAM should have no effect.
cached_model.full_unload_from_vram()
assert not cached_model.is_in_vram()
assert cached_model.cur_vram_bytes() == 0
# Running inference on the CPU should work.
output1 = model.run_inference(torch.randn(1, 10))
assert output1.device.type == "cpu"

View File

@ -0,0 +1,274 @@
import itertools
import torch
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
CachedModelWithPartialLoad,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import CustomLinear
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda
@parameterize_mps_and_cuda
def test_cached_model_total_bytes(device: str):
model = DummyModule()
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
linear1_numel = 10 * 32 + 32
linear2_numel = 32 * 64 + 64
buffer1_numel = 64
# Note that the non-persistent buffer (buffer2) is not included in .total_bytes() calculation.
assert cached_model.total_bytes() == (linear1_numel + linear2_numel + buffer1_numel) * 4
@parameterize_mps_and_cuda
def test_cached_model_cur_vram_bytes(device: str):
model = DummyModule()
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
assert cached_model.cur_vram_bytes() == 0
# Full load the model into VRAM.
cached_model.full_load_to_vram()
assert cached_model.cur_vram_bytes() > 0
assert cached_model.cur_vram_bytes() == cached_model.total_bytes()
assert all(p.device.type == device for p in model.parameters())
assert all(p.device.type == device for p in model.buffers())
@parameterize_mps_and_cuda
def test_cached_model_partial_load(device: str):
model = DummyModule()
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
# Partially load the model into VRAM.
target_vram_bytes = int(model_total_bytes * 0.6)
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
# Check that the model is partially loaded into VRAM.
assert loaded_bytes > 0
assert loaded_bytes < model_total_bytes
assert loaded_bytes == cached_model.cur_vram_bytes()
assert loaded_bytes == sum(
calc_tensor_size(p)
for n, p in itertools.chain(model.named_parameters(), model.named_buffers())
if p.device.type == device and n != "buffer2"
)
# Check that the model's modules have been patched with CustomLinear layers.
assert type(model.linear1) is CustomLinear
assert type(model.linear2) is CustomLinear
@parameterize_mps_and_cuda
def test_cached_model_partial_unload(device: str):
model = DummyModule()
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
# Full load the model into VRAM.
cached_model.full_load_to_vram()
assert cached_model.cur_vram_bytes() == model_total_bytes
# Partially unload the model from VRAM.
bytes_to_free = int(model_total_bytes * 0.4)
freed_bytes = cached_model.partial_unload_from_vram(bytes_to_free)
# Check that the model is partially unloaded from VRAM.
assert freed_bytes >= bytes_to_free
assert freed_bytes < model_total_bytes
assert freed_bytes == model_total_bytes - cached_model.cur_vram_bytes()
assert freed_bytes == sum(
calc_tensor_size(p) for p in itertools.chain(model.parameters(), model.buffers()) if p.device.type == "cpu"
)
# Check that the model's modules are still patched with CustomLinear layers.
assert type(model.linear1) is CustomLinear
assert type(model.linear2) is CustomLinear
@parameterize_mps_and_cuda
def test_cached_model_full_load_and_unload(device: str):
model = DummyModule()
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
# Full load the model into VRAM.
loaded_bytes = cached_model.full_load_to_vram()
assert loaded_bytes > 0
assert loaded_bytes == model_total_bytes
assert loaded_bytes == cached_model.cur_vram_bytes()
assert all(p.device.type == device for p in itertools.chain(model.parameters(), model.buffers()))
assert type(model.linear1) is torch.nn.Linear
assert type(model.linear2) is torch.nn.Linear
# Full unload the model from VRAM.
unloaded_bytes = cached_model.full_unload_from_vram()
# Check that the model is fully unloaded from VRAM.
assert unloaded_bytes > 0
assert unloaded_bytes == model_total_bytes
assert cached_model.cur_vram_bytes() == 0
# Note that the non-persistent buffer (buffer2) is not required to be unloaded from VRAM.
assert all(
p.device.type == "cpu"
for n, p in itertools.chain(model.named_parameters(), model.named_buffers())
if n != "buffer2"
)
@parameterize_mps_and_cuda
def test_cached_model_full_load_from_partial(device: str):
model = DummyModule()
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
# Partially load the model into VRAM.
target_vram_bytes = int(model_total_bytes * 0.6)
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
assert loaded_bytes > 0
assert loaded_bytes < model_total_bytes
assert loaded_bytes == cached_model.cur_vram_bytes()
assert type(model.linear1) is CustomLinear
assert type(model.linear2) is CustomLinear
# Full load the rest of the model into VRAM.
loaded_bytes_2 = cached_model.full_load_to_vram()
assert loaded_bytes_2 > 0
assert loaded_bytes_2 < model_total_bytes
assert loaded_bytes + loaded_bytes_2 == cached_model.cur_vram_bytes()
assert loaded_bytes + loaded_bytes_2 == model_total_bytes
assert all(p.device.type == device for p in itertools.chain(model.parameters(), model.buffers()))
assert type(model.linear1) is torch.nn.Linear
assert type(model.linear2) is torch.nn.Linear
@parameterize_mps_and_cuda
def test_cached_model_full_unload_from_partial(device: str):
model = DummyModule()
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
# Partially load the model into VRAM.
target_vram_bytes = int(model_total_bytes * 0.6)
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
assert loaded_bytes > 0
assert loaded_bytes < model_total_bytes
assert loaded_bytes == cached_model.cur_vram_bytes()
# Full unload the model from VRAM.
unloaded_bytes = cached_model.full_unload_from_vram()
assert unloaded_bytes > 0
assert unloaded_bytes == loaded_bytes
assert cached_model.cur_vram_bytes() == 0
# Note that the non-persistent buffer (buffer2) is not required to be unloaded from VRAM.
assert all(
p.device.type == "cpu"
for n, p in itertools.chain(model.named_parameters(), model.named_buffers())
if n != "buffer2"
)
@parameterize_mps_and_cuda
def test_cached_model_get_cpu_state_dict(device: str):
model = DummyModule()
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
# Model starts in CPU memory.
assert cached_model.cur_vram_bytes() == 0
# The CPU state dict can be accessed and has the expected properties.
cpu_state_dict = cached_model.get_cpu_state_dict()
assert cpu_state_dict is not None
assert len(cpu_state_dict) == len(model.state_dict())
assert all(p.device.type == "cpu" for p in cpu_state_dict.values())
# Full load the model into VRAM.
cached_model.full_load_to_vram()
assert cached_model.cur_vram_bytes() == cached_model.total_bytes()
# The CPU state dict is still available, and still on the CPU.
cpu_state_dict = cached_model.get_cpu_state_dict()
assert cpu_state_dict is not None
assert len(cpu_state_dict) == len(model.state_dict())
assert all(p.device.type == "cpu" for p in cpu_state_dict.values())
@parameterize_mps_and_cuda
def test_cached_model_full_load_and_inference(device: str):
model = DummyModule()
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
# Run inference on the CPU.
x = torch.randn(1, 10)
output1 = model(x)
assert output1.device.type == "cpu"
# Full load the model into VRAM.
loaded_bytes = cached_model.full_load_to_vram()
assert loaded_bytes > 0
assert loaded_bytes == model_total_bytes
assert loaded_bytes == cached_model.cur_vram_bytes()
assert all(p.device.type == device for p in itertools.chain(model.parameters(), model.buffers()))
# Run inference on the GPU.
output2 = model(x.to(device))
assert output2.device.type == device
# The outputs should be the same for both runs.
assert torch.allclose(output1, output2.to("cpu"))
@parameterize_mps_and_cuda
def test_cached_model_partial_load_and_inference(device: str):
model = DummyModule()
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
# Run inference on the CPU.
x = torch.randn(1, 10)
output1 = model(x)
assert output1.device.type == "cpu"
# Partially load the model into VRAM.
target_vram_bytes = int(model_total_bytes * 0.6)
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
# Check that the model is partially loaded into VRAM.
assert loaded_bytes > 0
assert loaded_bytes < model_total_bytes
assert loaded_bytes == cached_model.cur_vram_bytes()
assert loaded_bytes == sum(
calc_tensor_size(p)
for n, p in itertools.chain(model.named_parameters(), model.named_buffers())
if p.device.type == device and n != "buffer2"
)
# Check that the model's modules have been patched with CustomLinear layers.
assert type(model.linear1) is CustomLinear
assert type(model.linear2) is CustomLinear
# Run inference on the GPU.
output2 = model(x.to(device))
assert output2.device.type == device
# The output should be the same as the output from the CPU.
assert torch.allclose(output1, output2.to("cpu"))

View File

@ -0,0 +1,31 @@
import pytest
import torch
class DummyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 32)
self.linear2 = torch.nn.Linear(32, 64)
self.register_buffer("buffer1", torch.ones(64))
# Non-persistent buffers are not included in the state dict. We need to make sure that this case is handled
# correctly by the partial loading code.
self.register_buffer("buffer2", torch.ones(64), persistent=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = self.linear2(x)
x = x + self.buffer1
x = x + self.buffer2
return x
parameterize_mps_and_cuda = pytest.mark.parametrize(
("device"),
[
pytest.param(
"mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.")
),
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
],
)

View File

@ -0,0 +1,144 @@
import pytest
import torch
if not torch.cuda.is_available():
pytest.skip("CUDA is not available", allow_module_level=True)
else:
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_8_bit_lt import (
CustomInvokeLinear8bitLt,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_nf4 import (
CustomInvokeLinearNF4,
)
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
@pytest.fixture
def linear_8bit_lt_layer():
if not torch.cuda.is_available():
pytest.skip("CUDA is not available")
torch.manual_seed(1)
orig_layer = torch.nn.Linear(32, 64)
orig_layer_state_dict = orig_layer.state_dict()
# Prepare a quantized InvokeLinear8bitLt layer.
quantized_layer = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False)
quantized_layer.load_state_dict(orig_layer_state_dict)
quantized_layer.to("cuda")
# Assert that the InvokeLinear8bitLt layer is quantized.
assert quantized_layer.weight.CB is not None
assert quantized_layer.weight.SCB is not None
assert quantized_layer.weight.CB.dtype == torch.int8
return quantized_layer
def test_custom_invoke_linear_8bit_lt_all_weights_on_cuda(linear_8bit_lt_layer: InvokeLinear8bitLt):
"""Test CustomInvokeLinear8bitLt inference with all weights on the GPU."""
# Run inference on the original layer.
x = torch.randn(1, 32).to("cuda")
y_quantized = linear_8bit_lt_layer(x)
# Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it.
linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt
y_custom = linear_8bit_lt_layer(x)
# Assert that the quantized and custom layers produce the same output.
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
def test_custom_invoke_linear_8bit_lt_all_weights_on_cpu(linear_8bit_lt_layer: InvokeLinear8bitLt):
"""Test CustomInvokeLinear8bitLt inference with all weights on the CPU (streaming to the GPU)."""
# Run inference on the original layer.
x = torch.randn(1, 32).to("cuda")
y_quantized = linear_8bit_lt_layer(x)
# Copy the state dict to the CPU and reload it.
state_dict = linear_8bit_lt_layer.state_dict()
state_dict = {k: v.to("cpu") for k, v in state_dict.items()}
linear_8bit_lt_layer.load_state_dict(state_dict)
# Inference of the original layer should fail.
with pytest.raises(RuntimeError):
linear_8bit_lt_layer(x)
# Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it.
linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt
y_custom = linear_8bit_lt_layer(x)
# Assert that the quantized and custom layers produce the same output.
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
@pytest.fixture
def linear_nf4_layer():
if not torch.cuda.is_available():
pytest.skip("CUDA is not available")
torch.manual_seed(1)
orig_layer = torch.nn.Linear(64, 16)
orig_layer_state_dict = orig_layer.state_dict()
# Prepare a quantized InvokeLinearNF4 layer.
quantized_layer = InvokeLinearNF4(input_features=64, output_features=16)
quantized_layer.load_state_dict(orig_layer_state_dict)
quantized_layer.to("cuda")
# Assert that the InvokeLinearNF4 layer is quantized.
assert quantized_layer.weight.bnb_quantized
return quantized_layer
def test_custom_invoke_linear_nf4_all_weights_on_cuda(linear_nf4_layer: InvokeLinearNF4):
"""Test CustomInvokeLinearNF4 inference with all weights on the GPU."""
# Run inference on the original layer.
x = torch.randn(1, 64).to("cuda")
y_quantized = linear_nf4_layer(x)
# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
linear_nf4_layer.__class__ = CustomInvokeLinearNF4
y_custom = linear_nf4_layer(x)
# Assert that the quantized and custom layers produce the same output.
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
# We run with two different input dimensions, because the NF4 layer follows a different code path depending on the
# input dimension, and this has caused issues in the past.
@pytest.mark.parametrize("input_dim_0", [1, 2])
def test_custom_invoke_linear_nf4_all_weights_on_cpu(linear_nf4_layer: InvokeLinearNF4, input_dim_0: int):
"""Test CustomInvokeLinearNF4 inference with all weights on the CPU (streaming to the GPU)."""
# Run inference on the original layer.
x = torch.randn(input_dim_0, 64).to(device="cuda")
y_quantized = linear_nf4_layer(x)
# Copy the state dict to the CPU and reload it.
state_dict = linear_nf4_layer.state_dict()
state_dict = {k: v.to("cpu") for k, v in state_dict.items()}
linear_nf4_layer.load_state_dict(state_dict)
# Inference of the original layer should fail.
with pytest.raises(RuntimeError):
linear_nf4_layer(x)
# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
linear_nf4_layer.__class__ = CustomInvokeLinearNF4
y_custom = linear_nf4_layer(x)
# Assert that the state dict (and the tensors that it references) are still on the CPU.
assert all(v.device == torch.device("cpu") for v in state_dict.values())
# Assert that the weight, bias, and quant_state are all on the CPU.
assert linear_nf4_layer.weight.device == torch.device("cpu")
assert linear_nf4_layer.bias.device == torch.device("cpu")
assert linear_nf4_layer.weight.quant_state.absmax.device == torch.device("cpu")
assert linear_nf4_layer.weight.quant_state.code.device == torch.device("cpu")
# Assert that the quantized and custom layers produce the same output.
assert torch.allclose(y_quantized, y_custom, atol=1e-5)

View File

@ -0,0 +1,132 @@
import os
import gguf
import pytest
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
apply_custom_layers_to_model,
remove_custom_layers_from_model,
)
from tests.backend.quantization.gguf.test_ggml_tensor import quantize_tensor
try:
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt, quantize_model_llm_int8
except ImportError:
# This is expected to fail on MacOS
pass
cuda_and_mps = pytest.mark.parametrize(
"device",
[
pytest.param(
torch.device("cuda"), marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
),
pytest.param(
torch.device("mps"),
marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="requires MPS device"),
),
],
)
class ModelWithLinearLayer(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(32, 64)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
@pytest.fixture(params=["none", "gguf"])
def model(request: pytest.FixtureRequest) -> torch.nn.Module:
if request.param == "none":
return ModelWithLinearLayer()
elif request.param == "gguf":
# Initialize ModelWithLinearLayer and replace the linear layer weight with a GGML quantized weight.
model = ModelWithLinearLayer()
ggml_quantized_weight = quantize_tensor(model.linear.weight, gguf.GGMLQuantizationType.Q8_0)
model.linear.weight = torch.nn.Parameter(ggml_quantized_weight)
return model
else:
raise ValueError(f"Invalid quantization type: {request.param}")
@cuda_and_mps
@torch.no_grad()
def test_torch_module_autocast_linear_layer(device: torch.device, model: torch.nn.Module):
# Skip this test with MPS on GitHub Actions. It fails but I haven't taken the tie to figure out why. It passes
# locally on MacOS.
if os.environ.get("GITHUB_ACTIONS") == "true" and device.type == "mps":
pytest.skip("This test is flaky on GitHub Actions")
# Model parameters should start off on the CPU.
assert all(p.device.type == "cpu" for p in model.parameters())
torch.manual_seed(0)
# Run inference on the CPU.
x = torch.randn(1, 32, device="cpu")
expected = model(x)
assert expected.device.type == "cpu"
# Apply the custom layers to the model.
apply_custom_layers_to_model(model)
# Run the model on the device.
autocast_result = model(x.to(device))
# The model output should be on the device.
assert autocast_result.device.type == device.type
# The model parameters should still be on the CPU.
assert all(p.device.type == "cpu" for p in model.parameters())
# Remove the custom layers from the model.
remove_custom_layers_from_model(model)
# After removing the custom layers, the model should no longer be able to run inference on the device.
with pytest.raises(RuntimeError):
_ = model(x.to(device))
# Run inference again on the CPU.
after_result = model(x)
assert after_result.device.type == "cpu"
# The results from all inference runs should be the same.
assert torch.allclose(autocast_result.to("cpu"), expected, atol=1e-5)
assert torch.allclose(after_result, expected, atol=1e-5)
@torch.no_grad()
def test_torch_module_autocast_bnb_llm_int8_linear_layer():
if not torch.cuda.is_available():
pytest.skip("requires CUDA device")
torch.manual_seed(0)
model = ModelWithLinearLayer()
model = quantize_model_llm_int8(model, modules_to_not_convert=set())
# The act of moving the model to the CUDA device will trigger quantization.
model.to("cuda")
# Confirm that the layer is quantized.
assert isinstance(model.linear, InvokeLinear8bitLt)
assert model.linear.weight.CB is not None
assert model.linear.weight.SCB is not None
# Run inference on the GPU.
x = torch.randn(1, 32)
expected = model(x.to("cuda"))
assert expected.device.type == "cuda"
# Move the model back to the CPU and add the custom layers to the model.
model.to("cpu")
apply_custom_layers_to_model(model)
# Run inference with weights being streamed to the GPU.
autocast_result = model(x.to("cuda"))
assert autocast_result.device.type == "cuda"
# The results from all inference runs should be the same.
assert torch.allclose(autocast_result, expected, atol=1e-5)

View File

@ -0,0 +1,85 @@
import pytest
import torch
try:
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
except ImportError:
pass
def test_invoke_linear_8bit_lt_quantization():
"""Test quantization with InvokeLinear8bitLt."""
if not torch.cuda.is_available():
pytest.skip("CUDA is not available")
# Set the seed for reproducibility since we are using a pretty tight atol.
torch.manual_seed(3)
orig_layer = torch.nn.Linear(32, 64)
orig_layer_state_dict = orig_layer.state_dict()
# Initialize a InvokeLinear8bitLt layer (it is not quantized yet).
quantized_layer = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False)
# Load the non-quantized layer's state dict into the quantized layer.
quantized_layer.load_state_dict(orig_layer_state_dict)
# Move the InvokeLinear8bitLt layer to the GPU. This triggers quantization.
quantized_layer.to("cuda")
# Assert that the InvokeLinear8bitLt layer is quantized.
assert quantized_layer.weight.CB is not None
assert quantized_layer.weight.SCB is not None
assert quantized_layer.weight.CB.dtype == torch.int8
# Run inference on both the original and quantized layers.
x = torch.randn(1, 32)
y = orig_layer(x)
y_quantized = quantized_layer(x.to("cuda"))
assert y.shape == y_quantized.shape
# All within ~20% of each other.
assert torch.allclose(y, y_quantized.to("cpu"), atol=0.05)
def test_invoke_linear_8bit_lt_state_dict_roundtrip():
"""Test that we can roundtrip the state dict of a quantized InvokeLinear8bitLt layer."""
if not torch.cuda.is_available():
pytest.skip("CUDA is not available")
# Set the seed for reproducibility since we are using a pretty tight atol.
torch.manual_seed(3)
orig_layer = torch.nn.Linear(32, 64)
orig_layer_state_dict = orig_layer.state_dict()
# Run inference on the original layer.
x = torch.randn(1, 32)
y = orig_layer(x)
# Prepare a quantized InvokeLinear8bitLt layer.
quantized_layer_1 = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False)
quantized_layer_1.load_state_dict(orig_layer_state_dict)
quantized_layer_1.to("cuda")
# Assert that the InvokeLinear8bitLt layer is quantized.
assert quantized_layer_1.weight.CB is not None
assert quantized_layer_1.weight.SCB is not None
assert quantized_layer_1.weight.CB.dtype == torch.int8
# Run inference on the quantized layer.
y_quantized_1 = quantized_layer_1(x.to("cuda"))
# Save the state dict of the quantized layer.
quantized_layer_1_state_dict = quantized_layer_1.state_dict()
# Load the state dict of the quantized layer into a new quantized layer.
quantized_layer_2 = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False)
quantized_layer_2.load_state_dict(quantized_layer_1_state_dict)
quantized_layer_2.to("cuda")
# Run inference on the new quantized layer.
y_quantized_2 = quantized_layer_2(x.to("cuda"))
# Assert that the inference results are the same.
assert torch.allclose(y, y_quantized_1.to("cpu"), atol=0.05)
assert torch.allclose(y_quantized_1, y_quantized_2, atol=1e-5)