From c619348f29a1d80be5bf073a9085f670593f7137 Mon Sep 17 00:00:00 2001 From: Billy Date: Fri, 28 Mar 2025 10:35:13 +1100 Subject: [PATCH] Extract ModelOnDisk to its own module --- invokeai/backend/model_manager/config.py | 103 +----------------- .../backend/model_manager/model_on_disk.py | 93 ++++++++++++++++ scripts/strip_models.py | 4 +- tests/test_model_probe.py | 2 +- 4 files changed, 101 insertions(+), 101 deletions(-) create mode 100644 invokeai/backend/model_manager/model_on_disk.py diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index da5236b35c..385a188b17 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -30,15 +30,13 @@ from inspect import isabstract from pathlib import Path from typing import ClassVar, Literal, Optional, TypeAlias, Union -import safetensors.torch -import torch -from picklescan.scanner import scan_file_path from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter from typing_extensions import Annotated, Any, Dict from invokeai.app.util.misc import uuid_string from invokeai.backend.model_hash.hash_validator import validate_hash -from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash +from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk from invokeai.backend.model_manager.taxonomy import ( AnyVariant, BaseModelType, @@ -53,9 +51,7 @@ from invokeai.backend.model_manager.taxonomy import ( SubModelType, ) from invokeai.backend.model_manager.util.model_util import lora_token_vector_length -from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES -from invokeai.backend.util.silence_warnings import SilenceWarnings logger = logging.getLogger(__name__) @@ -69,11 +65,6 @@ class InvalidModelConfigException(Exception): DEFAULTS_PRECISION = Literal["fp16", "fp32"] -class FSLayout(Enum): - FILE = "file" - DIRECTORY = "directory" - - class SubmodelDefinition(BaseModel): path_or_prefix: str model_type: ModelType @@ -104,90 +95,6 @@ class ControlAdapterDefaultSettings(BaseModel): model_config = ConfigDict(extra="forbid") -StateDict: TypeAlias = dict[str | int, Any] - - -class ModelOnDisk: - """A utility class representing a model stored on disk.""" - - def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"): - self.path = path - self.layout = FSLayout.DIRECTORY if path.is_dir() else FSLayout.FILE - if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}: - self.name = path.stem - else: - self.name = path.name - self.hash_algo = hash_algo - self.cache = {} - self._state_dict_cache = {} - - def hash(self) -> str: - return ModelHash(algorithm=self.hash_algo).hash(self.path) - - def size(self) -> int: - if self.layout == FSLayout.FILE: - return self.path.stat().st_size - return sum(file.stat().st_size for file in self.path.rglob("*")) - - def component_paths(self) -> set[Path]: - if self.layout == FSLayout.FILE: - return {self.path} - extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"} - return {f for f in self.path.rglob("*") if f.suffix in extensions} - - def repo_variant(self) -> Optional[ModelRepoVariant]: - if self.layout == FSLayout.FILE: - return None - - weight_files = list(self.path.glob("**/*.safetensors")) - weight_files.extend(list(self.path.glob("**/*.bin"))) - for x in weight_files: - if ".fp16" in x.suffixes: - return ModelRepoVariant.FP16 - if "openvino_model" in x.name: - return ModelRepoVariant.OpenVINO - if "flax_model" in x.name: - return ModelRepoVariant.Flax - if x.suffix == ".onnx": - return ModelRepoVariant.ONNX - return ModelRepoVariant.Default - - def load_state_dict(self, path: Optional[Path] = None) -> StateDict: - if path in self._state_dict_cache: - return self._state_dict_cache[path] - - if not path: - components = list(self.component_paths()) - match components: - case []: - raise ValueError("No weight files found for this model") - case [p]: - path = p - case ps if len(ps) >= 2: - raise ValueError( - f"Multiple weight files found for this model: {ps}. " - f"Please specify the intended file using the 'path' argument" - ) - - with SilenceWarnings(): - if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")): - scan_result = scan_file_path(path) - if scan_result.infected_files != 0 or scan_result.scan_err: - raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.") - checkpoint = torch.load(path, map_location="cpu") - assert isinstance(checkpoint, dict) - elif path.suffix.endswith(".gguf"): - checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32) - elif path.suffix.endswith(".safetensors"): - checkpoint = safetensors.torch.load_file(path) - else: - raise ValueError(f"Unrecognized model extension: {path.suffix}") - - state_dict = checkpoint.get("state_dict", checkpoint) - self._state_dict_cache[path] = state_dict - return state_dict - - class MatchSpeed(int, Enum): """Represents the estimated runtime speed of a config's 'matches' method.""" @@ -429,7 +336,7 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase): @classmethod def matches(cls, mod: ModelOnDisk) -> bool: - if mod.layout == FSLayout.DIRECTORY: + if mod.path.is_dir(): return False # Avoid false positive match against ControlLoRA and Diffusers @@ -483,7 +390,7 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase): @classmethod def matches(cls, mod: ModelOnDisk) -> bool: - if mod.layout == FSLayout.FILE: + if mod.path.is_file(): return cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers suffixes = ["bin", "safetensors"] @@ -667,7 +574,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase): @classmethod def matches(cls, mod: ModelOnDisk) -> bool: - if mod.layout == FSLayout.FILE: + if mod.path.is_file(): return False config_path = mod.path / "config.json" diff --git a/invokeai/backend/model_manager/model_on_disk.py b/invokeai/backend/model_manager/model_on_disk.py new file mode 100644 index 0000000000..ccbf154b1f --- /dev/null +++ b/invokeai/backend/model_manager/model_on_disk.py @@ -0,0 +1,93 @@ +from pathlib import Path +from typing import Any, Optional, TypeAlias + +import safetensors.torch +import torch +from picklescan.scanner import scan_file_path + +from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash +from invokeai.backend.model_manager.taxonomy import ModelRepoVariant +from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader +from invokeai.backend.util.silence_warnings import SilenceWarnings + +StateDict: TypeAlias = dict[str | int, Any] + + +class ModelOnDisk: + """A utility class representing a model stored on disk.""" + + def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"): + self.path = path + if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}: + self.name = path.stem + else: + self.name = path.name + self.hash_algo = hash_algo + self.cache = {} + self._state_dict_cache = {} + + def hash(self) -> str: + return ModelHash(algorithm=self.hash_algo).hash(self.path) + + def size(self) -> int: + if self.path.is_file(): + return self.path.stat().st_size + return sum(file.stat().st_size for file in self.path.rglob("*")) + + def component_paths(self) -> set[Path]: + if self.path.is_file(): + return {self.path} + extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"} + return {f for f in self.path.rglob("*") if f.suffix in extensions} + + def repo_variant(self) -> Optional[ModelRepoVariant]: + if self.path.is_file(): + return None + + weight_files = list(self.path.glob("**/*.safetensors")) + weight_files.extend(list(self.path.glob("**/*.bin"))) + for x in weight_files: + if ".fp16" in x.suffixes: + return ModelRepoVariant.FP16 + if "openvino_model" in x.name: + return ModelRepoVariant.OpenVINO + if "flax_model" in x.name: + return ModelRepoVariant.Flax + if x.suffix == ".onnx": + return ModelRepoVariant.ONNX + return ModelRepoVariant.Default + + def load_state_dict(self, path: Optional[Path] = None) -> StateDict: + if path in self._state_dict_cache: + return self._state_dict_cache[path] + + if not path: + components = list(self.component_paths()) + match components: + case []: + raise ValueError("No weight files found for this model") + case [p]: + path = p + case ps if len(ps) >= 2: + raise ValueError( + f"Multiple weight files found for this model: {ps}. " + f"Please specify the intended file using the 'path' argument" + ) + + with SilenceWarnings(): + if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")): + scan_result = scan_file_path(path) + if scan_result.infected_files != 0 or scan_result.scan_err: + raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.") + checkpoint = torch.load(path, map_location="cpu") + assert isinstance(checkpoint, dict) + elif path.suffix.endswith(".gguf"): + checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32) + elif path.suffix.endswith(".safetensors"): + checkpoint = safetensors.torch.load_file(path) + else: + raise ValueError(f"Unrecognized model extension: {path.suffix}") + + state_dict = checkpoint.get("state_dict", checkpoint) + self._state_dict_cache[path] = state_dict + return state_dict diff --git a/scripts/strip_models.py b/scripts/strip_models.py index 8e1d259ff3..a2ac7804d8 100644 --- a/scripts/strip_models.py +++ b/scripts/strip_models.py @@ -22,7 +22,7 @@ from pathlib import Path import humanize import torch -from invokeai.backend.model_manager.config import FSLayout, ModelOnDisk +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk from invokeai.backend.model_manager.search import ModelSearch @@ -62,7 +62,7 @@ def load_stripped_model(path: Path, *args, **kwargs): def create_stripped_model(original_model_path: Path, stripped_model_path: Path) -> ModelOnDisk: original = ModelOnDisk(original_model_path) - if original.layout == FSLayout.FILE: + if original.path.is_file(): shutil.copy2(original.path, stripped_model_path) else: shutil.copytree(original.path, stripped_model_path, dirs_exist_ok=True) diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index c4a0b34062..a808b04393 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -17,7 +17,6 @@ from invokeai.backend.model_manager.config import ( MainDiffusersConfig, ModelConfigBase, ModelConfigFactory, - ModelOnDisk, get_model_discriminator_value, ) from invokeai.backend.model_manager.legacy_probe import ( @@ -27,6 +26,7 @@ from invokeai.backend.model_manager.legacy_probe import ( get_default_settings_control_adapters, get_default_settings_main, ) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.util.logging import InvokeAILogger