fix(config): edge cases in models.yaml migration

When running the configurator, the `legacy_models_conf_path` was stripped when saving the config file. Then the migration logic didn't fire correctly, and the custom models.yaml paths weren't migrated into the db.

- Rework the logic to migrate this path by adding it to the config object as a normal field that is not excluded from serialization.
- Rearrange the models.yaml migration logic to remove the legacy path after migrating, then write the config file. This way, the legacy path doesn't stick around.
- Move the schema version into the config object.
- Back up the config file before attempting migration.
- Add tests to cover this edge case
This commit is contained in:
psychedelicious 2024-03-15 23:21:21 +11:00
parent 1ed1c1fb24
commit e76cc71e81
5 changed files with 92 additions and 55 deletions

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import os
import re
import shutil
from functools import lru_cache
from pathlib import Path
from typing import Any, Literal, Optional
@ -46,12 +47,6 @@ class URLRegexTokenPair(BaseModel):
return v
class ConfigMeta(BaseModel):
"""Metadata for the config file. This is not stored in the config object."""
schema_version: int = CONFIG_SCHEMA_VERSION
class InvokeAIAppConfig(BaseSettings):
"""Invoke's global app configuration.
@ -109,6 +104,10 @@ class InvokeAIAppConfig(BaseSettings):
# fmt: off
# INTERNAL
schema_version: int = Field(default=CONFIG_SCHEMA_VERSION, description="Schema version of the config file. This is not a user-configurable setting.")
legacy_models_yaml_path: Optional[Path] = Field(default=None, description="Path to the legacy models.yaml file. This is not a user-configurable setting.")
# WEB
host: str = Field(default="127.0.0.1", description="IP address to bind to. Use `0.0.0.0` to serve to your local network.")
port: int = Field(default=9090, description="Port to bind to.")
@ -175,11 +174,6 @@ class InvokeAIAppConfig(BaseSettings):
hashing_algorithm: HASHING_ALGORITHMS = Field(default="blake3", description="Model hashing algorthim for model installs. 'blake3' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.")
remote_api_tokens: Optional[list[URLRegexTokenPair]] = Field(default=None, description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.")
# HIDDEN FIELDS
# v4 (MM2) doesn't use `models.yaml` files, but users were able to set paths in the v3 config. When we migrate a
# v3 config, we need to save the path to the models.yaml. This is only used during migration.
legacy_models_yaml_path: Optional[Path] = Field(default=None, description="The `conf_path` setting from a v3 `invokeai.yaml` file. Only present this app session migrated a config file, and it had `conf_test` on it.", exclude=True)
# fmt: on
model_config = SettingsConfigDict(env_prefix="INVOKEAI_", env_ignore_empty=True)
@ -217,8 +211,20 @@ class InvokeAIAppConfig(BaseSettings):
dest_path: Path to write the config to.
"""
with open(dest_path, "w") as file:
meta_dict = {"meta": ConfigMeta().model_dump()}
config_dict = self.model_dump(mode="json", exclude_unset=True, exclude_defaults=True)
# Meta fields should be written in a separate stanza
meta_dict = self.model_dump(mode="json", include={"schema_version"})
# Only include the legacy_models_yaml_path if it's set
if self.legacy_models_yaml_path:
meta_dict.update(self.model_dump(mode="json", include={"legacy_models_yaml_path"}))
# User settings
config_dict = self.model_dump(
mode="json",
exclude_unset=True,
exclude_defaults=True,
exclude={"schema_version", "legacy_models_yaml_path"},
)
file.write("# Internal metadata - do not edit:\n")
file.write(yaml.dump(meta_dict, sort_keys=False))
file.write("\n")
@ -370,11 +376,12 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
if "InvokeAI" in loaded_config_dict:
# This is a v3 config file, attempt to migrate it
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try:
config = migrate_v3_config_dict(loaded_config_dict)
except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
config_path.rename(config_path.with_suffix(".yaml.bak"))
# By excluding defaults, we ensure that the new config file only contains the settings that were explicitly set
config.write_file(config_path)
return config
@ -382,11 +389,11 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
# Attempt to load as a v4 config file
try:
# Meta is not included in the model fields, so we need to validate it separately
config_meta = ConfigMeta.model_validate(loaded_config_dict.pop("meta"))
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
assert (
config_meta.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config_meta.schema_version}"
return InvokeAIAppConfig.model_validate(loaded_config_dict)
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e

View File

@ -292,10 +292,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._app_config.legacy_models_yaml_path or self._app_config.root_path / "configs" / "models.yaml"
)
if not legacy_models_yaml_path.exists():
# No yaml to migrate
return
if legacy_models_yaml_path.exists():
legacy_models_yaml = yaml.safe_load(legacy_models_yaml_path.read_text())
yaml_metadata = legacy_models_yaml.pop("__metadata__")
@ -332,6 +329,10 @@ class ModelInstallService(ModelInstallServiceBase):
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
# Remove `legacy_models_yaml_path` from the config file - we are done with it either way
self._app_config.legacy_models_yaml_path = None
self._app_config.write_file(self._app_config.init_file_path)
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path).resolve() for x in self.record_store.all_models()}
callback = self._scan_install if install else self._scan_register

View File

@ -34,6 +34,7 @@ from transformers import AutoFeatureExtractor
import invokeai.configs as model_configs
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections
from invokeai.backend.model_manager import ModelType
from invokeai.backend.util import choose_precision, choose_torch_device
@ -63,8 +64,7 @@ def get_literal_fields(field: str) -> Tuple[Any]:
# --------------------------globals-----------------------
# Start from a fresh config object - we will read the user's config from file later, and update it with their choices
config = InvokeAIAppConfig()
config = get_config()
PRECISION_CHOICES = get_literal_fields("precision")
DEVICE_CHOICES = get_literal_fields("device")

View File

@ -3,6 +3,8 @@ from typing import Literal, get_args, get_type_hints
from invokeai.app.services.config.config_default import InvokeAIAppConfig
_excluded = {"schema_version", "legacy_models_yaml_path"}
def generate_config_docstrings() -> str:
"""Helper function for mkdocs. Generates a docstring for the InvokeAIAppConfig class.
@ -20,7 +22,7 @@ def generate_config_docstrings() -> str:
type_hints = get_type_hints(InvokeAIAppConfig)
for k, v in InvokeAIAppConfig.model_fields.items():
if v.exclude:
if v.exclude or k in _excluded:
continue
field_type = type_hints.get(k)
extra = ""

View File

@ -9,7 +9,6 @@ from pydantic import ValidationError
from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config, load_and_migrate_config
v4_config = """
meta:
schema_version: 4
host: "192.168.1.1"
@ -17,7 +16,6 @@ port: 8080
"""
invalid_v5_config = """
meta:
schema_version: 5
host: "192.168.1.1"
@ -44,6 +42,12 @@ InvokeAI:
max_vram_cache_size: 50
"""
v3_config_with_bad_values = """
InvokeAI:
Web Server:
port: "ice cream"
"""
invalid_config = """
i like turtles
"""
@ -88,6 +92,29 @@ def test_migrate_v3_config_from_file(tmp_path: Path):
assert not hasattr(config, "esrgan")
def test_migrate_v3_backup(tmp_path: Path):
"""Test the backup of the config file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(v3_config)
load_and_migrate_config(temp_config_file)
assert temp_config_file.with_suffix(".yaml.bak").exists()
assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config
def test_failed_migrate_backup(tmp_path: Path):
"""Test the failed migration of the config file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(v3_config_with_bad_values)
with pytest.raises(RuntimeError):
load_and_migrate_config(temp_config_file)
assert temp_config_file.with_suffix(".yaml.bak").exists()
assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config_with_bad_values
assert temp_config_file.exists()
assert temp_config_file.read_text() == v3_config_with_bad_values
def test_bails_on_invalid_config(tmp_path: Path):
"""Test reading configuration from a file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"