mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-08 11:57:36 +08:00
Get alternative GGUF implementation working... barely.
This commit is contained in:
parent
f347b26999
commit
f06765dfba
@ -37,7 +37,6 @@ from invokeai.backend.model_manager.util.model_util import (
|
||||
convert_bundle_to_flux_transformer_checkpoint,
|
||||
)
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.quantization.gguf.torch_patcher import GGUFPatcher
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
try:
|
||||
@ -234,7 +233,7 @@ class FluxGGUFCheckpointModel(ModelLoader):
|
||||
assert isinstance(config, MainGGUFCheckpointConfig)
|
||||
model_path = Path(config.path)
|
||||
|
||||
with SilenceWarnings(), GGUFPatcher().wrap():
|
||||
with SilenceWarnings():
|
||||
# Load the state dict and patcher
|
||||
sd = gguf_sd_loader(model_path)
|
||||
# Initialize the model
|
||||
|
@ -8,7 +8,48 @@ from invokeai.backend.quantization.gguf.utils import (
|
||||
)
|
||||
|
||||
|
||||
class GGMLTensor:
|
||||
def dequantize_and_run(func, args, kwargs):
|
||||
# TODO(ryand): Use the highest input precision of non-quantized inputs instead of hardcoding torch.float32.
|
||||
dequantized_args = [
|
||||
a.get_dequantized_tensor(dtype=torch.bfloat16) if hasattr(a, "get_dequantized_tensor") else a for a in args
|
||||
]
|
||||
dequantized_kwargs = {
|
||||
k: v.get_dequantized_tensor(dtype=torch.bfloat16) if hasattr(v, "get_dequantized_tensor") else v
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
return func(*dequantized_args, **dequantized_kwargs)
|
||||
|
||||
|
||||
def apply_to_quantized_tensor(func, args, kwargs):
|
||||
ggml_tensor = args[0]
|
||||
assert isinstance(ggml_tensor, GGMLTensor)
|
||||
new_data = func(ggml_tensor._data, *args[1:], **kwargs)
|
||||
return GGMLTensor(new_data, ggml_tensor._ggml_quantization_type, ggml_tensor._tensor_shape)
|
||||
|
||||
|
||||
GGML_TENSOR_OP_TABLE = {
|
||||
torch.ops.aten.detach.default: apply_to_quantized_tensor,
|
||||
torch.ops.aten._to_copy.default: apply_to_quantized_tensor,
|
||||
# --
|
||||
torch.ops.aten.t.default: dequantize_and_run,
|
||||
torch.ops.aten.addmm.default: dequantize_and_run,
|
||||
torch.ops.aten.mul.Tensor: dequantize_and_run,
|
||||
}
|
||||
|
||||
|
||||
class GGMLTensor(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType, tensor_shape: torch.Size):
|
||||
return torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
data.shape,
|
||||
dtype=data.dtype,
|
||||
layout=data.layout,
|
||||
device=data.device,
|
||||
strides=data.stride(),
|
||||
storage_offset=data.storage_offset(),
|
||||
)
|
||||
|
||||
def __init__(self, data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType, tensor_shape: torch.Size):
|
||||
self._data = data
|
||||
self._ggml_quantization_type = ggml_quantization_type
|
||||
@ -18,6 +59,17 @@ class GGMLTensor:
|
||||
def __repr__(self):
|
||||
return f"GGMLTensor(type={self._ggml_quantization_type.name}, dequantized_shape=({self._tensor_shape})"
|
||||
|
||||
def size(self):
|
||||
return self._tensor_shape
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.size()
|
||||
|
||||
def requires_grad_(self, requires_grad: bool = True):
|
||||
# TODO(ryand): Think about whether we should set requires_grad on the underlying tensor.
|
||||
return self
|
||||
|
||||
def get_dequantized_tensor(self, dtype: torch.dtype):
|
||||
"""Return the dequantized tensor.
|
||||
|
||||
@ -37,23 +89,7 @@ class GGMLTensor:
|
||||
return torch.from_numpy(new).to(self._data.device, dtype=dtype)
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
# Most functions will work by simply running on the dequantized tensors, so we assume this as the default
|
||||
# behavior. Over time, we will have to add special handling for exceptions. For example, .to() will need special
|
||||
# handling.
|
||||
if func in []:
|
||||
return NotImplemented
|
||||
else:
|
||||
# TODO(ryand): Use the highest input precision of non-quantized inputs instead of hardcoding torch.float32.
|
||||
dequantized_args = [
|
||||
a.get_dequantized_tensor(dtype=torch.float32) if hasattr(a, "get_dequantized_tensor") else a
|
||||
for a in args
|
||||
]
|
||||
dequantized_kwargs = {
|
||||
k: v.get_dequantized_tensor(dtype=torch.float32) if hasattr(v, "get_dequantized_tensor") else v
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
return func(*dequantized_args, **dequantized_kwargs)
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if func in GGML_TENSOR_OP_TABLE:
|
||||
return GGML_TENSOR_OP_TABLE[func](func, args, kwargs)
|
||||
raise NotImplementedError(f"Unsupported function {func}")
|
||||
|
@ -5,64 +5,78 @@ from pathlib import Path
|
||||
import gguf
|
||||
import torch
|
||||
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
from invokeai.backend.quantization.gguf.layers import GGUFTensor
|
||||
from invokeai.backend.quantization.gguf.utils import detect_arch
|
||||
from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES
|
||||
|
||||
|
||||
def gguf_sd_loader(
|
||||
path: Path, handle_prefix: str = "model.diffusion_model.", data_type: torch.dtype = torch.bfloat16
|
||||
) -> dict[str, GGUFTensor]:
|
||||
"""
|
||||
Read state dict as fake tensors
|
||||
"""
|
||||
def gguf_sd_loader(path: Path) -> dict[str, GGUFTensor]:
|
||||
reader = gguf.GGUFReader(path)
|
||||
|
||||
prefix_len = len(handle_prefix)
|
||||
tensor_names = {tensor.name for tensor in reader.tensors}
|
||||
has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
|
||||
|
||||
tensors: list[tuple[str, gguf.ReaderTensor]] = []
|
||||
sd: dict[str, GGUFTensor] = {}
|
||||
for tensor in reader.tensors:
|
||||
sd_key = tensor_name = tensor.name
|
||||
if has_prefix:
|
||||
if not tensor_name.startswith(handle_prefix):
|
||||
continue
|
||||
sd_key = tensor_name[prefix_len:]
|
||||
tensors.append((sd_key, tensor))
|
||||
|
||||
# detect and verify architecture
|
||||
compat = None
|
||||
arch_str = None
|
||||
arch_field = reader.get_field("general.architecture")
|
||||
if arch_field is not None:
|
||||
if len(arch_field.types) != 1 or arch_field.types[0] != gguf.GGUFValueType.STRING:
|
||||
raise TypeError(f"Bad type for GGUF general.architecture key: expected string, got {arch_field.types!r}")
|
||||
arch_str = str(arch_field.parts[arch_field.data[-1]], encoding="utf-8")
|
||||
if arch_str not in {"flux"}:
|
||||
raise ValueError(f"Unexpected architecture type in GGUF file, expected flux, but got {arch_str!r}")
|
||||
else:
|
||||
arch_str = detect_arch({val[0] for val in tensors})
|
||||
compat = "sd.cpp"
|
||||
|
||||
# main loading loop
|
||||
state_dict: dict[str, GGUFTensor] = {}
|
||||
qtype_dict: dict[str, int] = {}
|
||||
for sd_key, tensor in tensors:
|
||||
tensor_name = tensor.name
|
||||
tensor_type_str = str(tensor.tensor_type)
|
||||
torch_tensor = torch.from_numpy(tensor.data) # mmap
|
||||
|
||||
torch_tensor = torch.from_numpy(tensor.data)
|
||||
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
|
||||
# Workaround for stable-diffusion.cpp SDXL detection.
|
||||
if compat == "sd.cpp" and arch_str == "sdxl":
|
||||
if tensor_name.endswith((".proj_in.weight", ".proj_out.weight")):
|
||||
while len(shape) > 2 and shape[-1] == 1:
|
||||
shape = shape[:-1]
|
||||
|
||||
# add to state dict
|
||||
if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
|
||||
if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES:
|
||||
torch_tensor = torch_tensor.view(*shape)
|
||||
state_dict[sd_key] = GGUFTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape)
|
||||
qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1
|
||||
sd[tensor.name] = GGMLTensor(torch_tensor, ggml_quantization_type=tensor.tensor_type, tensor_shape=shape)
|
||||
return sd
|
||||
|
||||
return state_dict
|
||||
|
||||
# def gguf_sd_loader(
|
||||
# path: Path, handle_prefix: str = "model.diffusion_model.", data_type: torch.dtype = torch.bfloat16
|
||||
# ) -> dict[str, GGUFTensor]:
|
||||
# """
|
||||
# Read state dict as fake tensors
|
||||
# """
|
||||
# reader = gguf.GGUFReader(path)
|
||||
|
||||
# prefix_len = len(handle_prefix)
|
||||
# tensor_names = {tensor.name for tensor in reader.tensors}
|
||||
# has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
|
||||
|
||||
# tensors: list[tuple[str, gguf.ReaderTensor]] = []
|
||||
# for tensor in reader.tensors:
|
||||
# sd_key = tensor_name = tensor.name
|
||||
# if has_prefix:
|
||||
# if not tensor_name.startswith(handle_prefix):
|
||||
# continue
|
||||
# sd_key = tensor_name[prefix_len:]
|
||||
# tensors.append((sd_key, tensor))
|
||||
|
||||
# # detect and verify architecture
|
||||
# compat = None
|
||||
# arch_str = None
|
||||
# arch_field = reader.get_field("general.architecture")
|
||||
# if arch_field is not None:
|
||||
# if len(arch_field.types) != 1 or arch_field.types[0] != gguf.GGUFValueType.STRING:
|
||||
# raise TypeError(f"Bad type for GGUF general.architecture key: expected string, got {arch_field.types!r}")
|
||||
# arch_str = str(arch_field.parts[arch_field.data[-1]], encoding="utf-8")
|
||||
# if arch_str not in {"flux"}:
|
||||
# raise ValueError(f"Unexpected architecture type in GGUF file, expected flux, but got {arch_str!r}")
|
||||
# else:
|
||||
# arch_str = detect_arch({val[0] for val in tensors})
|
||||
# compat = "sd.cpp"
|
||||
|
||||
# # main loading loop
|
||||
# state_dict: dict[str, GGUFTensor] = {}
|
||||
# qtype_dict: dict[str, int] = {}
|
||||
# for sd_key, tensor in tensors:
|
||||
# tensor_name = tensor.name
|
||||
# tensor_type_str = str(tensor.tensor_type)
|
||||
# torch_tensor = torch.from_numpy(tensor.data) # mmap
|
||||
|
||||
# shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
|
||||
# # Workaround for stable-diffusion.cpp SDXL detection.
|
||||
# if compat == "sd.cpp" and arch_str == "sdxl":
|
||||
# if tensor_name.endswith((".proj_in.weight", ".proj_out.weight")):
|
||||
# while len(shape) > 2 and shape[-1] == 1:
|
||||
# shape = shape[:-1]
|
||||
|
||||
# # add to state dict
|
||||
# if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
|
||||
# torch_tensor = torch_tensor.view(*shape)
|
||||
# state_dict[sd_key] = GGUFTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape)
|
||||
# qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1
|
||||
|
||||
# return state_dict
|
||||
|
@ -51,10 +51,10 @@ dependencies = [
|
||||
"sentencepiece==0.2.0",
|
||||
"spandrel==0.3.4",
|
||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||
"torch==2.2.2",
|
||||
"torch==2.4.1",
|
||||
"torchmetrics==0.11.4",
|
||||
"torchsde==0.2.6",
|
||||
"torchvision==0.17.2",
|
||||
"torchvision==0.19.1",
|
||||
"transformers==4.41.1",
|
||||
|
||||
# Core application dependencies, pinned for reproducible builds.
|
||||
|
Loading…
Reference in New Issue
Block a user