mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-09 04:18:46 +08:00
Get alternative GGUF implementation working... barely.
This commit is contained in:
parent
f347b26999
commit
f06765dfba
@ -37,7 +37,6 @@ from invokeai.backend.model_manager.util.model_util import (
|
|||||||
convert_bundle_to_flux_transformer_checkpoint,
|
convert_bundle_to_flux_transformer_checkpoint,
|
||||||
)
|
)
|
||||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||||
from invokeai.backend.quantization.gguf.torch_patcher import GGUFPatcher
|
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -234,7 +233,7 @@ class FluxGGUFCheckpointModel(ModelLoader):
|
|||||||
assert isinstance(config, MainGGUFCheckpointConfig)
|
assert isinstance(config, MainGGUFCheckpointConfig)
|
||||||
model_path = Path(config.path)
|
model_path = Path(config.path)
|
||||||
|
|
||||||
with SilenceWarnings(), GGUFPatcher().wrap():
|
with SilenceWarnings():
|
||||||
# Load the state dict and patcher
|
# Load the state dict and patcher
|
||||||
sd = gguf_sd_loader(model_path)
|
sd = gguf_sd_loader(model_path)
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
|
@ -8,7 +8,48 @@ from invokeai.backend.quantization.gguf.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class GGMLTensor:
|
def dequantize_and_run(func, args, kwargs):
|
||||||
|
# TODO(ryand): Use the highest input precision of non-quantized inputs instead of hardcoding torch.float32.
|
||||||
|
dequantized_args = [
|
||||||
|
a.get_dequantized_tensor(dtype=torch.bfloat16) if hasattr(a, "get_dequantized_tensor") else a for a in args
|
||||||
|
]
|
||||||
|
dequantized_kwargs = {
|
||||||
|
k: v.get_dequantized_tensor(dtype=torch.bfloat16) if hasattr(v, "get_dequantized_tensor") else v
|
||||||
|
for k, v in kwargs.items()
|
||||||
|
}
|
||||||
|
return func(*dequantized_args, **dequantized_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_to_quantized_tensor(func, args, kwargs):
|
||||||
|
ggml_tensor = args[0]
|
||||||
|
assert isinstance(ggml_tensor, GGMLTensor)
|
||||||
|
new_data = func(ggml_tensor._data, *args[1:], **kwargs)
|
||||||
|
return GGMLTensor(new_data, ggml_tensor._ggml_quantization_type, ggml_tensor._tensor_shape)
|
||||||
|
|
||||||
|
|
||||||
|
GGML_TENSOR_OP_TABLE = {
|
||||||
|
torch.ops.aten.detach.default: apply_to_quantized_tensor,
|
||||||
|
torch.ops.aten._to_copy.default: apply_to_quantized_tensor,
|
||||||
|
# --
|
||||||
|
torch.ops.aten.t.default: dequantize_and_run,
|
||||||
|
torch.ops.aten.addmm.default: dequantize_and_run,
|
||||||
|
torch.ops.aten.mul.Tensor: dequantize_and_run,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class GGMLTensor(torch.Tensor):
|
||||||
|
@staticmethod
|
||||||
|
def __new__(cls, data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType, tensor_shape: torch.Size):
|
||||||
|
return torch.Tensor._make_wrapper_subclass(
|
||||||
|
cls,
|
||||||
|
data.shape,
|
||||||
|
dtype=data.dtype,
|
||||||
|
layout=data.layout,
|
||||||
|
device=data.device,
|
||||||
|
strides=data.stride(),
|
||||||
|
storage_offset=data.storage_offset(),
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType, tensor_shape: torch.Size):
|
def __init__(self, data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType, tensor_shape: torch.Size):
|
||||||
self._data = data
|
self._data = data
|
||||||
self._ggml_quantization_type = ggml_quantization_type
|
self._ggml_quantization_type = ggml_quantization_type
|
||||||
@ -18,6 +59,17 @@ class GGMLTensor:
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"GGMLTensor(type={self._ggml_quantization_type.name}, dequantized_shape=({self._tensor_shape})"
|
return f"GGMLTensor(type={self._ggml_quantization_type.name}, dequantized_shape=({self._tensor_shape})"
|
||||||
|
|
||||||
|
def size(self):
|
||||||
|
return self._tensor_shape
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
return self.size()
|
||||||
|
|
||||||
|
def requires_grad_(self, requires_grad: bool = True):
|
||||||
|
# TODO(ryand): Think about whether we should set requires_grad on the underlying tensor.
|
||||||
|
return self
|
||||||
|
|
||||||
def get_dequantized_tensor(self, dtype: torch.dtype):
|
def get_dequantized_tensor(self, dtype: torch.dtype):
|
||||||
"""Return the dequantized tensor.
|
"""Return the dequantized tensor.
|
||||||
|
|
||||||
@ -37,23 +89,7 @@ class GGMLTensor:
|
|||||||
return torch.from_numpy(new).to(self._data.device, dtype=dtype)
|
return torch.from_numpy(new).to(self._data.device, dtype=dtype)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||||
if kwargs is None:
|
if func in GGML_TENSOR_OP_TABLE:
|
||||||
kwargs = {}
|
return GGML_TENSOR_OP_TABLE[func](func, args, kwargs)
|
||||||
|
raise NotImplementedError(f"Unsupported function {func}")
|
||||||
# Most functions will work by simply running on the dequantized tensors, so we assume this as the default
|
|
||||||
# behavior. Over time, we will have to add special handling for exceptions. For example, .to() will need special
|
|
||||||
# handling.
|
|
||||||
if func in []:
|
|
||||||
return NotImplemented
|
|
||||||
else:
|
|
||||||
# TODO(ryand): Use the highest input precision of non-quantized inputs instead of hardcoding torch.float32.
|
|
||||||
dequantized_args = [
|
|
||||||
a.get_dequantized_tensor(dtype=torch.float32) if hasattr(a, "get_dequantized_tensor") else a
|
|
||||||
for a in args
|
|
||||||
]
|
|
||||||
dequantized_kwargs = {
|
|
||||||
k: v.get_dequantized_tensor(dtype=torch.float32) if hasattr(v, "get_dequantized_tensor") else v
|
|
||||||
for k, v in kwargs.items()
|
|
||||||
}
|
|
||||||
return func(*dequantized_args, **dequantized_kwargs)
|
|
||||||
|
@ -5,64 +5,78 @@ from pathlib import Path
|
|||||||
import gguf
|
import gguf
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||||
from invokeai.backend.quantization.gguf.layers import GGUFTensor
|
from invokeai.backend.quantization.gguf.layers import GGUFTensor
|
||||||
from invokeai.backend.quantization.gguf.utils import detect_arch
|
from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES
|
||||||
|
|
||||||
|
|
||||||
def gguf_sd_loader(
|
def gguf_sd_loader(path: Path) -> dict[str, GGUFTensor]:
|
||||||
path: Path, handle_prefix: str = "model.diffusion_model.", data_type: torch.dtype = torch.bfloat16
|
|
||||||
) -> dict[str, GGUFTensor]:
|
|
||||||
"""
|
|
||||||
Read state dict as fake tensors
|
|
||||||
"""
|
|
||||||
reader = gguf.GGUFReader(path)
|
reader = gguf.GGUFReader(path)
|
||||||
|
|
||||||
prefix_len = len(handle_prefix)
|
sd: dict[str, GGUFTensor] = {}
|
||||||
tensor_names = {tensor.name for tensor in reader.tensors}
|
|
||||||
has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
|
|
||||||
|
|
||||||
tensors: list[tuple[str, gguf.ReaderTensor]] = []
|
|
||||||
for tensor in reader.tensors:
|
for tensor in reader.tensors:
|
||||||
sd_key = tensor_name = tensor.name
|
torch_tensor = torch.from_numpy(tensor.data)
|
||||||
if has_prefix:
|
|
||||||
if not tensor_name.startswith(handle_prefix):
|
|
||||||
continue
|
|
||||||
sd_key = tensor_name[prefix_len:]
|
|
||||||
tensors.append((sd_key, tensor))
|
|
||||||
|
|
||||||
# detect and verify architecture
|
|
||||||
compat = None
|
|
||||||
arch_str = None
|
|
||||||
arch_field = reader.get_field("general.architecture")
|
|
||||||
if arch_field is not None:
|
|
||||||
if len(arch_field.types) != 1 or arch_field.types[0] != gguf.GGUFValueType.STRING:
|
|
||||||
raise TypeError(f"Bad type for GGUF general.architecture key: expected string, got {arch_field.types!r}")
|
|
||||||
arch_str = str(arch_field.parts[arch_field.data[-1]], encoding="utf-8")
|
|
||||||
if arch_str not in {"flux"}:
|
|
||||||
raise ValueError(f"Unexpected architecture type in GGUF file, expected flux, but got {arch_str!r}")
|
|
||||||
else:
|
|
||||||
arch_str = detect_arch({val[0] for val in tensors})
|
|
||||||
compat = "sd.cpp"
|
|
||||||
|
|
||||||
# main loading loop
|
|
||||||
state_dict: dict[str, GGUFTensor] = {}
|
|
||||||
qtype_dict: dict[str, int] = {}
|
|
||||||
for sd_key, tensor in tensors:
|
|
||||||
tensor_name = tensor.name
|
|
||||||
tensor_type_str = str(tensor.tensor_type)
|
|
||||||
torch_tensor = torch.from_numpy(tensor.data) # mmap
|
|
||||||
|
|
||||||
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
|
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
|
||||||
# Workaround for stable-diffusion.cpp SDXL detection.
|
if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES:
|
||||||
if compat == "sd.cpp" and arch_str == "sdxl":
|
|
||||||
if tensor_name.endswith((".proj_in.weight", ".proj_out.weight")):
|
|
||||||
while len(shape) > 2 and shape[-1] == 1:
|
|
||||||
shape = shape[:-1]
|
|
||||||
|
|
||||||
# add to state dict
|
|
||||||
if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
|
|
||||||
torch_tensor = torch_tensor.view(*shape)
|
torch_tensor = torch_tensor.view(*shape)
|
||||||
state_dict[sd_key] = GGUFTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape)
|
sd[tensor.name] = GGMLTensor(torch_tensor, ggml_quantization_type=tensor.tensor_type, tensor_shape=shape)
|
||||||
qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1
|
return sd
|
||||||
|
|
||||||
return state_dict
|
|
||||||
|
# def gguf_sd_loader(
|
||||||
|
# path: Path, handle_prefix: str = "model.diffusion_model.", data_type: torch.dtype = torch.bfloat16
|
||||||
|
# ) -> dict[str, GGUFTensor]:
|
||||||
|
# """
|
||||||
|
# Read state dict as fake tensors
|
||||||
|
# """
|
||||||
|
# reader = gguf.GGUFReader(path)
|
||||||
|
|
||||||
|
# prefix_len = len(handle_prefix)
|
||||||
|
# tensor_names = {tensor.name for tensor in reader.tensors}
|
||||||
|
# has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
|
||||||
|
|
||||||
|
# tensors: list[tuple[str, gguf.ReaderTensor]] = []
|
||||||
|
# for tensor in reader.tensors:
|
||||||
|
# sd_key = tensor_name = tensor.name
|
||||||
|
# if has_prefix:
|
||||||
|
# if not tensor_name.startswith(handle_prefix):
|
||||||
|
# continue
|
||||||
|
# sd_key = tensor_name[prefix_len:]
|
||||||
|
# tensors.append((sd_key, tensor))
|
||||||
|
|
||||||
|
# # detect and verify architecture
|
||||||
|
# compat = None
|
||||||
|
# arch_str = None
|
||||||
|
# arch_field = reader.get_field("general.architecture")
|
||||||
|
# if arch_field is not None:
|
||||||
|
# if len(arch_field.types) != 1 or arch_field.types[0] != gguf.GGUFValueType.STRING:
|
||||||
|
# raise TypeError(f"Bad type for GGUF general.architecture key: expected string, got {arch_field.types!r}")
|
||||||
|
# arch_str = str(arch_field.parts[arch_field.data[-1]], encoding="utf-8")
|
||||||
|
# if arch_str not in {"flux"}:
|
||||||
|
# raise ValueError(f"Unexpected architecture type in GGUF file, expected flux, but got {arch_str!r}")
|
||||||
|
# else:
|
||||||
|
# arch_str = detect_arch({val[0] for val in tensors})
|
||||||
|
# compat = "sd.cpp"
|
||||||
|
|
||||||
|
# # main loading loop
|
||||||
|
# state_dict: dict[str, GGUFTensor] = {}
|
||||||
|
# qtype_dict: dict[str, int] = {}
|
||||||
|
# for sd_key, tensor in tensors:
|
||||||
|
# tensor_name = tensor.name
|
||||||
|
# tensor_type_str = str(tensor.tensor_type)
|
||||||
|
# torch_tensor = torch.from_numpy(tensor.data) # mmap
|
||||||
|
|
||||||
|
# shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
|
||||||
|
# # Workaround for stable-diffusion.cpp SDXL detection.
|
||||||
|
# if compat == "sd.cpp" and arch_str == "sdxl":
|
||||||
|
# if tensor_name.endswith((".proj_in.weight", ".proj_out.weight")):
|
||||||
|
# while len(shape) > 2 and shape[-1] == 1:
|
||||||
|
# shape = shape[:-1]
|
||||||
|
|
||||||
|
# # add to state dict
|
||||||
|
# if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
|
||||||
|
# torch_tensor = torch_tensor.view(*shape)
|
||||||
|
# state_dict[sd_key] = GGUFTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape)
|
||||||
|
# qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1
|
||||||
|
|
||||||
|
# return state_dict
|
||||||
|
@ -51,10 +51,10 @@ dependencies = [
|
|||||||
"sentencepiece==0.2.0",
|
"sentencepiece==0.2.0",
|
||||||
"spandrel==0.3.4",
|
"spandrel==0.3.4",
|
||||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||||
"torch==2.2.2",
|
"torch==2.4.1",
|
||||||
"torchmetrics==0.11.4",
|
"torchmetrics==0.11.4",
|
||||||
"torchsde==0.2.6",
|
"torchsde==0.2.6",
|
||||||
"torchvision==0.17.2",
|
"torchvision==0.19.1",
|
||||||
"transformers==4.41.1",
|
"transformers==4.41.1",
|
||||||
|
|
||||||
# Core application dependencies, pinned for reproducible builds.
|
# Core application dependencies, pinned for reproducible builds.
|
||||||
|
Loading…
Reference in New Issue
Block a user