mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-09 12:37:45 +08:00
test(model_management): add a couple tests for _get_model_path
This commit is contained in:
parent
65ed224bfc
commit
44bf308192
@ -533,7 +533,7 @@ class ModelManager(object):
|
||||
model_path = self.resolve_model_path(model_path)
|
||||
return model_path, is_submodel_override
|
||||
|
||||
def _get_model_config(self, base_model, model_name, model_type) -> ModelConfigBase:
|
||||
def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase:
|
||||
"""Get a model's config object."""
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
try:
|
||||
|
@ -100,7 +100,7 @@ dependencies = [
|
||||
"dev" = [
|
||||
"pudb",
|
||||
]
|
||||
"test" = ["pytest>6.0.0", "pytest-cov", "black"]
|
||||
"test" = ["pytest>6.0.0", "pytest-cov", "pytest-datadir", "black"]
|
||||
"xformers" = [
|
||||
"xformers~=0.0.19; sys_platform!='darwin'",
|
||||
"triton; sys_platform=='linux'",
|
||||
|
36
tests/test_model_manager.py
Normal file
36
tests/test_model_manager.py
Normal file
@ -0,0 +1,36 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend import ModelManager, BaseModelType, ModelType, SubModelType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_manager(datadir) -> ModelManager:
|
||||
InvokeAIAppConfig.get_config(root=datadir)
|
||||
return ModelManager(datadir / "configs" / "relative_sub.models.yaml")
|
||||
|
||||
|
||||
def test_get_model_names(model_manager: ModelManager):
|
||||
names = model_manager.model_names()
|
||||
assert names[:2] == [
|
||||
("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main),
|
||||
("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main),
|
||||
]
|
||||
|
||||
|
||||
def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path):
|
||||
model_config = model_manager._get_model_config(BaseModelType.StableDiffusionXL, "SDXL base", ModelType.Main)
|
||||
top_model_path, is_override = model_manager._get_model_path(model_config)
|
||||
expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0"
|
||||
assert top_model_path == expected_model_path
|
||||
assert not is_override
|
||||
|
||||
|
||||
def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path):
|
||||
model_config = model_manager._get_model_config(BaseModelType.StableDiffusionXL, "SDXL with VAE", ModelType.Main)
|
||||
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
|
||||
expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix"
|
||||
assert vae_model_path == expected_vae_path
|
||||
assert is_override
|
15
tests/test_model_manager/configs/relative_sub.models.yaml
Normal file
15
tests/test_model_manager/configs/relative_sub.models.yaml
Normal file
@ -0,0 +1,15 @@
|
||||
__metadata__:
|
||||
version: 3.0.0
|
||||
|
||||
sdxl/main/SDXL base:
|
||||
path: sdxl/main/SDXL base 1_0
|
||||
description: SDXL base v1.0
|
||||
variant: normal
|
||||
format: diffusers
|
||||
|
||||
sdxl/main/SDXL with VAE:
|
||||
path: sdxl/main/SDXL base 1_0
|
||||
description: SDXL base v1.0
|
||||
vae: sdxl/vae/sdxl-vae-fp16-fix/
|
||||
variant: normal
|
||||
format: diffusers
|
Loading…
Reference in New Issue
Block a user