mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-07 03:17:05 +08:00
Partial Loading PR2: Add utils to support partial loading of models from CPU to GPU (#7494)
## Summary This PR adds utilities to support partial loading of models from CPU to GPU. The new utilities are not yet being used by the ModelCache, so there should be no functional behavior changes in this PR. Detailed changes: - Add autocast modules that are designed to wrap common `torch.nn.Module`s and enable them to run with automatic device casting. E.g. a linear layer on the CPU can be executed with an input tensor on the GPU by streaming the weights to the GPU at runtime. - Add unit tests for the aforementioned autocast modules to verify that they work for all supported quantization formats (GGUF, BnB NF4, BnB LLM.int8()). - Add `CachedModelWithPartialLoad` and `CachedModelOnlyFullLoad` classes to manage partial loading at the model level. ## Alternative Implementations Several options were explored for supporting inference on partially-loaded models. The pros/cons of the explored options are summarized here for reference. In the end, wrapper modules were selected as the best overall solution for our use case. Option 1: Re-implement the .forward() methods of modules to add support for device conversions - This is the option implemented in this PR. - This approach is the most manual of the three, but as a result offers the broadest compatibility with unusual model types. It is manual in that we have to explicitly add support for all module types that we wish to support. Fortunately, the list of foundational module types is relatively small (e.g. the current set of implemented layers covers all but 0.04 MB of the full FLUX model.). Option 2: Implement a custom Tensor type that casts tensors to a `target_device` each time the tensor is used - This approach has the nice property that it is injected at the tensor level, and the model does not need to be modified in any way. - One challenge with this approach is handling interactions with other custom tensor types (e.g. GGMLTensor). This problem is solvable, but definitely introduces a layer of complexity. (There are likely to also be some similar issues with interactions with the BnB quantization, but I didn't get as far as testing BnB.) Option 3: Override the `__torch_function__` dispatch calls globally and cast all params to the execution device. - This approach is nice and simple: just apply a global context manager and all operations will happen on the compute device regardless of the device of the participating tensors. - Challenges: - Overriding the `__torch_function__` dispatch calls introduces some overhead even if the tensors are already on the correct device. - It is difficult to manage the autocasting context manager. E.g. it is tempting to apply it to the model's `.forward(...)` method, but we use some models with non-standard entrypoints. And we don't want to end up with nested autocasting context managers. - BnB applies quantization side effects when a param is moved to the GPU - this interacts in unexpected ways with a global context manager. ## QA Instructions Most of the changes in this PR should not impact active code, and thus should not cause any changes to behavior. The main risks come from bumping the bitsandbytes dependency and some minor modifications to the bitsandbytes quantization code. - [x] Regression test bitsandbytes NF4 quantization - [x] Regression test bitsandbytes LLM.int8() quantization - [x] Regression test on MacOS (to ensure that there are no lingering bitsandbytes import errors) I also tested the new utilities for inference on full models in another branch to validate that there were not major issues. This functionality will be tested more thoroughly in a future PR. ## Merge Plan - [x] #7492 should be merged first so that the target branch can be updated to main. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
This commit is contained in:
commit
6bf5b747ce
@ -0,0 +1,93 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class CachedModelOnlyFullLoad:
|
||||
"""A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device.
|
||||
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
|
||||
MPS memory, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int):
|
||||
"""Initialize a CachedModelOnlyFullLoad.
|
||||
Args:
|
||||
model (torch.nn.Module | Any): The model to wrap. Should be on the CPU.
|
||||
compute_device (torch.device): The compute device to move the model to.
|
||||
total_bytes (int): The total size (in bytes) of all the weights in the model.
|
||||
"""
|
||||
# model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases.
|
||||
self._model = model
|
||||
self._compute_device = compute_device
|
||||
self._offload_device = torch.device("cpu")
|
||||
|
||||
# A CPU read-only copy of the model's state dict.
|
||||
self._cpu_state_dict: dict[str, torch.Tensor] | None = None
|
||||
if isinstance(model, torch.nn.Module):
|
||||
self._cpu_state_dict = model.state_dict()
|
||||
|
||||
self._total_bytes = total_bytes
|
||||
self._is_in_vram = False
|
||||
|
||||
@property
|
||||
def model(self) -> torch.nn.Module:
|
||||
return self._model
|
||||
|
||||
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
|
||||
"""Get a read-only copy of the model's state dict in RAM."""
|
||||
# TODO(ryand): Document this better.
|
||||
return self._cpu_state_dict
|
||||
|
||||
def total_bytes(self) -> int:
|
||||
"""Get the total size (in bytes) of all the weights in the model."""
|
||||
return self._total_bytes
|
||||
|
||||
def cur_vram_bytes(self) -> int:
|
||||
"""Get the size (in bytes) of the weights that are currently in VRAM."""
|
||||
if self._is_in_vram:
|
||||
return self._total_bytes
|
||||
else:
|
||||
return 0
|
||||
|
||||
def is_in_vram(self) -> bool:
|
||||
"""Return true if the model is currently in VRAM."""
|
||||
return self._is_in_vram
|
||||
|
||||
def full_load_to_vram(self) -> int:
|
||||
"""Load all weights into VRAM (if supported by the model).
|
||||
Returns:
|
||||
The number of bytes loaded into VRAM.
|
||||
"""
|
||||
if self._is_in_vram:
|
||||
# Already in VRAM.
|
||||
return 0
|
||||
|
||||
if not hasattr(self._model, "to"):
|
||||
# Model doesn't support moving to a device.
|
||||
return 0
|
||||
|
||||
if self._cpu_state_dict is not None:
|
||||
new_state_dict: dict[str, torch.Tensor] = {}
|
||||
for k, v in self._cpu_state_dict.items():
|
||||
new_state_dict[k] = v.to(self._compute_device, copy=True)
|
||||
self._model.load_state_dict(new_state_dict, assign=True)
|
||||
self._model.to(self._compute_device)
|
||||
|
||||
self._is_in_vram = True
|
||||
return self._total_bytes
|
||||
|
||||
def full_unload_from_vram(self) -> int:
|
||||
"""Unload all weights from VRAM.
|
||||
Returns:
|
||||
The number of bytes unloaded from VRAM.
|
||||
"""
|
||||
if not self._is_in_vram:
|
||||
# Already in RAM.
|
||||
return 0
|
||||
|
||||
if self._cpu_state_dict is not None:
|
||||
self._model.load_state_dict(self._cpu_state_dict, assign=True)
|
||||
self._model.to(self._offload_device)
|
||||
|
||||
self._is_in_vram = False
|
||||
return self._total_bytes
|
@ -0,0 +1,201 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
|
||||
AUTOCAST_MODULE_TYPE_MAPPING,
|
||||
apply_custom_layers_to_model,
|
||||
remove_custom_layers_from_model,
|
||||
)
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
def set_nested_attr(obj: object, attr: str, value: object):
|
||||
"""A helper function that extends setattr() to support nested attributes.
|
||||
|
||||
Example:
|
||||
set_nested_attr(model, "module.encoder.conv1.weight", new_conv1_weight)
|
||||
"""
|
||||
attrs = attr.split(".")
|
||||
for attr in attrs[:-1]:
|
||||
obj = getattr(obj, attr)
|
||||
setattr(obj, attrs[-1], value)
|
||||
|
||||
|
||||
class CachedModelWithPartialLoad:
|
||||
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
|
||||
|
||||
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
|
||||
MPS memory, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
|
||||
self._model = model
|
||||
self._compute_device = compute_device
|
||||
|
||||
# A CPU read-only copy of the model's state dict.
|
||||
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
|
||||
|
||||
# TODO(ryand): Handle the case where the model sizes changes after initial load (e.g. due to dtype casting).
|
||||
# Consider how we should handle this for both self._total_bytes and self._cur_vram_bytes.
|
||||
self._total_bytes = sum(calc_tensor_size(p) for p in self._cpu_state_dict.values())
|
||||
self._cur_vram_bytes: int | None = None
|
||||
|
||||
self._modules_that_support_autocast = self._find_modules_that_support_autocast()
|
||||
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast()
|
||||
|
||||
def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]:
|
||||
"""Find all modules that support autocasting."""
|
||||
return {n: m for n, m in self._model.named_modules() if type(m) in AUTOCAST_MODULE_TYPE_MAPPING}
|
||||
|
||||
def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]:
|
||||
keys_in_modules_that_do_not_support_autocast = set()
|
||||
for key in self._cpu_state_dict.keys():
|
||||
for module_name in self._modules_that_support_autocast.keys():
|
||||
if key.startswith(module_name):
|
||||
break
|
||||
else:
|
||||
keys_in_modules_that_do_not_support_autocast.add(key)
|
||||
return keys_in_modules_that_do_not_support_autocast
|
||||
|
||||
def _move_non_persistent_buffers_to_device(self, device: torch.device):
|
||||
"""Move the non-persistent buffers to the target device. These buffers are not included in the state dict,
|
||||
so we need to move them manually.
|
||||
"""
|
||||
# HACK(ryand): Typically, non-persistent buffers are moved when calling module.to(device). We don't move entire
|
||||
# modules, because we manage the devices of individual tensors using the state dict. Since non-persistent
|
||||
# buffers are not included in the state dict, we need to handle them manually. The only way to do this is by
|
||||
# using private torch.nn.Module attributes.
|
||||
for module in self._model.modules():
|
||||
for name, buffer in module.named_buffers():
|
||||
if name in module._non_persistent_buffers_set:
|
||||
module._buffers[name] = buffer.to(device, copy=True)
|
||||
|
||||
@property
|
||||
def model(self) -> torch.nn.Module:
|
||||
return self._model
|
||||
|
||||
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
|
||||
"""Get a read-only copy of the model's state dict in RAM."""
|
||||
# TODO(ryand): Document this better.
|
||||
return self._cpu_state_dict
|
||||
|
||||
def total_bytes(self) -> int:
|
||||
"""Get the total size (in bytes) of all the weights in the model."""
|
||||
return self._total_bytes
|
||||
|
||||
def cur_vram_bytes(self) -> int:
|
||||
"""Get the size (in bytes) of the weights that are currently in VRAM."""
|
||||
if self._cur_vram_bytes is None:
|
||||
cur_state_dict = self._model.state_dict()
|
||||
self._cur_vram_bytes = sum(
|
||||
calc_tensor_size(p) for p in cur_state_dict.values() if p.device.type == self._compute_device.type
|
||||
)
|
||||
return self._cur_vram_bytes
|
||||
|
||||
def full_load_to_vram(self) -> int:
|
||||
"""Load all weights into VRAM."""
|
||||
return self.partial_load_to_vram(self.total_bytes())
|
||||
|
||||
def full_unload_from_vram(self) -> int:
|
||||
"""Unload all weights from VRAM."""
|
||||
return self.partial_unload_from_vram(self.total_bytes())
|
||||
|
||||
@torch.no_grad()
|
||||
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
|
||||
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
|
||||
|
||||
Returns:
|
||||
The number of bytes loaded into VRAM.
|
||||
"""
|
||||
# TODO(ryand): Handle the case where an exception is thrown while loading or unloading weights. At the very
|
||||
# least, we should reset self._cur_vram_bytes to None.
|
||||
|
||||
vram_bytes_loaded = 0
|
||||
|
||||
cur_state_dict = self._model.state_dict()
|
||||
|
||||
# First, process the keys *must* be loaded into VRAM.
|
||||
for key in self._keys_in_modules_that_do_not_support_autocast:
|
||||
param = cur_state_dict[key]
|
||||
if param.device.type == self._compute_device.type:
|
||||
continue
|
||||
|
||||
param_size = calc_tensor_size(param)
|
||||
cur_state_dict[key] = param.to(self._compute_device, copy=True)
|
||||
vram_bytes_loaded += param_size
|
||||
|
||||
if vram_bytes_loaded > vram_bytes_to_load:
|
||||
logger = InvokeAILogger.get_logger()
|
||||
logger.warning(
|
||||
f"Loaded {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were "
|
||||
"requested. This is the minimum set of weights in VRAM required to run the model."
|
||||
)
|
||||
|
||||
# Next, process the keys that can optionally be loaded into VRAM.
|
||||
fully_loaded = True
|
||||
for key, param in cur_state_dict.items():
|
||||
if param.device.type == self._compute_device.type:
|
||||
continue
|
||||
|
||||
param_size = calc_tensor_size(param)
|
||||
if vram_bytes_loaded + param_size > vram_bytes_to_load:
|
||||
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
|
||||
# worth continuing to search for a smaller parameter that would fit?
|
||||
fully_loaded = False
|
||||
continue
|
||||
|
||||
cur_state_dict[key] = param.to(self._compute_device, copy=True)
|
||||
vram_bytes_loaded += param_size
|
||||
|
||||
if vram_bytes_loaded > 0:
|
||||
# We load the entire state dict, not just the parameters that changed, in case there are modules that
|
||||
# override _load_from_state_dict() and do some funky stuff that requires the entire state dict.
|
||||
# Alternatively, in the future, grouping parameters by module could probably solve this problem.
|
||||
self._model.load_state_dict(cur_state_dict, assign=True)
|
||||
|
||||
if self._cur_vram_bytes is not None:
|
||||
self._cur_vram_bytes += vram_bytes_loaded
|
||||
|
||||
if fully_loaded:
|
||||
remove_custom_layers_from_model(self._model)
|
||||
# TODO(ryand): Warn if the self.cur_vram_bytes() and self.total_bytes() are out of sync.
|
||||
else:
|
||||
apply_custom_layers_to_model(self._model)
|
||||
|
||||
# Move all non-persistent buffers to the compute device. These are a weird edge case and do not participate in
|
||||
# the vram_bytes_loaded tracking.
|
||||
self._move_non_persistent_buffers_to_device(self._compute_device)
|
||||
|
||||
return vram_bytes_loaded
|
||||
|
||||
@torch.no_grad()
|
||||
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
|
||||
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded.
|
||||
|
||||
Returns:
|
||||
The number of bytes unloaded from VRAM.
|
||||
"""
|
||||
vram_bytes_freed = 0
|
||||
|
||||
offload_device = "cpu"
|
||||
cur_state_dict = self._model.state_dict()
|
||||
for key, param in cur_state_dict.items():
|
||||
if vram_bytes_freed >= vram_bytes_to_free:
|
||||
break
|
||||
|
||||
if param.device.type == offload_device:
|
||||
continue
|
||||
|
||||
cur_state_dict[key] = self._cpu_state_dict[key]
|
||||
vram_bytes_freed += calc_tensor_size(param)
|
||||
|
||||
if vram_bytes_freed > 0:
|
||||
self._model.load_state_dict(cur_state_dict, assign=True)
|
||||
|
||||
if self._cur_vram_bytes is not None:
|
||||
self._cur_vram_bytes -= vram_bytes_freed
|
||||
|
||||
# We may have gone from a fully-loaded model to a partially-loaded model, so we need to reapply the custom
|
||||
# layers.
|
||||
apply_custom_layers_to_model(self._model)
|
||||
return vram_bytes_freed
|
@ -0,0 +1,50 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
|
||||
# This file contains custom torch.nn.Module classes that support streaming of weights to the target device.
|
||||
# Each class sub-classes the original module type that is is replacing, so the following properties are preserved:
|
||||
# - isinstance(m, torch.nn.OrginalModule) should still work.
|
||||
# - Patching the weights (e.g. for LoRA) should still work if non-quantized.
|
||||
|
||||
|
||||
class CustomLinear(torch.nn.Linear):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
|
||||
class CustomConv1d(torch.nn.Conv1d):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
|
||||
class CustomConv2d(torch.nn.Conv2d):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
|
||||
class CustomGroupNorm(torch.nn.GroupNorm):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
|
||||
|
||||
class CustomEmbedding(torch.nn.Embedding):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
return torch.nn.functional.embedding(
|
||||
input,
|
||||
weight,
|
||||
self.padding_idx,
|
||||
self.max_norm,
|
||||
self.norm_type,
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
@ -0,0 +1,15 @@
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
T = TypeVar("T", torch.Tensor, None, torch.Tensor | None)
|
||||
|
||||
|
||||
def cast_to_device(t: T, to_device: torch.device) -> T:
|
||||
"""Helper function to cast an optional tensor to a target device."""
|
||||
if t is None:
|
||||
return t
|
||||
|
||||
if t.device.type != to_device.type:
|
||||
return t.to(to_device)
|
||||
return t
|
@ -0,0 +1,27 @@
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
|
||||
|
||||
|
||||
class CustomInvokeLinear8bitLt(InvokeLinear8bitLt):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
matmul_state = bnb.MatmulLtState()
|
||||
matmul_state.threshold = self.state.threshold
|
||||
matmul_state.has_fp16_weights = self.state.has_fp16_weights
|
||||
matmul_state.use_pool = self.state.use_pool
|
||||
matmul_state.is_training = self.training
|
||||
# The underlying InvokeInt8Params weight must already be quantized.
|
||||
assert self.weight.CB is not None
|
||||
matmul_state.CB = cast_to_device(self.weight.CB, x.device)
|
||||
matmul_state.SCB = cast_to_device(self.weight.SCB, x.device)
|
||||
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually.
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
# NOTE(ryand): The second parameter should not be needed at all given our expected inference configuration, but
|
||||
# it's dtype field must be accessible, even though it's not used. We pass in self.weight even though it could be
|
||||
# on the wrong device.
|
||||
return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state)
|
@ -0,0 +1,45 @@
|
||||
import copy
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
|
||||
|
||||
|
||||
class CustomInvokeLinearNF4(InvokeLinearNF4):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
bnb.nn.modules.fix_4bit_weight_quant_state_from_module(self)
|
||||
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
if not self.compute_type_is_set:
|
||||
self.set_compute_type(x)
|
||||
self.compute_type_is_set = True
|
||||
|
||||
inp_dtype = x.dtype
|
||||
if self.compute_dtype is not None:
|
||||
x = x.to(self.compute_dtype)
|
||||
|
||||
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
|
||||
|
||||
# HACK(ryand): Casting self.weight to the device also casts the self.weight.quant_state in-place (i.e. it
|
||||
# does not follow the tensor semantics of returning a new copy when converting to a different device). This
|
||||
# means that quant_state elements that started on the CPU would be left on the GPU, which we don't want. To
|
||||
# avoid this side effect we make a shallow copy of the original quant_state so that we can restore it. Fixing
|
||||
# this properly would require more invasive changes to the bitsandbytes library.
|
||||
|
||||
# Make a shallow copy of the quant_state so that we can undo the in-place modification that occurs when casting
|
||||
# to a new device.
|
||||
old_quant_state = copy.copy(self.weight.quant_state)
|
||||
weight = cast_to_device(self.weight, x.device)
|
||||
self.weight.quant_state = old_quant_state
|
||||
|
||||
# For some reason, the quant_state.to(...) implementation fails to cast the quant_state.code field. We do this
|
||||
# manually here.
|
||||
weight.quant_state.code = cast_to_device(weight.quant_state.code, x.device)
|
||||
|
||||
bias = cast_to_device(self.bias, x.device)
|
||||
return bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state).to(inp_dtype)
|
@ -0,0 +1,56 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import (
|
||||
CustomConv1d,
|
||||
CustomConv2d,
|
||||
CustomEmbedding,
|
||||
CustomGroupNorm,
|
||||
CustomLinear,
|
||||
)
|
||||
|
||||
AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]] = {
|
||||
torch.nn.Linear: CustomLinear,
|
||||
torch.nn.Conv1d: CustomConv1d,
|
||||
torch.nn.Conv2d: CustomConv2d,
|
||||
torch.nn.GroupNorm: CustomGroupNorm,
|
||||
torch.nn.Embedding: CustomEmbedding,
|
||||
}
|
||||
|
||||
try:
|
||||
# These dependencies are not expected to be present on MacOS.
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_8_bit_lt import (
|
||||
CustomInvokeLinear8bitLt,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_nf4 import (
|
||||
CustomInvokeLinearNF4,
|
||||
)
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
|
||||
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
|
||||
|
||||
AUTOCAST_MODULE_TYPE_MAPPING[InvokeLinear8bitLt] = CustomInvokeLinear8bitLt
|
||||
AUTOCAST_MODULE_TYPE_MAPPING[InvokeLinearNF4] = CustomInvokeLinearNF4
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def apply_custom_layers_to_model(model: torch.nn.Module):
|
||||
def apply_custom_layers(module: torch.nn.Module):
|
||||
override_type = AUTOCAST_MODULE_TYPE_MAPPING.get(type(module), None)
|
||||
if override_type is not None:
|
||||
module.__class__ = override_type
|
||||
|
||||
# model.apply(...) calls apply_custom_layers(...) on each module in the model.
|
||||
model.apply(apply_custom_layers)
|
||||
|
||||
|
||||
def remove_custom_layers_from_model(model: torch.nn.Module):
|
||||
# Invert AUTOCAST_MODULE_TYPE_MAPPING.
|
||||
original_module_type_mapping = {v: k for k, v in AUTOCAST_MODULE_TYPE_MAPPING.items()}
|
||||
|
||||
def remove_custom_layers(module: torch.nn.Module):
|
||||
override_type = original_module_type_mapping.get(type(module), None)
|
||||
if override_type is not None:
|
||||
module.__class__ = override_type
|
||||
|
||||
# model.apply(...) calls remove_custom_layers(...) on each module in the model.
|
||||
model.apply(remove_custom_layers)
|
@ -25,12 +25,9 @@ class InvokeInt8Params(bnb.nn.Int8Params):
|
||||
self.CB = self.data
|
||||
self.SCB = self.SCB.cuda()
|
||||
else:
|
||||
# we store the 8-bit rows-major weight
|
||||
# we convert this weight to the turning/ampere weight during the first inference pass
|
||||
# We quantize the weight and store in 8bit row-major
|
||||
B = self.data.contiguous().half().cuda(device)
|
||||
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
|
||||
del CBt
|
||||
del SCBt
|
||||
CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B)
|
||||
self.data = CB
|
||||
self.CB = CB
|
||||
self.SCB = SCB
|
||||
@ -55,9 +52,10 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
||||
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
|
||||
scb = state_dict.pop(prefix + "SCB", None)
|
||||
|
||||
# Currently, we only support weight_format=0.
|
||||
weight_format = state_dict.pop(prefix + "weight_format", None)
|
||||
assert weight_format == 0
|
||||
if weight_format is not None:
|
||||
# Currently, we only support weight_format=0.
|
||||
assert weight_format == 0
|
||||
|
||||
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
|
||||
# rather than raising an exception to correctly implement this API.
|
||||
@ -99,6 +97,27 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
||||
new_state.use_pool = self.state.use_pool
|
||||
self.state = new_state
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# The state management in the base bnb.nn.Linear8bitLt is very convoluted. We override the forward method to
|
||||
# try to simplify the state management a bit. We initialize a new MatmulLtState object for each forward pass.
|
||||
# By avoiding persistent state, it is easier to move the layer between devices without worrying about keeping
|
||||
# references to weights on the old device (e.g. self.state.CB).
|
||||
matmul_state = bnb.MatmulLtState()
|
||||
matmul_state.threshold = self.state.threshold
|
||||
matmul_state.has_fp16_weights = self.state.has_fp16_weights
|
||||
matmul_state.use_pool = self.state.use_pool
|
||||
matmul_state.is_training = self.training
|
||||
# The underlying InvokeInt8Params weight must already be quantized.
|
||||
assert self.weight.CB is not None
|
||||
matmul_state.CB = self.weight.CB
|
||||
matmul_state.SCB = self.weight.SCB
|
||||
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually.
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
return bnb.matmul(x, self.weight, bias=self.bias, state=matmul_state)
|
||||
|
||||
|
||||
def _convert_linear_layers_to_llm_8bit(
|
||||
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""
|
||||
|
@ -34,7 +34,7 @@ classifiers = [
|
||||
dependencies = [
|
||||
# Core generation dependencies, pinned for reproducible builds.
|
||||
"accelerate==1.0.1",
|
||||
"bitsandbytes==0.43.3; sys_platform!='darwin'",
|
||||
"bitsandbytes==0.45.0; sys_platform!='darwin'",
|
||||
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"compel==2.0.2",
|
||||
"controlnet-aux==0.0.7",
|
||||
|
@ -0,0 +1,122 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
|
||||
CachedModelOnlyFullLoad,
|
||||
)
|
||||
from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda
|
||||
|
||||
|
||||
class NonTorchModel:
|
||||
"""A model that does not sub-class torch.nn.Module."""
|
||||
|
||||
def __init__(self):
|
||||
self.linear = torch.nn.Linear(10, 32)
|
||||
|
||||
def run_inference(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_total_bytes(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
assert cached_model.total_bytes() == 100
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_is_in_vram(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
assert not cached_model.is_in_vram()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
cached_model.full_load_to_vram()
|
||||
assert cached_model.is_in_vram()
|
||||
assert cached_model.cur_vram_bytes() == 100
|
||||
|
||||
cached_model.full_unload_from_vram()
|
||||
assert not cached_model.is_in_vram()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_and_unload(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
assert cached_model.full_load_to_vram() == 100
|
||||
assert cached_model.is_in_vram()
|
||||
assert all(p.device.type == device for p in cached_model.model.parameters())
|
||||
|
||||
assert cached_model.full_unload_from_vram() == 100
|
||||
assert not cached_model.is_in_vram()
|
||||
assert all(p.device.type == "cpu" for p in cached_model.model.parameters())
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_get_cpu_state_dict(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
assert not cached_model.is_in_vram()
|
||||
|
||||
# The CPU state dict can be accessed and has the expected properties.
|
||||
cpu_state_dict = cached_model.get_cpu_state_dict()
|
||||
assert cpu_state_dict is not None
|
||||
assert len(cpu_state_dict) == len(model.state_dict())
|
||||
assert all(p.device.type == "cpu" for p in cpu_state_dict.values())
|
||||
|
||||
# Full load the model into VRAM.
|
||||
cached_model.full_load_to_vram()
|
||||
assert cached_model.is_in_vram()
|
||||
|
||||
# The CPU state dict is still available, and still on the CPU.
|
||||
cpu_state_dict = cached_model.get_cpu_state_dict()
|
||||
assert cpu_state_dict is not None
|
||||
assert len(cpu_state_dict) == len(model.state_dict())
|
||||
assert all(p.device.type == "cpu" for p in cpu_state_dict.values())
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_and_inference(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
assert not cached_model.is_in_vram()
|
||||
|
||||
# Run inference on the CPU.
|
||||
x = torch.randn(1, 10)
|
||||
output1 = model(x)
|
||||
assert output1.device.type == "cpu"
|
||||
|
||||
# Full load the model into VRAM.
|
||||
cached_model.full_load_to_vram()
|
||||
assert cached_model.is_in_vram()
|
||||
|
||||
# Run inference on the GPU.
|
||||
output2 = model(x.to(device))
|
||||
assert output2.device.type == device
|
||||
|
||||
# The outputs should be the same for both runs.
|
||||
assert torch.allclose(output1, output2.to("cpu"))
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_non_torch_model(device: str):
|
||||
model = NonTorchModel()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
assert not cached_model.is_in_vram()
|
||||
|
||||
# The model does not have a CPU state dict.
|
||||
assert cached_model.get_cpu_state_dict() is None
|
||||
|
||||
# Attempting to load the model into VRAM should have no effect.
|
||||
cached_model.full_load_to_vram()
|
||||
assert not cached_model.is_in_vram()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Attempting to unload the model from VRAM should have no effect.
|
||||
cached_model.full_unload_from_vram()
|
||||
assert not cached_model.is_in_vram()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Running inference on the CPU should work.
|
||||
output1 = model.run_inference(torch.randn(1, 10))
|
||||
assert output1.device.type == "cpu"
|
@ -0,0 +1,274 @@
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
|
||||
CachedModelWithPartialLoad,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import CustomLinear
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_total_bytes(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
linear1_numel = 10 * 32 + 32
|
||||
linear2_numel = 32 * 64 + 64
|
||||
buffer1_numel = 64
|
||||
# Note that the non-persistent buffer (buffer2) is not included in .total_bytes() calculation.
|
||||
assert cached_model.total_bytes() == (linear1_numel + linear2_numel + buffer1_numel) * 4
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_cur_vram_bytes(device: str):
|
||||
model = DummyModule()
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Full load the model into VRAM.
|
||||
cached_model.full_load_to_vram()
|
||||
assert cached_model.cur_vram_bytes() > 0
|
||||
assert cached_model.cur_vram_bytes() == cached_model.total_bytes()
|
||||
assert all(p.device.type == device for p in model.parameters())
|
||||
assert all(p.device.type == device for p in model.buffers())
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_load(device: str):
|
||||
model = DummyModule()
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Partially load the model into VRAM.
|
||||
target_vram_bytes = int(model_total_bytes * 0.6)
|
||||
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
|
||||
|
||||
# Check that the model is partially loaded into VRAM.
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes < model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
assert loaded_bytes == sum(
|
||||
calc_tensor_size(p)
|
||||
for n, p in itertools.chain(model.named_parameters(), model.named_buffers())
|
||||
if p.device.type == device and n != "buffer2"
|
||||
)
|
||||
|
||||
# Check that the model's modules have been patched with CustomLinear layers.
|
||||
assert type(model.linear1) is CustomLinear
|
||||
assert type(model.linear2) is CustomLinear
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_unload(device: str):
|
||||
model = DummyModule()
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Full load the model into VRAM.
|
||||
cached_model.full_load_to_vram()
|
||||
assert cached_model.cur_vram_bytes() == model_total_bytes
|
||||
|
||||
# Partially unload the model from VRAM.
|
||||
bytes_to_free = int(model_total_bytes * 0.4)
|
||||
freed_bytes = cached_model.partial_unload_from_vram(bytes_to_free)
|
||||
|
||||
# Check that the model is partially unloaded from VRAM.
|
||||
assert freed_bytes >= bytes_to_free
|
||||
assert freed_bytes < model_total_bytes
|
||||
assert freed_bytes == model_total_bytes - cached_model.cur_vram_bytes()
|
||||
assert freed_bytes == sum(
|
||||
calc_tensor_size(p) for p in itertools.chain(model.parameters(), model.buffers()) if p.device.type == "cpu"
|
||||
)
|
||||
|
||||
# Check that the model's modules are still patched with CustomLinear layers.
|
||||
assert type(model.linear1) is CustomLinear
|
||||
assert type(model.linear2) is CustomLinear
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_and_unload(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Full load the model into VRAM.
|
||||
loaded_bytes = cached_model.full_load_to_vram()
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes == model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
assert all(p.device.type == device for p in itertools.chain(model.parameters(), model.buffers()))
|
||||
assert type(model.linear1) is torch.nn.Linear
|
||||
assert type(model.linear2) is torch.nn.Linear
|
||||
|
||||
# Full unload the model from VRAM.
|
||||
unloaded_bytes = cached_model.full_unload_from_vram()
|
||||
|
||||
# Check that the model is fully unloaded from VRAM.
|
||||
assert unloaded_bytes > 0
|
||||
assert unloaded_bytes == model_total_bytes
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
# Note that the non-persistent buffer (buffer2) is not required to be unloaded from VRAM.
|
||||
assert all(
|
||||
p.device.type == "cpu"
|
||||
for n, p in itertools.chain(model.named_parameters(), model.named_buffers())
|
||||
if n != "buffer2"
|
||||
)
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_from_partial(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Partially load the model into VRAM.
|
||||
target_vram_bytes = int(model_total_bytes * 0.6)
|
||||
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes < model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
assert type(model.linear1) is CustomLinear
|
||||
assert type(model.linear2) is CustomLinear
|
||||
|
||||
# Full load the rest of the model into VRAM.
|
||||
loaded_bytes_2 = cached_model.full_load_to_vram()
|
||||
assert loaded_bytes_2 > 0
|
||||
assert loaded_bytes_2 < model_total_bytes
|
||||
assert loaded_bytes + loaded_bytes_2 == cached_model.cur_vram_bytes()
|
||||
assert loaded_bytes + loaded_bytes_2 == model_total_bytes
|
||||
assert all(p.device.type == device for p in itertools.chain(model.parameters(), model.buffers()))
|
||||
assert type(model.linear1) is torch.nn.Linear
|
||||
assert type(model.linear2) is torch.nn.Linear
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_unload_from_partial(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Partially load the model into VRAM.
|
||||
target_vram_bytes = int(model_total_bytes * 0.6)
|
||||
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes < model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
|
||||
# Full unload the model from VRAM.
|
||||
unloaded_bytes = cached_model.full_unload_from_vram()
|
||||
assert unloaded_bytes > 0
|
||||
assert unloaded_bytes == loaded_bytes
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
# Note that the non-persistent buffer (buffer2) is not required to be unloaded from VRAM.
|
||||
assert all(
|
||||
p.device.type == "cpu"
|
||||
for n, p in itertools.chain(model.named_parameters(), model.named_buffers())
|
||||
if n != "buffer2"
|
||||
)
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_get_cpu_state_dict(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
|
||||
# Model starts in CPU memory.
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# The CPU state dict can be accessed and has the expected properties.
|
||||
cpu_state_dict = cached_model.get_cpu_state_dict()
|
||||
assert cpu_state_dict is not None
|
||||
assert len(cpu_state_dict) == len(model.state_dict())
|
||||
assert all(p.device.type == "cpu" for p in cpu_state_dict.values())
|
||||
|
||||
# Full load the model into VRAM.
|
||||
cached_model.full_load_to_vram()
|
||||
assert cached_model.cur_vram_bytes() == cached_model.total_bytes()
|
||||
|
||||
# The CPU state dict is still available, and still on the CPU.
|
||||
cpu_state_dict = cached_model.get_cpu_state_dict()
|
||||
assert cpu_state_dict is not None
|
||||
assert len(cpu_state_dict) == len(model.state_dict())
|
||||
assert all(p.device.type == "cpu" for p in cpu_state_dict.values())
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_and_inference(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Run inference on the CPU.
|
||||
x = torch.randn(1, 10)
|
||||
output1 = model(x)
|
||||
assert output1.device.type == "cpu"
|
||||
|
||||
# Full load the model into VRAM.
|
||||
loaded_bytes = cached_model.full_load_to_vram()
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes == model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
assert all(p.device.type == device for p in itertools.chain(model.parameters(), model.buffers()))
|
||||
|
||||
# Run inference on the GPU.
|
||||
output2 = model(x.to(device))
|
||||
assert output2.device.type == device
|
||||
|
||||
# The outputs should be the same for both runs.
|
||||
assert torch.allclose(output1, output2.to("cpu"))
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_load_and_inference(device: str):
|
||||
model = DummyModule()
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Run inference on the CPU.
|
||||
x = torch.randn(1, 10)
|
||||
output1 = model(x)
|
||||
assert output1.device.type == "cpu"
|
||||
|
||||
# Partially load the model into VRAM.
|
||||
target_vram_bytes = int(model_total_bytes * 0.6)
|
||||
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
|
||||
|
||||
# Check that the model is partially loaded into VRAM.
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes < model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
assert loaded_bytes == sum(
|
||||
calc_tensor_size(p)
|
||||
for n, p in itertools.chain(model.named_parameters(), model.named_buffers())
|
||||
if p.device.type == device and n != "buffer2"
|
||||
)
|
||||
# Check that the model's modules have been patched with CustomLinear layers.
|
||||
assert type(model.linear1) is CustomLinear
|
||||
assert type(model.linear2) is CustomLinear
|
||||
|
||||
# Run inference on the GPU.
|
||||
output2 = model(x.to(device))
|
||||
assert output2.device.type == device
|
||||
|
||||
# The output should be the same as the output from the CPU.
|
||||
assert torch.allclose(output1, output2.to("cpu"))
|
@ -0,0 +1,31 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
class DummyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(10, 32)
|
||||
self.linear2 = torch.nn.Linear(32, 64)
|
||||
self.register_buffer("buffer1", torch.ones(64))
|
||||
# Non-persistent buffers are not included in the state dict. We need to make sure that this case is handled
|
||||
# correctly by the partial loading code.
|
||||
self.register_buffer("buffer2", torch.ones(64), persistent=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
x = x + self.buffer1
|
||||
x = x + self.buffer2
|
||||
return x
|
||||
|
||||
|
||||
parameterize_mps_and_cuda = pytest.mark.parametrize(
|
||||
("device"),
|
||||
[
|
||||
pytest.param(
|
||||
"mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.")
|
||||
),
|
||||
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
|
||||
],
|
||||
)
|
@ -0,0 +1,144 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available", allow_module_level=True)
|
||||
else:
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_8_bit_lt import (
|
||||
CustomInvokeLinear8bitLt,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_nf4 import (
|
||||
CustomInvokeLinearNF4,
|
||||
)
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
|
||||
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def linear_8bit_lt_layer():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available")
|
||||
|
||||
torch.manual_seed(1)
|
||||
|
||||
orig_layer = torch.nn.Linear(32, 64)
|
||||
orig_layer_state_dict = orig_layer.state_dict()
|
||||
|
||||
# Prepare a quantized InvokeLinear8bitLt layer.
|
||||
quantized_layer = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False)
|
||||
quantized_layer.load_state_dict(orig_layer_state_dict)
|
||||
quantized_layer.to("cuda")
|
||||
|
||||
# Assert that the InvokeLinear8bitLt layer is quantized.
|
||||
assert quantized_layer.weight.CB is not None
|
||||
assert quantized_layer.weight.SCB is not None
|
||||
assert quantized_layer.weight.CB.dtype == torch.int8
|
||||
|
||||
return quantized_layer
|
||||
|
||||
|
||||
def test_custom_invoke_linear_8bit_lt_all_weights_on_cuda(linear_8bit_lt_layer: InvokeLinear8bitLt):
|
||||
"""Test CustomInvokeLinear8bitLt inference with all weights on the GPU."""
|
||||
# Run inference on the original layer.
|
||||
x = torch.randn(1, 32).to("cuda")
|
||||
y_quantized = linear_8bit_lt_layer(x)
|
||||
|
||||
# Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it.
|
||||
linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt
|
||||
y_custom = linear_8bit_lt_layer(x)
|
||||
|
||||
# Assert that the quantized and custom layers produce the same output.
|
||||
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
|
||||
|
||||
|
||||
def test_custom_invoke_linear_8bit_lt_all_weights_on_cpu(linear_8bit_lt_layer: InvokeLinear8bitLt):
|
||||
"""Test CustomInvokeLinear8bitLt inference with all weights on the CPU (streaming to the GPU)."""
|
||||
# Run inference on the original layer.
|
||||
x = torch.randn(1, 32).to("cuda")
|
||||
y_quantized = linear_8bit_lt_layer(x)
|
||||
|
||||
# Copy the state dict to the CPU and reload it.
|
||||
state_dict = linear_8bit_lt_layer.state_dict()
|
||||
state_dict = {k: v.to("cpu") for k, v in state_dict.items()}
|
||||
linear_8bit_lt_layer.load_state_dict(state_dict)
|
||||
|
||||
# Inference of the original layer should fail.
|
||||
with pytest.raises(RuntimeError):
|
||||
linear_8bit_lt_layer(x)
|
||||
|
||||
# Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it.
|
||||
linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt
|
||||
y_custom = linear_8bit_lt_layer(x)
|
||||
|
||||
# Assert that the quantized and custom layers produce the same output.
|
||||
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def linear_nf4_layer():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available")
|
||||
|
||||
torch.manual_seed(1)
|
||||
|
||||
orig_layer = torch.nn.Linear(64, 16)
|
||||
orig_layer_state_dict = orig_layer.state_dict()
|
||||
|
||||
# Prepare a quantized InvokeLinearNF4 layer.
|
||||
quantized_layer = InvokeLinearNF4(input_features=64, output_features=16)
|
||||
quantized_layer.load_state_dict(orig_layer_state_dict)
|
||||
quantized_layer.to("cuda")
|
||||
|
||||
# Assert that the InvokeLinearNF4 layer is quantized.
|
||||
assert quantized_layer.weight.bnb_quantized
|
||||
|
||||
return quantized_layer
|
||||
|
||||
|
||||
def test_custom_invoke_linear_nf4_all_weights_on_cuda(linear_nf4_layer: InvokeLinearNF4):
|
||||
"""Test CustomInvokeLinearNF4 inference with all weights on the GPU."""
|
||||
# Run inference on the original layer.
|
||||
x = torch.randn(1, 64).to("cuda")
|
||||
y_quantized = linear_nf4_layer(x)
|
||||
|
||||
# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
|
||||
linear_nf4_layer.__class__ = CustomInvokeLinearNF4
|
||||
y_custom = linear_nf4_layer(x)
|
||||
|
||||
# Assert that the quantized and custom layers produce the same output.
|
||||
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
|
||||
|
||||
|
||||
# We run with two different input dimensions, because the NF4 layer follows a different code path depending on the
|
||||
# input dimension, and this has caused issues in the past.
|
||||
@pytest.mark.parametrize("input_dim_0", [1, 2])
|
||||
def test_custom_invoke_linear_nf4_all_weights_on_cpu(linear_nf4_layer: InvokeLinearNF4, input_dim_0: int):
|
||||
"""Test CustomInvokeLinearNF4 inference with all weights on the CPU (streaming to the GPU)."""
|
||||
# Run inference on the original layer.
|
||||
x = torch.randn(input_dim_0, 64).to(device="cuda")
|
||||
y_quantized = linear_nf4_layer(x)
|
||||
|
||||
# Copy the state dict to the CPU and reload it.
|
||||
state_dict = linear_nf4_layer.state_dict()
|
||||
state_dict = {k: v.to("cpu") for k, v in state_dict.items()}
|
||||
linear_nf4_layer.load_state_dict(state_dict)
|
||||
|
||||
# Inference of the original layer should fail.
|
||||
with pytest.raises(RuntimeError):
|
||||
linear_nf4_layer(x)
|
||||
|
||||
# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
|
||||
linear_nf4_layer.__class__ = CustomInvokeLinearNF4
|
||||
y_custom = linear_nf4_layer(x)
|
||||
|
||||
# Assert that the state dict (and the tensors that it references) are still on the CPU.
|
||||
assert all(v.device == torch.device("cpu") for v in state_dict.values())
|
||||
|
||||
# Assert that the weight, bias, and quant_state are all on the CPU.
|
||||
assert linear_nf4_layer.weight.device == torch.device("cpu")
|
||||
assert linear_nf4_layer.bias.device == torch.device("cpu")
|
||||
assert linear_nf4_layer.weight.quant_state.absmax.device == torch.device("cpu")
|
||||
assert linear_nf4_layer.weight.quant_state.code.device == torch.device("cpu")
|
||||
|
||||
# Assert that the quantized and custom layers produce the same output.
|
||||
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
|
@ -0,0 +1,132 @@
|
||||
import os
|
||||
|
||||
import gguf
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
|
||||
apply_custom_layers_to_model,
|
||||
remove_custom_layers_from_model,
|
||||
)
|
||||
from tests.backend.quantization.gguf.test_ggml_tensor import quantize_tensor
|
||||
|
||||
try:
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt, quantize_model_llm_int8
|
||||
except ImportError:
|
||||
# This is expected to fail on MacOS
|
||||
pass
|
||||
|
||||
cuda_and_mps = pytest.mark.parametrize(
|
||||
"device",
|
||||
[
|
||||
pytest.param(
|
||||
torch.device("cuda"), marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
|
||||
),
|
||||
pytest.param(
|
||||
torch.device("mps"),
|
||||
marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="requires MPS device"),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class ModelWithLinearLayer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(32, 64)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
@pytest.fixture(params=["none", "gguf"])
|
||||
def model(request: pytest.FixtureRequest) -> torch.nn.Module:
|
||||
if request.param == "none":
|
||||
return ModelWithLinearLayer()
|
||||
elif request.param == "gguf":
|
||||
# Initialize ModelWithLinearLayer and replace the linear layer weight with a GGML quantized weight.
|
||||
model = ModelWithLinearLayer()
|
||||
ggml_quantized_weight = quantize_tensor(model.linear.weight, gguf.GGMLQuantizationType.Q8_0)
|
||||
model.linear.weight = torch.nn.Parameter(ggml_quantized_weight)
|
||||
return model
|
||||
else:
|
||||
raise ValueError(f"Invalid quantization type: {request.param}")
|
||||
|
||||
|
||||
@cuda_and_mps
|
||||
@torch.no_grad()
|
||||
def test_torch_module_autocast_linear_layer(device: torch.device, model: torch.nn.Module):
|
||||
# Skip this test with MPS on GitHub Actions. It fails but I haven't taken the tie to figure out why. It passes
|
||||
# locally on MacOS.
|
||||
if os.environ.get("GITHUB_ACTIONS") == "true" and device.type == "mps":
|
||||
pytest.skip("This test is flaky on GitHub Actions")
|
||||
|
||||
# Model parameters should start off on the CPU.
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Run inference on the CPU.
|
||||
x = torch.randn(1, 32, device="cpu")
|
||||
expected = model(x)
|
||||
assert expected.device.type == "cpu"
|
||||
|
||||
# Apply the custom layers to the model.
|
||||
apply_custom_layers_to_model(model)
|
||||
|
||||
# Run the model on the device.
|
||||
autocast_result = model(x.to(device))
|
||||
|
||||
# The model output should be on the device.
|
||||
assert autocast_result.device.type == device.type
|
||||
# The model parameters should still be on the CPU.
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
# Remove the custom layers from the model.
|
||||
remove_custom_layers_from_model(model)
|
||||
|
||||
# After removing the custom layers, the model should no longer be able to run inference on the device.
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = model(x.to(device))
|
||||
|
||||
# Run inference again on the CPU.
|
||||
after_result = model(x)
|
||||
|
||||
assert after_result.device.type == "cpu"
|
||||
|
||||
# The results from all inference runs should be the same.
|
||||
assert torch.allclose(autocast_result.to("cpu"), expected, atol=1e-5)
|
||||
assert torch.allclose(after_result, expected, atol=1e-5)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_torch_module_autocast_bnb_llm_int8_linear_layer():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("requires CUDA device")
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
model = ModelWithLinearLayer()
|
||||
model = quantize_model_llm_int8(model, modules_to_not_convert=set())
|
||||
# The act of moving the model to the CUDA device will trigger quantization.
|
||||
model.to("cuda")
|
||||
# Confirm that the layer is quantized.
|
||||
assert isinstance(model.linear, InvokeLinear8bitLt)
|
||||
assert model.linear.weight.CB is not None
|
||||
assert model.linear.weight.SCB is not None
|
||||
|
||||
# Run inference on the GPU.
|
||||
x = torch.randn(1, 32)
|
||||
expected = model(x.to("cuda"))
|
||||
assert expected.device.type == "cuda"
|
||||
|
||||
# Move the model back to the CPU and add the custom layers to the model.
|
||||
model.to("cpu")
|
||||
apply_custom_layers_to_model(model)
|
||||
|
||||
# Run inference with weights being streamed to the GPU.
|
||||
autocast_result = model(x.to("cuda"))
|
||||
assert autocast_result.device.type == "cuda"
|
||||
|
||||
# The results from all inference runs should be the same.
|
||||
assert torch.allclose(autocast_result, expected, atol=1e-5)
|
85
tests/backend/quantization/test_bnb_llm_int8.py
Normal file
85
tests/backend/quantization/test_bnb_llm_int8.py
Normal file
@ -0,0 +1,85 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
try:
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def test_invoke_linear_8bit_lt_quantization():
|
||||
"""Test quantization with InvokeLinear8bitLt."""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available")
|
||||
|
||||
# Set the seed for reproducibility since we are using a pretty tight atol.
|
||||
torch.manual_seed(3)
|
||||
|
||||
orig_layer = torch.nn.Linear(32, 64)
|
||||
orig_layer_state_dict = orig_layer.state_dict()
|
||||
|
||||
# Initialize a InvokeLinear8bitLt layer (it is not quantized yet).
|
||||
quantized_layer = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False)
|
||||
|
||||
# Load the non-quantized layer's state dict into the quantized layer.
|
||||
quantized_layer.load_state_dict(orig_layer_state_dict)
|
||||
|
||||
# Move the InvokeLinear8bitLt layer to the GPU. This triggers quantization.
|
||||
quantized_layer.to("cuda")
|
||||
|
||||
# Assert that the InvokeLinear8bitLt layer is quantized.
|
||||
assert quantized_layer.weight.CB is not None
|
||||
assert quantized_layer.weight.SCB is not None
|
||||
assert quantized_layer.weight.CB.dtype == torch.int8
|
||||
|
||||
# Run inference on both the original and quantized layers.
|
||||
x = torch.randn(1, 32)
|
||||
y = orig_layer(x)
|
||||
y_quantized = quantized_layer(x.to("cuda"))
|
||||
assert y.shape == y_quantized.shape
|
||||
# All within ~20% of each other.
|
||||
assert torch.allclose(y, y_quantized.to("cpu"), atol=0.05)
|
||||
|
||||
|
||||
def test_invoke_linear_8bit_lt_state_dict_roundtrip():
|
||||
"""Test that we can roundtrip the state dict of a quantized InvokeLinear8bitLt layer."""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available")
|
||||
|
||||
# Set the seed for reproducibility since we are using a pretty tight atol.
|
||||
torch.manual_seed(3)
|
||||
|
||||
orig_layer = torch.nn.Linear(32, 64)
|
||||
orig_layer_state_dict = orig_layer.state_dict()
|
||||
|
||||
# Run inference on the original layer.
|
||||
x = torch.randn(1, 32)
|
||||
y = orig_layer(x)
|
||||
|
||||
# Prepare a quantized InvokeLinear8bitLt layer.
|
||||
quantized_layer_1 = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False)
|
||||
quantized_layer_1.load_state_dict(orig_layer_state_dict)
|
||||
quantized_layer_1.to("cuda")
|
||||
|
||||
# Assert that the InvokeLinear8bitLt layer is quantized.
|
||||
assert quantized_layer_1.weight.CB is not None
|
||||
assert quantized_layer_1.weight.SCB is not None
|
||||
assert quantized_layer_1.weight.CB.dtype == torch.int8
|
||||
|
||||
# Run inference on the quantized layer.
|
||||
y_quantized_1 = quantized_layer_1(x.to("cuda"))
|
||||
|
||||
# Save the state dict of the quantized layer.
|
||||
quantized_layer_1_state_dict = quantized_layer_1.state_dict()
|
||||
|
||||
# Load the state dict of the quantized layer into a new quantized layer.
|
||||
quantized_layer_2 = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False)
|
||||
quantized_layer_2.load_state_dict(quantized_layer_1_state_dict)
|
||||
quantized_layer_2.to("cuda")
|
||||
|
||||
# Run inference on the new quantized layer.
|
||||
y_quantized_2 = quantized_layer_2(x.to("cuda"))
|
||||
|
||||
# Assert that the inference results are the same.
|
||||
assert torch.allclose(y, y_quantized_1.to("cpu"), atol=0.05)
|
||||
assert torch.allclose(y_quantized_1, y_quantized_2, atol=1e-5)
|
Loading…
Reference in New Issue
Block a user