Get alternative GGUF implementation working... barely.

This commit is contained in:
Ryan Dick 2024-09-30 22:36:25 +00:00 committed by Kent Keirsey
parent f347b26999
commit f06765dfba
4 changed files with 126 additions and 77 deletions

View File

@ -37,7 +37,6 @@ from invokeai.backend.model_manager.util.model_util import (
convert_bundle_to_flux_transformer_checkpoint,
)
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.quantization.gguf.torch_patcher import GGUFPatcher
from invokeai.backend.util.silence_warnings import SilenceWarnings
try:
@ -234,7 +233,7 @@ class FluxGGUFCheckpointModel(ModelLoader):
assert isinstance(config, MainGGUFCheckpointConfig)
model_path = Path(config.path)
with SilenceWarnings(), GGUFPatcher().wrap():
with SilenceWarnings():
# Load the state dict and patcher
sd = gguf_sd_loader(model_path)
# Initialize the model

View File

@ -8,7 +8,48 @@ from invokeai.backend.quantization.gguf.utils import (
)
class GGMLTensor:
def dequantize_and_run(func, args, kwargs):
# TODO(ryand): Use the highest input precision of non-quantized inputs instead of hardcoding torch.float32.
dequantized_args = [
a.get_dequantized_tensor(dtype=torch.bfloat16) if hasattr(a, "get_dequantized_tensor") else a for a in args
]
dequantized_kwargs = {
k: v.get_dequantized_tensor(dtype=torch.bfloat16) if hasattr(v, "get_dequantized_tensor") else v
for k, v in kwargs.items()
}
return func(*dequantized_args, **dequantized_kwargs)
def apply_to_quantized_tensor(func, args, kwargs):
ggml_tensor = args[0]
assert isinstance(ggml_tensor, GGMLTensor)
new_data = func(ggml_tensor._data, *args[1:], **kwargs)
return GGMLTensor(new_data, ggml_tensor._ggml_quantization_type, ggml_tensor._tensor_shape)
GGML_TENSOR_OP_TABLE = {
torch.ops.aten.detach.default: apply_to_quantized_tensor,
torch.ops.aten._to_copy.default: apply_to_quantized_tensor,
# --
torch.ops.aten.t.default: dequantize_and_run,
torch.ops.aten.addmm.default: dequantize_and_run,
torch.ops.aten.mul.Tensor: dequantize_and_run,
}
class GGMLTensor(torch.Tensor):
@staticmethod
def __new__(cls, data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType, tensor_shape: torch.Size):
return torch.Tensor._make_wrapper_subclass(
cls,
data.shape,
dtype=data.dtype,
layout=data.layout,
device=data.device,
strides=data.stride(),
storage_offset=data.storage_offset(),
)
def __init__(self, data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType, tensor_shape: torch.Size):
self._data = data
self._ggml_quantization_type = ggml_quantization_type
@ -18,6 +59,17 @@ class GGMLTensor:
def __repr__(self):
return f"GGMLTensor(type={self._ggml_quantization_type.name}, dequantized_shape=({self._tensor_shape})"
def size(self):
return self._tensor_shape
@property
def shape(self):
return self.size()
def requires_grad_(self, requires_grad: bool = True):
# TODO(ryand): Think about whether we should set requires_grad on the underlying tensor.
return self
def get_dequantized_tensor(self, dtype: torch.dtype):
"""Return the dequantized tensor.
@ -37,23 +89,7 @@ class GGMLTensor:
return torch.from_numpy(new).to(self._data.device, dtype=dtype)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
# Most functions will work by simply running on the dequantized tensors, so we assume this as the default
# behavior. Over time, we will have to add special handling for exceptions. For example, .to() will need special
# handling.
if func in []:
return NotImplemented
else:
# TODO(ryand): Use the highest input precision of non-quantized inputs instead of hardcoding torch.float32.
dequantized_args = [
a.get_dequantized_tensor(dtype=torch.float32) if hasattr(a, "get_dequantized_tensor") else a
for a in args
]
dequantized_kwargs = {
k: v.get_dequantized_tensor(dtype=torch.float32) if hasattr(v, "get_dequantized_tensor") else v
for k, v in kwargs.items()
}
return func(*dequantized_args, **dequantized_kwargs)
def __torch_dispatch__(cls, func, types, args, kwargs):
if func in GGML_TENSOR_OP_TABLE:
return GGML_TENSOR_OP_TABLE[func](func, args, kwargs)
raise NotImplementedError(f"Unsupported function {func}")

View File

@ -5,64 +5,78 @@ from pathlib import Path
import gguf
import torch
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.layers import GGUFTensor
from invokeai.backend.quantization.gguf.utils import detect_arch
from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES
def gguf_sd_loader(
path: Path, handle_prefix: str = "model.diffusion_model.", data_type: torch.dtype = torch.bfloat16
) -> dict[str, GGUFTensor]:
"""
Read state dict as fake tensors
"""
def gguf_sd_loader(path: Path) -> dict[str, GGUFTensor]:
reader = gguf.GGUFReader(path)
prefix_len = len(handle_prefix)
tensor_names = {tensor.name for tensor in reader.tensors}
has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
tensors: list[tuple[str, gguf.ReaderTensor]] = []
sd: dict[str, GGUFTensor] = {}
for tensor in reader.tensors:
sd_key = tensor_name = tensor.name
if has_prefix:
if not tensor_name.startswith(handle_prefix):
continue
sd_key = tensor_name[prefix_len:]
tensors.append((sd_key, tensor))
# detect and verify architecture
compat = None
arch_str = None
arch_field = reader.get_field("general.architecture")
if arch_field is not None:
if len(arch_field.types) != 1 or arch_field.types[0] != gguf.GGUFValueType.STRING:
raise TypeError(f"Bad type for GGUF general.architecture key: expected string, got {arch_field.types!r}")
arch_str = str(arch_field.parts[arch_field.data[-1]], encoding="utf-8")
if arch_str not in {"flux"}:
raise ValueError(f"Unexpected architecture type in GGUF file, expected flux, but got {arch_str!r}")
else:
arch_str = detect_arch({val[0] for val in tensors})
compat = "sd.cpp"
# main loading loop
state_dict: dict[str, GGUFTensor] = {}
qtype_dict: dict[str, int] = {}
for sd_key, tensor in tensors:
tensor_name = tensor.name
tensor_type_str = str(tensor.tensor_type)
torch_tensor = torch.from_numpy(tensor.data) # mmap
torch_tensor = torch.from_numpy(tensor.data)
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
# Workaround for stable-diffusion.cpp SDXL detection.
if compat == "sd.cpp" and arch_str == "sdxl":
if tensor_name.endswith((".proj_in.weight", ".proj_out.weight")):
while len(shape) > 2 and shape[-1] == 1:
shape = shape[:-1]
# add to state dict
if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES:
torch_tensor = torch_tensor.view(*shape)
state_dict[sd_key] = GGUFTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape)
qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1
sd[tensor.name] = GGMLTensor(torch_tensor, ggml_quantization_type=tensor.tensor_type, tensor_shape=shape)
return sd
return state_dict
# def gguf_sd_loader(
# path: Path, handle_prefix: str = "model.diffusion_model.", data_type: torch.dtype = torch.bfloat16
# ) -> dict[str, GGUFTensor]:
# """
# Read state dict as fake tensors
# """
# reader = gguf.GGUFReader(path)
# prefix_len = len(handle_prefix)
# tensor_names = {tensor.name for tensor in reader.tensors}
# has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
# tensors: list[tuple[str, gguf.ReaderTensor]] = []
# for tensor in reader.tensors:
# sd_key = tensor_name = tensor.name
# if has_prefix:
# if not tensor_name.startswith(handle_prefix):
# continue
# sd_key = tensor_name[prefix_len:]
# tensors.append((sd_key, tensor))
# # detect and verify architecture
# compat = None
# arch_str = None
# arch_field = reader.get_field("general.architecture")
# if arch_field is not None:
# if len(arch_field.types) != 1 or arch_field.types[0] != gguf.GGUFValueType.STRING:
# raise TypeError(f"Bad type for GGUF general.architecture key: expected string, got {arch_field.types!r}")
# arch_str = str(arch_field.parts[arch_field.data[-1]], encoding="utf-8")
# if arch_str not in {"flux"}:
# raise ValueError(f"Unexpected architecture type in GGUF file, expected flux, but got {arch_str!r}")
# else:
# arch_str = detect_arch({val[0] for val in tensors})
# compat = "sd.cpp"
# # main loading loop
# state_dict: dict[str, GGUFTensor] = {}
# qtype_dict: dict[str, int] = {}
# for sd_key, tensor in tensors:
# tensor_name = tensor.name
# tensor_type_str = str(tensor.tensor_type)
# torch_tensor = torch.from_numpy(tensor.data) # mmap
# shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
# # Workaround for stable-diffusion.cpp SDXL detection.
# if compat == "sd.cpp" and arch_str == "sdxl":
# if tensor_name.endswith((".proj_in.weight", ".proj_out.weight")):
# while len(shape) > 2 and shape[-1] == 1:
# shape = shape[:-1]
# # add to state dict
# if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
# torch_tensor = torch_tensor.view(*shape)
# state_dict[sd_key] = GGUFTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape)
# qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1
# return state_dict

View File

@ -51,10 +51,10 @@ dependencies = [
"sentencepiece==0.2.0",
"spandrel==0.3.4",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"torch==2.2.2",
"torch==2.4.1",
"torchmetrics==0.11.4",
"torchsde==0.2.6",
"torchvision==0.17.2",
"torchvision==0.19.1",
"transformers==4.41.1",
# Core application dependencies, pinned for reproducible builds.