mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-09 04:18:46 +08:00
make download and convert cache keys safe for filename length
This commit is contained in:
parent
bb04f496e0
commit
a26667d3ca
@ -7,6 +7,7 @@ from pathlib import Path
|
||||
|
||||
from invokeai.backend.util import GIG, directory_size
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util.util import safe_filename
|
||||
|
||||
from .convert_cache_base import ModelConvertCacheBase
|
||||
|
||||
@ -35,6 +36,7 @@ class ModelConvertCache(ModelConvertCacheBase):
|
||||
|
||||
def cache_path(self, key: str) -> Path:
|
||||
"""Return the path for a model with the indicated key."""
|
||||
key = safe_filename(self._cache_path, key)
|
||||
return self._cache_path / key
|
||||
|
||||
def make_room(self, size: float) -> None:
|
||||
|
@ -19,7 +19,6 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.util import slugify
|
||||
|
||||
|
||||
# TO DO: The loader is not thread safe!
|
||||
@ -85,7 +84,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
cache_path: Path = self._convert_cache.cache_path(slugify(model_path))
|
||||
cache_path: Path = self._convert_cache.cache_path(str(model_path))
|
||||
if self._needs_conversion(config, model_path, cache_path):
|
||||
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
|
||||
else:
|
||||
|
@ -18,7 +18,8 @@ def slugify(value: str, allow_unicode: bool = False) -> str:
|
||||
"""
|
||||
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
|
||||
dashes to single dashes. Remove characters that aren't alphanumerics,
|
||||
underscores, or hyphens. Convert to lowercase. Also strip leading and
|
||||
underscores, or hyphens. Replace slashes with underscores.
|
||||
Convert to lowercase. Also strip leading and
|
||||
trailing whitespace, dashes, and underscores.
|
||||
|
||||
Adapted from Django: https://github.com/django/django/blob/main/django/utils/text.py
|
||||
@ -29,10 +30,17 @@ def slugify(value: str, allow_unicode: bool = False) -> str:
|
||||
else:
|
||||
value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
|
||||
value = re.sub(r"[/]", "_", value.lower())
|
||||
value = re.sub(r"[^\w\s-]", "", value.lower())
|
||||
value = re.sub(r"[^.\w\s-]", "", value.lower())
|
||||
return re.sub(r"[-\s]+", "-", value).strip("-_")
|
||||
|
||||
|
||||
def safe_filename(directory: Path, value: str) -> str:
|
||||
"""Make a string safe to use as a filename."""
|
||||
escaped_string = slugify(value)
|
||||
max_name_length = os.pathconf(directory, "PC_NAME_MAX")
|
||||
return escaped_string[len(escaped_string) - max_name_length :]
|
||||
|
||||
|
||||
def directory_size(directory: Path) -> int:
|
||||
"""
|
||||
Return the aggregate size of all files in a directory (bytes).
|
||||
|
@ -1,6 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
||||
@ -22,7 +23,7 @@ def mock_context(
|
||||
)
|
||||
|
||||
|
||||
def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path):
|
||||
def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) -> None:
|
||||
downloaded_path = mock_context.models.download_and_cache_ckpt(
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
@ -37,13 +38,29 @@ def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path)
|
||||
assert downloaded_path == downloaded_path_2
|
||||
|
||||
|
||||
def test_download_and_load(mock_context: InvocationContext):
|
||||
def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -> None:
|
||||
downloaded_path = mock_context.models.download_and_cache_ckpt(
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
loaded_model_1 = mock_context.models.load_ckpt_from_path(downloaded_path)
|
||||
assert isinstance(loaded_model_1, LoadedModel)
|
||||
|
||||
loaded_model_2 = mock_context.models.load_ckpt_from_path(downloaded_path)
|
||||
assert isinstance(loaded_model_2, LoadedModel)
|
||||
assert loaded_model_1.model is loaded_model_2.model
|
||||
|
||||
loaded_model_3 = mock_context.models.load_ckpt_from_path(embedding_file)
|
||||
assert isinstance(loaded_model_3, LoadedModel)
|
||||
assert loaded_model_1.model is not loaded_model_3.model
|
||||
assert isinstance(loaded_model_1.model, dict)
|
||||
assert isinstance(loaded_model_3.model, dict)
|
||||
assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"])
|
||||
|
||||
|
||||
def test_download_and_load(mock_context: InvocationContext) -> None:
|
||||
loaded_model_1 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors")
|
||||
assert isinstance(loaded_model_1, LoadedModel)
|
||||
|
||||
loaded_model_2 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors")
|
||||
assert isinstance(loaded_model_2, LoadedModel)
|
||||
|
||||
with loaded_model_1 as model_1, loaded_model_2 as model_2:
|
||||
assert model_1 == model_2
|
||||
assert isinstance(model_1, dict)
|
||||
assert loaded_model_1.model is loaded_model_2.model # should be cached copy
|
||||
|
Loading…
Reference in New Issue
Block a user