mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-04-04 22:43:40 +08:00
94 lines
3.7 KiB
Python
94 lines
3.7 KiB
Python
from pathlib import Path
|
|
from typing import Any, Optional, TypeAlias
|
|
|
|
import safetensors.torch
|
|
import torch
|
|
from picklescan.scanner import scan_file_path
|
|
|
|
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
|
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
|
|
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
|
|
|
StateDict: TypeAlias = dict[str | int, Any]
|
|
|
|
|
|
class ModelOnDisk:
|
|
"""A utility class representing a model stored on disk."""
|
|
|
|
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
|
|
self.path = path
|
|
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
|
|
self.name = path.stem
|
|
else:
|
|
self.name = path.name
|
|
self.hash_algo = hash_algo
|
|
self.cache = {}
|
|
self._state_dict_cache = {}
|
|
|
|
def hash(self) -> str:
|
|
return ModelHash(algorithm=self.hash_algo).hash(self.path)
|
|
|
|
def size(self) -> int:
|
|
if self.path.is_file():
|
|
return self.path.stat().st_size
|
|
return sum(file.stat().st_size for file in self.path.rglob("*"))
|
|
|
|
def component_paths(self) -> set[Path]:
|
|
if self.path.is_file():
|
|
return {self.path}
|
|
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
|
|
return {f for f in self.path.rglob("*") if f.suffix in extensions}
|
|
|
|
def repo_variant(self) -> Optional[ModelRepoVariant]:
|
|
if self.path.is_file():
|
|
return None
|
|
|
|
weight_files = list(self.path.glob("**/*.safetensors"))
|
|
weight_files.extend(list(self.path.glob("**/*.bin")))
|
|
for x in weight_files:
|
|
if ".fp16" in x.suffixes:
|
|
return ModelRepoVariant.FP16
|
|
if "openvino_model" in x.name:
|
|
return ModelRepoVariant.OpenVINO
|
|
if "flax_model" in x.name:
|
|
return ModelRepoVariant.Flax
|
|
if x.suffix == ".onnx":
|
|
return ModelRepoVariant.ONNX
|
|
return ModelRepoVariant.Default
|
|
|
|
def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
|
|
if path in self._state_dict_cache:
|
|
return self._state_dict_cache[path]
|
|
|
|
if not path:
|
|
components = list(self.component_paths())
|
|
match components:
|
|
case []:
|
|
raise ValueError("No weight files found for this model")
|
|
case [p]:
|
|
path = p
|
|
case ps if len(ps) >= 2:
|
|
raise ValueError(
|
|
f"Multiple weight files found for this model: {ps}. "
|
|
f"Please specify the intended file using the 'path' argument"
|
|
)
|
|
|
|
with SilenceWarnings():
|
|
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
|
|
scan_result = scan_file_path(path)
|
|
if scan_result.infected_files != 0 or scan_result.scan_err:
|
|
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
|
|
checkpoint = torch.load(path, map_location="cpu")
|
|
assert isinstance(checkpoint, dict)
|
|
elif path.suffix.endswith(".gguf"):
|
|
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32)
|
|
elif path.suffix.endswith(".safetensors"):
|
|
checkpoint = safetensors.torch.load_file(path)
|
|
else:
|
|
raise ValueError(f"Unrecognized model extension: {path.suffix}")
|
|
|
|
state_dict = checkpoint.get("state_dict", checkpoint)
|
|
self._state_dict_cache[path] = state_dict
|
|
return state_dict
|