implement suggestions from first review by @psychedelicious

This commit is contained in:
Lincoln Stein 2023-12-04 17:08:33 -05:00
parent f73b678aae
commit 620b2d477a
7 changed files with 177 additions and 78 deletions

View File

@ -455,17 +455,23 @@ The `import_model()` method is the core of the installer. The
following illustrates basic usage:
```
sources = [
Path('/opt/models/sushi.safetensors'), # a local safetensors file
Path('/opt/models/sushi_diffusers/'), # a local diffusers folder
'runwayml/stable-diffusion-v1-5', # a repo_id
'runwayml/stable-diffusion-v1-5:vae', # a subfolder within a repo_id
'https://civitai.com/api/download/models/63006', # a civitai direct download link
'https://civitai.com/models/8765?modelVersionId=10638', # civitai model page
'https://s3.amazon.com/fjacks/sd-3.safetensors', # arbitrary URL
]
from invokeai.app.services.model_install import (
LocalModelSource,
HFModelSource,
URLModelSource,
)
for source in sources:
source1 = LocalModelSource(path='/opt/models/sushi.safetensors') # a local safetensors file
source2 = LocalModelSource(path='/opt/models/sushi_diffusers') # a local diffusers folder
source3 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5') # a repo_id
source4 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='vae') # a subfolder within a repo_id
source5 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', variant='fp16') # a named variant of a HF model
source6 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL
source7 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token
for source in [source1, source2, source3, source4, source5, source6, source7]:
install_job = installer.install_model(source)
source2job = installer.wait_for_installs()

View File

@ -186,25 +186,11 @@ async def add_model_record(
status_code=201,
)
async def import_model(
source: ModelSource = Body(
description="A model path, repo_id or URL to import. NOTE: only model path is implemented currently!"
),
source: ModelSource,
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None,
),
variant: Optional[str] = Body(
description="When fetching a repo_id, force variant type to fetch such as 'fp16'",
default=None,
),
subfolder: Optional[str] = Body(
description="When fetching a repo_id, specify subfolder to fetch model from",
default=None,
),
access_token: Optional[str] = Body(
description="When fetching a repo_id or URL, access token for web access",
default=None,
),
) -> ModelInstallJob:
"""Add a model using its local path, repo_id, or remote URL.
@ -212,6 +198,38 @@ async def import_model(
series of background threads. The return object has `status` attribute
that can be used to monitor progress.
The source object is a discriminated Union of LocalModelSource,
HFModelSource and URLModelSource. Set the "type" field to the
appropriate value:
* To install a local path using LocalModelSource, pass a source of form:
`{
"type": "local",
"path": "/path/to/model",
"inplace": false
}`
The "inplace" flag, if true, will register the model in place in its
current filesystem location. Otherwise, the model will be copied
into the InvokeAI models directory.
* To install a HuggingFace repo_id using HFModelSource, pass a source of form:
`{
"type": "hf",
"repo_id": "stabilityai/stable-diffusion-2.0",
"variant": "fp16",
"subfolder": "vae",
"access_token": "f5820a918aaf01"
}`
The `variant`, `subfolder` and `access_token` fields are optional.
* To install a remote model using an arbitrary URL, pass:
`{
"type": "url",
"url": "http://www.civitai.com/models/123456",
"access_token": "f5820a918aaf01"
}`
The `access_token` field is optonal
The model's configuration record will be probed and filled in
automatically. To override the default guesses, pass "metadata"
with a Dict containing the attributes you wish to override.
@ -234,11 +252,8 @@ async def import_model(
try:
installer = ApiDependencies.invoker.services.model_install
result: ModelInstallJob = installer.import_model(
source,
source=source,
config=config,
variant=variant,
subfolder=subfolder,
access_token=access_token,
)
logger.info(f"Started installation of {source}")
except UnknownModelException as e:

View File

@ -383,17 +383,17 @@ class InvokeAIAppConfig(InvokeAISettings):
return db_dir / DB_FILE
@property
def model_conf_path(self) -> Optional[Path]:
def model_conf_path(self) -> Path:
"""Path to models configuration file."""
return self._resolve(self.conf_path)
@property
def legacy_conf_path(self) -> Optional[Path]:
def legacy_conf_path(self) -> Path:
"""Path to directory of legacy configuration files (e.g. v1-inference.yaml)."""
return self._resolve(self.legacy_conf_dir)
@property
def models_path(self) -> Optional[Path]:
def models_path(self) -> Path:
"""Path to the models directory."""
return self._resolve(self.models_dir)

View File

@ -1,11 +1,15 @@
"""Initialization file for model install service package."""
from .model_install_base import (
HFModelSource,
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallServiceBase,
ModelSource,
ModelSourceValidator,
UnknownInstallJobException,
URLModelSource,
)
from .model_install_default import ModelInstallService
@ -16,4 +20,8 @@ __all__ = [
"ModelInstallJob",
"UnknownInstallJobException",
"ModelSource",
"ModelSourceValidator",
"LocalModelSource",
"HFModelSource",
"URLModelSource",
]

View File

@ -1,11 +1,14 @@
import re
import traceback
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field
from fastapi import Body
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events import EventServiceBase
@ -27,7 +30,74 @@ class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested."""
ModelSource = Union[str, Path, AnyHttpUrl]
class StringLikeSource(BaseModel):
"""Base class for model sources, implements functions that lets the source be sorted and indexed."""
def __hash__(self) -> int:
"""Return hash of the path field, for indexing."""
return hash(str(self))
def __lt__(self, other: Any) -> int:
"""Return comparison of the stringified version, for sorting."""
return str(self) < str(other)
def __eq__(self, other: Any) -> bool:
"""Return equality on the stringified version."""
return str(self) == str(other)
class LocalModelSource(StringLikeSource):
"""A local file or directory path."""
path: str | Path
inplace: Optional[bool] = False
type: Literal["local"] = "local"
# these methods allow the source to be used in a string-like way,
# for example as an index into a dict
def __str__(self) -> str:
"""Return string version of path when string rep needed."""
return Path(self.path).as_posix()
class HFModelSource(StringLikeSource):
"""A HuggingFace repo_id, with optional variant and sub-folder."""
repo_id: str
variant: Optional[str] = None
subfolder: Optional[str | Path] = None
access_token: Optional[str] = None
type: Literal["hf"] = "hf"
@field_validator("repo_id")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = self.repo_id
base += f":{self.subfolder}" if self.subfolder else ""
base += f" ({self.variant})" if self.variant else ""
return base
class URLModelSource(StringLikeSource):
"""A generic URL point to a checkpoint file."""
url: AnyHttpUrl
access_token: Optional[str] = None
type: Literal["generic_url"] = "generic_url"
def __str__(self) -> str:
"""Return string version of the url when string rep needed."""
return str(self.url)
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Body(discriminator="type")]
ModelSourceValidator = TypeAdapter(ModelSource)
class ModelInstallJob(BaseModel):
@ -74,6 +144,7 @@ class ModelInstallServiceBase(ABC):
"""
def start(self, invoker: Invoker) -> None:
"""Call at InvokeAI startup time."""
self.sync_to_config()
@property
@ -139,25 +210,12 @@ class ModelInstallServiceBase(ABC):
@abstractmethod
def import_model(
self,
source: Union[str, Path, AnyHttpUrl],
inplace: bool = False,
variant: Optional[str] = None,
subfolder: Optional[str] = None,
source: ModelSource,
config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
) -> ModelInstallJob:
"""Install the indicated model.
:param source: Either a URL or a HuggingFace repo_id.
:param inplace: If True, local paths will not be moved into
the models directory, but registered in place (the default).
:param variant: For HuggingFace models, this optional parameter
specifies which variant to download (e.g. 'fp16')
:param subfolder: When downloading HF repo_ids this can be used to
specify a subfolder of the HF repository to download from.
:param source: ModelSource object
:param config: Optional dict. Any fields in this dict
will override corresponding autoassigned probe fields in the
@ -165,9 +223,6 @@ class ModelInstallServiceBase(ABC):
`name`, `description`, `base_type`, `model_type`, `format`,
`prediction_type`, `image_size`, and/or `ztsnr_training`.
:param access_token: Access token for use in downloading remote
models.
This will download the model located at `source`,
probe it, and install it into the models directory.
This call is executed asynchronously in a separate
@ -196,7 +251,7 @@ class ModelInstallServiceBase(ABC):
"""Return the ModelInstallJob corresponding to the provided source."""
@abstractmethod
def list_jobs(self, source: Optional[ModelSource] = None) -> List[ModelInstallJob]: # noqa D102
def list_jobs(self, source: Optional[ModelSource | str] = None) -> List[ModelInstallJob]: # noqa D102
"""
List active and complete install jobs.
@ -208,11 +263,11 @@ class ModelInstallServiceBase(ABC):
"""Prune all completed and errored jobs."""
@abstractmethod
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], ModelInstallJob]:
def wait_for_installs(self) -> Dict[ModelSource, ModelInstallJob]:
"""
Wait for all pending installs to complete.
This will block until all pending downloads have
This will block until all pending installs have
completed, been cancelled, or errored out. It will
block indefinitely if one or more jobs are in the
paused state.
@ -234,3 +289,12 @@ class ModelInstallServiceBase(ABC):
@abstractmethod
def sync_to_config(self) -> None:
"""Synchronize models on disk to those in the model record database."""
@abstractmethod
def release(self) -> None:
"""
Signal the install thread to exit.
This is useful if you are done with the installer and wish to
release its resources.
"""

View File

@ -2,6 +2,7 @@
import threading
from hashlib import sha256
from logging import Logger
from pathlib import Path
from queue import Queue
from random import randbytes
@ -24,6 +25,7 @@ from invokeai.backend.util import Chdir, InvokeAILogger
from .model_install_base import (
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallServiceBase,
ModelSource,
@ -31,7 +33,10 @@ from .model_install_base import (
)
# marker that the queue is done and that thread should exit
STOP_JOB = ModelInstallJob(source="stop", local_path=Path("/dev/null"))
STOP_JOB = ModelInstallJob(
source=LocalModelSource(path="stop"),
local_path=Path("/dev/null"),
)
class ModelInstallService(ModelInstallServiceBase):
@ -42,7 +47,7 @@ class ModelInstallService(ModelInstallServiceBase):
_event_bus: Optional[EventServiceBase] = None
_install_queue: Queue[ModelInstallJob]
_install_jobs: Dict[ModelSource, ModelInstallJob]
_logger: InvokeAILogger
_logger: Logger
_cached_model_paths: Set[Path]
_models_installed: Set[str]
@ -109,11 +114,16 @@ class ModelInstallService(ModelInstallServiceBase):
def _signal_job_running(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.RUNNING
self._logger.info(f"{job.source}: model installation started")
if self._event_bus:
self._event_bus.emit_model_install_started(str(job.source))
def _signal_job_completed(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.COMPLETED
assert job.config_out
self._logger.info(
f"{job.source}: model installation completed. {job.local_path} registered key {job.config_out.key}"
)
if self._event_bus:
assert job.local_path is not None
assert job.config_out is not None
@ -122,6 +132,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _signal_job_errored(self, job: ModelInstallJob, excp: Exception) -> None:
job.set_error(excp)
self._logger.info(f"{job.source}: model installation encountered an exception: {job.error_type}")
if self._event_bus:
error_type = job.error_type
error = job.error
@ -151,7 +162,6 @@ class ModelInstallService(ModelInstallServiceBase):
config["source"] = model_path.resolve().as_posix()
info: AnyModelConfig = self._probe_model(Path(model_path), config)
old_hash = info.original_hash
dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name
new_path = self._copy_model(model_path, dest_path)
@ -167,26 +177,17 @@ class ModelInstallService(ModelInstallServiceBase):
def import_model(
self,
source: ModelSource,
inplace: bool = False,
variant: Optional[str] = None,
subfolder: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
) -> ModelInstallJob: # noqa D102
# Clean up a common source of error. Doesn't work with Paths.
if isinstance(source, str):
source = source.strip()
if not config:
config = {}
# Installing a local path
if isinstance(source, (str, Path)) and Path(source).exists(): # a path that is already on disk
if isinstance(source, LocalModelSource) and Path(source.path).exists(): # a path that is already on disk
job = ModelInstallJob(
config_in=config,
source=source,
inplace=inplace,
local_path=Path(source),
config_in=config,
local_path=Path(source.path),
)
self._install_jobs[source] = job
self._install_queue.put(job)
@ -195,13 +196,12 @@ class ModelInstallService(ModelInstallServiceBase):
else: # here is where we'd download a URL or repo_id. Implementation pending download queue.
raise UnknownModelException("File or directory not found")
def list_jobs(self, source: Optional[ModelSource] = None) -> List[ModelInstallJob]: # noqa D102
def list_jobs(self, source: Optional[ModelSource | str] = None) -> List[ModelInstallJob]: # noqa D102
jobs = self._install_jobs
if not source:
return list(jobs.values())
else:
source = str(source)
return [jobs[x] for x in jobs if source in str(x)]
return [jobs[x] for x in jobs if str(source) in str(x)]
def get_job(self, source: ModelSource) -> ModelInstallJob: # noqa D102
try:
@ -344,6 +344,10 @@ class ModelInstallService(ModelInstallServiceBase):
path.unlink()
self.unregister(key)
def release(self) -> None:
"""Stop the install thread and release its resources."""
self._install_queue.put(STOP_JOB)
def _copy_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path

View File

@ -12,6 +12,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_install import (
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallService,
ModelInstallServiceBase,
@ -124,9 +125,10 @@ def test_install(installer: ModelInstallServiceBase, test_file: Path, app_config
def test_background_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
"""Note: may want to break this down into several smaller unit tests."""
source = test_file
path = test_file
description = "Test of metadata assignment"
job = installer.import_model(source, inplace=False, config={"description": description})
source = LocalModelSource(path=path, inplace=False)
job = installer.import_model(source, config={"description": description})
assert job is not None
assert isinstance(job, ModelInstallJob)
@ -147,8 +149,8 @@ def test_background_install(installer: ModelInstallServiceBase, test_file: Path,
event_names = [x.event_name for x in bus.events]
assert "model_install_started" in event_names
assert "model_install_completed" in event_names
assert Path(bus.events[0].payload["source"]) == Path(source)
assert Path(bus.events[1].payload["source"]) == Path(source)
assert Path(bus.events[0].payload["source"]) == source
assert Path(bus.events[1].payload["source"]) == source
key = bus.events[1].payload["key"]
assert key is not None