mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-04-04 22:43:40 +08:00
implement suggestions from first review by @psychedelicious
This commit is contained in:
parent
f73b678aae
commit
620b2d477a
@ -455,17 +455,23 @@ The `import_model()` method is the core of the installer. The
|
||||
following illustrates basic usage:
|
||||
|
||||
```
|
||||
sources = [
|
||||
Path('/opt/models/sushi.safetensors'), # a local safetensors file
|
||||
Path('/opt/models/sushi_diffusers/'), # a local diffusers folder
|
||||
'runwayml/stable-diffusion-v1-5', # a repo_id
|
||||
'runwayml/stable-diffusion-v1-5:vae', # a subfolder within a repo_id
|
||||
'https://civitai.com/api/download/models/63006', # a civitai direct download link
|
||||
'https://civitai.com/models/8765?modelVersionId=10638', # civitai model page
|
||||
'https://s3.amazon.com/fjacks/sd-3.safetensors', # arbitrary URL
|
||||
]
|
||||
from invokeai.app.services.model_install import (
|
||||
LocalModelSource,
|
||||
HFModelSource,
|
||||
URLModelSource,
|
||||
)
|
||||
|
||||
for source in sources:
|
||||
source1 = LocalModelSource(path='/opt/models/sushi.safetensors') # a local safetensors file
|
||||
source2 = LocalModelSource(path='/opt/models/sushi_diffusers') # a local diffusers folder
|
||||
|
||||
source3 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5') # a repo_id
|
||||
source4 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='vae') # a subfolder within a repo_id
|
||||
source5 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', variant='fp16') # a named variant of a HF model
|
||||
|
||||
source6 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL
|
||||
source7 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token
|
||||
|
||||
for source in [source1, source2, source3, source4, source5, source6, source7]:
|
||||
install_job = installer.install_model(source)
|
||||
|
||||
source2job = installer.wait_for_installs()
|
||||
|
@ -186,25 +186,11 @@ async def add_model_record(
|
||||
status_code=201,
|
||||
)
|
||||
async def import_model(
|
||||
source: ModelSource = Body(
|
||||
description="A model path, repo_id or URL to import. NOTE: only model path is implemented currently!"
|
||||
),
|
||||
source: ModelSource,
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
default=None,
|
||||
),
|
||||
variant: Optional[str] = Body(
|
||||
description="When fetching a repo_id, force variant type to fetch such as 'fp16'",
|
||||
default=None,
|
||||
),
|
||||
subfolder: Optional[str] = Body(
|
||||
description="When fetching a repo_id, specify subfolder to fetch model from",
|
||||
default=None,
|
||||
),
|
||||
access_token: Optional[str] = Body(
|
||||
description="When fetching a repo_id or URL, access token for web access",
|
||||
default=None,
|
||||
),
|
||||
) -> ModelInstallJob:
|
||||
"""Add a model using its local path, repo_id, or remote URL.
|
||||
|
||||
@ -212,6 +198,38 @@ async def import_model(
|
||||
series of background threads. The return object has `status` attribute
|
||||
that can be used to monitor progress.
|
||||
|
||||
The source object is a discriminated Union of LocalModelSource,
|
||||
HFModelSource and URLModelSource. Set the "type" field to the
|
||||
appropriate value:
|
||||
|
||||
* To install a local path using LocalModelSource, pass a source of form:
|
||||
`{
|
||||
"type": "local",
|
||||
"path": "/path/to/model",
|
||||
"inplace": false
|
||||
}`
|
||||
The "inplace" flag, if true, will register the model in place in its
|
||||
current filesystem location. Otherwise, the model will be copied
|
||||
into the InvokeAI models directory.
|
||||
|
||||
* To install a HuggingFace repo_id using HFModelSource, pass a source of form:
|
||||
`{
|
||||
"type": "hf",
|
||||
"repo_id": "stabilityai/stable-diffusion-2.0",
|
||||
"variant": "fp16",
|
||||
"subfolder": "vae",
|
||||
"access_token": "f5820a918aaf01"
|
||||
}`
|
||||
The `variant`, `subfolder` and `access_token` fields are optional.
|
||||
|
||||
* To install a remote model using an arbitrary URL, pass:
|
||||
`{
|
||||
"type": "url",
|
||||
"url": "http://www.civitai.com/models/123456",
|
||||
"access_token": "f5820a918aaf01"
|
||||
}`
|
||||
The `access_token` field is optonal
|
||||
|
||||
The model's configuration record will be probed and filled in
|
||||
automatically. To override the default guesses, pass "metadata"
|
||||
with a Dict containing the attributes you wish to override.
|
||||
@ -234,11 +252,8 @@ async def import_model(
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
result: ModelInstallJob = installer.import_model(
|
||||
source,
|
||||
source=source,
|
||||
config=config,
|
||||
variant=variant,
|
||||
subfolder=subfolder,
|
||||
access_token=access_token,
|
||||
)
|
||||
logger.info(f"Started installation of {source}")
|
||||
except UnknownModelException as e:
|
||||
|
@ -383,17 +383,17 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
return db_dir / DB_FILE
|
||||
|
||||
@property
|
||||
def model_conf_path(self) -> Optional[Path]:
|
||||
def model_conf_path(self) -> Path:
|
||||
"""Path to models configuration file."""
|
||||
return self._resolve(self.conf_path)
|
||||
|
||||
@property
|
||||
def legacy_conf_path(self) -> Optional[Path]:
|
||||
def legacy_conf_path(self) -> Path:
|
||||
"""Path to directory of legacy configuration files (e.g. v1-inference.yaml)."""
|
||||
return self._resolve(self.legacy_conf_dir)
|
||||
|
||||
@property
|
||||
def models_path(self) -> Optional[Path]:
|
||||
def models_path(self) -> Path:
|
||||
"""Path to the models directory."""
|
||||
return self._resolve(self.models_dir)
|
||||
|
||||
|
@ -1,11 +1,15 @@
|
||||
"""Initialization file for model install service package."""
|
||||
|
||||
from .model_install_base import (
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
ModelInstallJob,
|
||||
ModelInstallServiceBase,
|
||||
ModelSource,
|
||||
ModelSourceValidator,
|
||||
UnknownInstallJobException,
|
||||
URLModelSource,
|
||||
)
|
||||
from .model_install_default import ModelInstallService
|
||||
|
||||
@ -16,4 +20,8 @@ __all__ = [
|
||||
"ModelInstallJob",
|
||||
"UnknownInstallJobException",
|
||||
"ModelSource",
|
||||
"ModelSourceValidator",
|
||||
"LocalModelSource",
|
||||
"HFModelSource",
|
||||
"URLModelSource",
|
||||
]
|
||||
|
@ -1,11 +1,14 @@
|
||||
import re
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from fastapi import Body
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
@ -27,7 +30,74 @@ class UnknownInstallJobException(Exception):
|
||||
"""Raised when the status of an unknown job is requested."""
|
||||
|
||||
|
||||
ModelSource = Union[str, Path, AnyHttpUrl]
|
||||
class StringLikeSource(BaseModel):
|
||||
"""Base class for model sources, implements functions that lets the source be sorted and indexed."""
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash of the path field, for indexing."""
|
||||
return hash(str(self))
|
||||
|
||||
def __lt__(self, other: Any) -> int:
|
||||
"""Return comparison of the stringified version, for sorting."""
|
||||
return str(self) < str(other)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Return equality on the stringified version."""
|
||||
return str(self) == str(other)
|
||||
|
||||
|
||||
class LocalModelSource(StringLikeSource):
|
||||
"""A local file or directory path."""
|
||||
|
||||
path: str | Path
|
||||
inplace: Optional[bool] = False
|
||||
type: Literal["local"] = "local"
|
||||
|
||||
# these methods allow the source to be used in a string-like way,
|
||||
# for example as an index into a dict
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of path when string rep needed."""
|
||||
return Path(self.path).as_posix()
|
||||
|
||||
|
||||
class HFModelSource(StringLikeSource):
|
||||
"""A HuggingFace repo_id, with optional variant and sub-folder."""
|
||||
|
||||
repo_id: str
|
||||
variant: Optional[str] = None
|
||||
subfolder: Optional[str | Path] = None
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["hf"] = "hf"
|
||||
|
||||
@field_validator("repo_id")
|
||||
@classmethod
|
||||
def proper_repo_id(cls, v: str) -> str: # noqa D102
|
||||
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
|
||||
raise ValueError(f"{v}: invalid repo_id format")
|
||||
return v
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of repoid when string rep needed."""
|
||||
base: str = self.repo_id
|
||||
base += f":{self.subfolder}" if self.subfolder else ""
|
||||
base += f" ({self.variant})" if self.variant else ""
|
||||
return base
|
||||
|
||||
|
||||
class URLModelSource(StringLikeSource):
|
||||
"""A generic URL point to a checkpoint file."""
|
||||
|
||||
url: AnyHttpUrl
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["generic_url"] = "generic_url"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of the url when string rep needed."""
|
||||
return str(self.url)
|
||||
|
||||
|
||||
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Body(discriminator="type")]
|
||||
ModelSourceValidator = TypeAdapter(ModelSource)
|
||||
|
||||
|
||||
class ModelInstallJob(BaseModel):
|
||||
@ -74,6 +144,7 @@ class ModelInstallServiceBase(ABC):
|
||||
"""
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
"""Call at InvokeAI startup time."""
|
||||
self.sync_to_config()
|
||||
|
||||
@property
|
||||
@ -139,25 +210,12 @@ class ModelInstallServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def import_model(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
inplace: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
source: ModelSource,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""Install the indicated model.
|
||||
|
||||
:param source: Either a URL or a HuggingFace repo_id.
|
||||
|
||||
:param inplace: If True, local paths will not be moved into
|
||||
the models directory, but registered in place (the default).
|
||||
|
||||
:param variant: For HuggingFace models, this optional parameter
|
||||
specifies which variant to download (e.g. 'fp16')
|
||||
|
||||
:param subfolder: When downloading HF repo_ids this can be used to
|
||||
specify a subfolder of the HF repository to download from.
|
||||
:param source: ModelSource object
|
||||
|
||||
:param config: Optional dict. Any fields in this dict
|
||||
will override corresponding autoassigned probe fields in the
|
||||
@ -165,9 +223,6 @@ class ModelInstallServiceBase(ABC):
|
||||
`name`, `description`, `base_type`, `model_type`, `format`,
|
||||
`prediction_type`, `image_size`, and/or `ztsnr_training`.
|
||||
|
||||
:param access_token: Access token for use in downloading remote
|
||||
models.
|
||||
|
||||
This will download the model located at `source`,
|
||||
probe it, and install it into the models directory.
|
||||
This call is executed asynchronously in a separate
|
||||
@ -196,7 +251,7 @@ class ModelInstallServiceBase(ABC):
|
||||
"""Return the ModelInstallJob corresponding to the provided source."""
|
||||
|
||||
@abstractmethod
|
||||
def list_jobs(self, source: Optional[ModelSource] = None) -> List[ModelInstallJob]: # noqa D102
|
||||
def list_jobs(self, source: Optional[ModelSource | str] = None) -> List[ModelInstallJob]: # noqa D102
|
||||
"""
|
||||
List active and complete install jobs.
|
||||
|
||||
@ -208,11 +263,11 @@ class ModelInstallServiceBase(ABC):
|
||||
"""Prune all completed and errored jobs."""
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], ModelInstallJob]:
|
||||
def wait_for_installs(self) -> Dict[ModelSource, ModelInstallJob]:
|
||||
"""
|
||||
Wait for all pending installs to complete.
|
||||
|
||||
This will block until all pending downloads have
|
||||
This will block until all pending installs have
|
||||
completed, been cancelled, or errored out. It will
|
||||
block indefinitely if one or more jobs are in the
|
||||
paused state.
|
||||
@ -234,3 +289,12 @@ class ModelInstallServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def sync_to_config(self) -> None:
|
||||
"""Synchronize models on disk to those in the model record database."""
|
||||
|
||||
@abstractmethod
|
||||
def release(self) -> None:
|
||||
"""
|
||||
Signal the install thread to exit.
|
||||
|
||||
This is useful if you are done with the installer and wish to
|
||||
release its resources.
|
||||
"""
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import threading
|
||||
from hashlib import sha256
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from random import randbytes
|
||||
@ -24,6 +25,7 @@ from invokeai.backend.util import Chdir, InvokeAILogger
|
||||
|
||||
from .model_install_base import (
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
ModelInstallJob,
|
||||
ModelInstallServiceBase,
|
||||
ModelSource,
|
||||
@ -31,7 +33,10 @@ from .model_install_base import (
|
||||
)
|
||||
|
||||
# marker that the queue is done and that thread should exit
|
||||
STOP_JOB = ModelInstallJob(source="stop", local_path=Path("/dev/null"))
|
||||
STOP_JOB = ModelInstallJob(
|
||||
source=LocalModelSource(path="stop"),
|
||||
local_path=Path("/dev/null"),
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallService(ModelInstallServiceBase):
|
||||
@ -42,7 +47,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
_event_bus: Optional[EventServiceBase] = None
|
||||
_install_queue: Queue[ModelInstallJob]
|
||||
_install_jobs: Dict[ModelSource, ModelInstallJob]
|
||||
_logger: InvokeAILogger
|
||||
_logger: Logger
|
||||
_cached_model_paths: Set[Path]
|
||||
_models_installed: Set[str]
|
||||
|
||||
@ -109,11 +114,16 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def _signal_job_running(self, job: ModelInstallJob) -> None:
|
||||
job.status = InstallStatus.RUNNING
|
||||
self._logger.info(f"{job.source}: model installation started")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_started(str(job.source))
|
||||
|
||||
def _signal_job_completed(self, job: ModelInstallJob) -> None:
|
||||
job.status = InstallStatus.COMPLETED
|
||||
assert job.config_out
|
||||
self._logger.info(
|
||||
f"{job.source}: model installation completed. {job.local_path} registered key {job.config_out.key}"
|
||||
)
|
||||
if self._event_bus:
|
||||
assert job.local_path is not None
|
||||
assert job.config_out is not None
|
||||
@ -122,6 +132,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def _signal_job_errored(self, job: ModelInstallJob, excp: Exception) -> None:
|
||||
job.set_error(excp)
|
||||
self._logger.info(f"{job.source}: model installation encountered an exception: {job.error_type}")
|
||||
if self._event_bus:
|
||||
error_type = job.error_type
|
||||
error = job.error
|
||||
@ -151,7 +162,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
|
||||
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
||||
|
||||
old_hash = info.original_hash
|
||||
dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name
|
||||
new_path = self._copy_model(model_path, dest_path)
|
||||
@ -167,26 +177,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def import_model(
|
||||
self,
|
||||
source: ModelSource,
|
||||
inplace: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallJob: # noqa D102
|
||||
# Clean up a common source of error. Doesn't work with Paths.
|
||||
if isinstance(source, str):
|
||||
source = source.strip()
|
||||
|
||||
if not config:
|
||||
config = {}
|
||||
|
||||
# Installing a local path
|
||||
if isinstance(source, (str, Path)) and Path(source).exists(): # a path that is already on disk
|
||||
if isinstance(source, LocalModelSource) and Path(source.path).exists(): # a path that is already on disk
|
||||
job = ModelInstallJob(
|
||||
config_in=config,
|
||||
source=source,
|
||||
inplace=inplace,
|
||||
local_path=Path(source),
|
||||
config_in=config,
|
||||
local_path=Path(source.path),
|
||||
)
|
||||
self._install_jobs[source] = job
|
||||
self._install_queue.put(job)
|
||||
@ -195,13 +196,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
else: # here is where we'd download a URL or repo_id. Implementation pending download queue.
|
||||
raise UnknownModelException("File or directory not found")
|
||||
|
||||
def list_jobs(self, source: Optional[ModelSource] = None) -> List[ModelInstallJob]: # noqa D102
|
||||
def list_jobs(self, source: Optional[ModelSource | str] = None) -> List[ModelInstallJob]: # noqa D102
|
||||
jobs = self._install_jobs
|
||||
if not source:
|
||||
return list(jobs.values())
|
||||
else:
|
||||
source = str(source)
|
||||
return [jobs[x] for x in jobs if source in str(x)]
|
||||
return [jobs[x] for x in jobs if str(source) in str(x)]
|
||||
|
||||
def get_job(self, source: ModelSource) -> ModelInstallJob: # noqa D102
|
||||
try:
|
||||
@ -344,6 +344,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
path.unlink()
|
||||
self.unregister(key)
|
||||
|
||||
def release(self) -> None:
|
||||
"""Stop the install thread and release its resources."""
|
||||
self._install_queue.put(STOP_JOB)
|
||||
|
||||
def _copy_model(self, old_path: Path, new_path: Path) -> Path:
|
||||
if old_path == new_path:
|
||||
return old_path
|
||||
|
@ -12,6 +12,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.model_install import (
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
ModelInstallJob,
|
||||
ModelInstallService,
|
||||
ModelInstallServiceBase,
|
||||
@ -124,9 +125,10 @@ def test_install(installer: ModelInstallServiceBase, test_file: Path, app_config
|
||||
|
||||
def test_background_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
|
||||
"""Note: may want to break this down into several smaller unit tests."""
|
||||
source = test_file
|
||||
path = test_file
|
||||
description = "Test of metadata assignment"
|
||||
job = installer.import_model(source, inplace=False, config={"description": description})
|
||||
source = LocalModelSource(path=path, inplace=False)
|
||||
job = installer.import_model(source, config={"description": description})
|
||||
assert job is not None
|
||||
assert isinstance(job, ModelInstallJob)
|
||||
|
||||
@ -147,8 +149,8 @@ def test_background_install(installer: ModelInstallServiceBase, test_file: Path,
|
||||
event_names = [x.event_name for x in bus.events]
|
||||
assert "model_install_started" in event_names
|
||||
assert "model_install_completed" in event_names
|
||||
assert Path(bus.events[0].payload["source"]) == Path(source)
|
||||
assert Path(bus.events[1].payload["source"]) == Path(source)
|
||||
assert Path(bus.events[0].payload["source"]) == source
|
||||
assert Path(bus.events[1].payload["source"]) == source
|
||||
key = bus.events[1].payload["key"]
|
||||
assert key is not None
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user