make download and convert cache keys safe for filename length

This commit is contained in:
Lincoln Stein 2024-04-28 12:24:36 -04:00
parent bb04f496e0
commit a26667d3ca
4 changed files with 36 additions and 10 deletions

View File

@ -7,6 +7,7 @@ from pathlib import Path
from invokeai.backend.util import GIG, directory_size
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.util import safe_filename
from .convert_cache_base import ModelConvertCacheBase
@ -35,6 +36,7 @@ class ModelConvertCache(ModelConvertCacheBase):
def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key."""
key = safe_filename(self._cache_path, key)
return self._cache_path / key
def make_room(self, size: float) -> None:

View File

@ -19,7 +19,6 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.util import slugify
# TO DO: The loader is not thread safe!
@ -85,7 +84,7 @@ class ModelLoader(ModelLoaderBase):
except IndexError:
pass
cache_path: Path = self._convert_cache.cache_path(slugify(model_path))
cache_path: Path = self._convert_cache.cache_path(str(model_path))
if self._needs_conversion(config, model_path, cache_path):
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
else:

View File

@ -18,7 +18,8 @@ def slugify(value: str, allow_unicode: bool = False) -> str:
"""
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
dashes to single dashes. Remove characters that aren't alphanumerics,
underscores, or hyphens. Convert to lowercase. Also strip leading and
underscores, or hyphens. Replace slashes with underscores.
Convert to lowercase. Also strip leading and
trailing whitespace, dashes, and underscores.
Adapted from Django: https://github.com/django/django/blob/main/django/utils/text.py
@ -29,10 +30,17 @@ def slugify(value: str, allow_unicode: bool = False) -> str:
else:
value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
value = re.sub(r"[/]", "_", value.lower())
value = re.sub(r"[^\w\s-]", "", value.lower())
value = re.sub(r"[^.\w\s-]", "", value.lower())
return re.sub(r"[-\s]+", "-", value).strip("-_")
def safe_filename(directory: Path, value: str) -> str:
"""Make a string safe to use as a filename."""
escaped_string = slugify(value)
max_name_length = os.pathconf(directory, "PC_NAME_MAX")
return escaped_string[len(escaped_string) - max_name_length :]
def directory_size(directory: Path) -> int:
"""
Return the aggregate size of all files in a directory (bytes).

View File

@ -1,6 +1,7 @@
from pathlib import Path
import pytest
import torch
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_manager import ModelManagerServiceBase
@ -22,7 +23,7 @@ def mock_context(
)
def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path):
def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) -> None:
downloaded_path = mock_context.models.download_and_cache_ckpt(
"https://www.test.foo/download/test_embedding.safetensors"
)
@ -37,13 +38,29 @@ def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path)
assert downloaded_path == downloaded_path_2
def test_download_and_load(mock_context: InvocationContext):
def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -> None:
downloaded_path = mock_context.models.download_and_cache_ckpt(
"https://www.test.foo/download/test_embedding.safetensors"
)
loaded_model_1 = mock_context.models.load_ckpt_from_path(downloaded_path)
assert isinstance(loaded_model_1, LoadedModel)
loaded_model_2 = mock_context.models.load_ckpt_from_path(downloaded_path)
assert isinstance(loaded_model_2, LoadedModel)
assert loaded_model_1.model is loaded_model_2.model
loaded_model_3 = mock_context.models.load_ckpt_from_path(embedding_file)
assert isinstance(loaded_model_3, LoadedModel)
assert loaded_model_1.model is not loaded_model_3.model
assert isinstance(loaded_model_1.model, dict)
assert isinstance(loaded_model_3.model, dict)
assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"])
def test_download_and_load(mock_context: InvocationContext) -> None:
loaded_model_1 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors")
assert isinstance(loaded_model_1, LoadedModel)
loaded_model_2 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors")
assert isinstance(loaded_model_2, LoadedModel)
with loaded_model_1 as model_1, loaded_model_2 as model_2:
assert model_1 == model_2
assert isinstance(model_1, dict)
assert loaded_model_1.model is loaded_model_2.model # should be cached copy