mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-08 11:57:36 +08:00
refactor multifile download code
This commit is contained in:
parent
2dae5eb7ad
commit
d968c6f379
@ -128,7 +128,8 @@ The queue operates on a series of download job objects. These objects
|
||||
specify the source and destination of the download, and keep track of
|
||||
the progress of the download.
|
||||
|
||||
The only job type currently implemented is `DownloadJob`, a pydantic object with the
|
||||
Two job types are defined. `DownloadJob` and
|
||||
`MultiFileDownloadJob`. The former is a pydantic object with the
|
||||
following fields:
|
||||
|
||||
| **Field** | **Type** | **Default** | **Description** |
|
||||
@ -138,7 +139,7 @@ following fields:
|
||||
| `dest` | Path | | Where to download to |
|
||||
| `access_token` | str | | [optional] string containing authentication token for access |
|
||||
| `on_start` | Callable | | [optional] callback when the download starts |
|
||||
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
|
||||
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
|
||||
| `on_complete` | Callable | | [optional] callback called after successful download completion |
|
||||
| `on_error` | Callable | | [optional] callback called after an error occurs |
|
||||
| `id` | int | auto assigned | Job ID, an integer >= 0 |
|
||||
@ -190,6 +191,33 @@ A cancelled job will have status `DownloadJobStatus.ERROR` and an
|
||||
`error_type` field of "DownloadJobCancelledException". In addition,
|
||||
the job's `cancelled` property will be set to True.
|
||||
|
||||
The `MultiFileDownloadJob` is used for diffusers model downloads,
|
||||
which contain multiple files and directories under a common root:
|
||||
|
||||
| **Field** | **Type** | **Default** | **Description** |
|
||||
|----------------|-----------------|---------------|-----------------|
|
||||
| _Fields passed in at job creation time_ |
|
||||
| `download_parts` | Set[DownloadJob]| | Component download jobs |
|
||||
| `dest` | Path | | Where to download to |
|
||||
| `on_start` | Callable | | [optional] callback when the download starts |
|
||||
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
|
||||
| `on_complete` | Callable | | [optional] callback called after successful download completion |
|
||||
| `on_error` | Callable | | [optional] callback called after an error occurs |
|
||||
| `id` | int | auto assigned | Job ID, an integer >= 0 |
|
||||
| _Fields updated over the course of the download task_
|
||||
| `status` | DownloadJobStatus| | Status code |
|
||||
| `download_path` | Path | | Path to the root of the downloaded files |
|
||||
| `bytes` | int | 0 | Bytes downloaded so far |
|
||||
| `total_bytes` | int | 0 | Total size of the file at the remote site |
|
||||
| `error_type` | str | | String version of the exception that caused an error during download |
|
||||
| `error` | str | | String version of the traceback associated with an error |
|
||||
| `cancelled` | bool | False | Set to true if the job was cancelled by the caller|
|
||||
|
||||
Note that the MultiFileDownloadJob does not support the `priority`,
|
||||
`job_started`, `job_ended` or `content_type` attributes. You can get
|
||||
these from the individual download jobs in `download_parts`.
|
||||
|
||||
|
||||
### Callbacks
|
||||
|
||||
Download jobs can be associated with a series of callbacks, each with
|
||||
@ -251,11 +279,40 @@ jobs using `list_jobs()`, fetch a single job by its with
|
||||
running jobs with `cancel_all_jobs()`, and wait for all jobs to finish
|
||||
with `join()`.
|
||||
|
||||
#### job = queue.download(source, dest, priority, access_token)
|
||||
#### job = queue.download(source, dest, priority, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
|
||||
|
||||
Create a new download job and put it on the queue, returning the
|
||||
DownloadJob object.
|
||||
|
||||
#### multifile_job = queue.multifile_download(parts, dest, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
|
||||
|
||||
This is similar to download(), but instead of taking a single source,
|
||||
it accepts a `parts` argument consisting of a list of
|
||||
`RemoteModelFile` objects. Each part corresponds to a URL/Path pair,
|
||||
where the URL is the location of the remote file, and the Path is the
|
||||
destination.
|
||||
|
||||
`RemoteModelFile` can be imported from `invokeai.backend.model_manager.metadata`, and
|
||||
consists of a url/path pair. Note that the path *must* be relative.
|
||||
|
||||
The method returns a `MultiFileDownloadJob`.
|
||||
|
||||
|
||||
```
|
||||
from invokeai.backend.model_manager.metadata import RemoteModelFile
|
||||
remote_file_1 = RemoteModelFile(url='http://www.foo.bar/my/pytorch_model.safetensors'',
|
||||
path='my_model/textencoder/pytorch_model.safetensors'
|
||||
)
|
||||
remote_file_2 = RemoteModelFile(url='http://www.bar.baz/vae.ckpt',
|
||||
path='my_model/vae/diffusers_model.safetensors'
|
||||
)
|
||||
job = queue.multifile_download(parts=[remote_file_1, remote_file_2],
|
||||
dest='/tmp/downloads',
|
||||
on_progress=TqdmProgress().update)
|
||||
queue.wait_for_job(job)
|
||||
print(f"The files were downloaded to {job.download_path}")
|
||||
```
|
||||
|
||||
#### jobs = queue.list_jobs()
|
||||
|
||||
Return a list of all active and inactive `DownloadJob`s.
|
||||
|
@ -1577,3 +1577,41 @@ This method takes a model key, looks it up using the
|
||||
`ModelRecordServiceBase` object in `mm.store`, and passes the returned
|
||||
model configuration to `load_model_by_config()`. It may raise a
|
||||
`NotImplementedException`.
|
||||
|
||||
## Invocation Context Model Manager API
|
||||
|
||||
Within invocations, the following methods are available from the
|
||||
`InvocationContext` object:
|
||||
|
||||
### context.download_and_cache_model(source) -> Path
|
||||
|
||||
This method accepts a `source` of a model, downloads and caches it
|
||||
locally, and returns a Path to the local model. The source can be a
|
||||
local file or directory, a URL, or a HuggingFace repo_id.
|
||||
|
||||
In the case of HuggingFace repo_id, the following variants are
|
||||
recognized:
|
||||
|
||||
* stabilityai/stable-diffusion-v4 -- default model
|
||||
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
|
||||
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
|
||||
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
|
||||
|
||||
You can also point at an arbitrary individual file within a repo_id
|
||||
directory using this syntax:
|
||||
|
||||
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
||||
|
||||
### context.load_and_cache_model(source, [loader]) -> LoadedModel
|
||||
|
||||
This method takes a model source, downloads it, caches it, and then
|
||||
loads it into the RAM cache for use in inference. The optional loader
|
||||
is a Callable that accepts a Path to the object, and returns a
|
||||
`Dict[str, torch.Tensor]`. If no loader is provided, then the method
|
||||
will use `torch.load()` for a .ckpt or .bin checkpoint file,
|
||||
`safetensors.torch.load_file()` for a safetensors checkpoint file, or
|
||||
`*.from_pretrained()` for a directory that looks like a
|
||||
diffusers directory.
|
||||
|
||||
|
||||
|
||||
|
@ -611,7 +611,7 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
|
||||
with context.models.load_ckpt_from_url(source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader) as model:
|
||||
with context.models.load_and_cache_model(source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader) as model:
|
||||
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
|
||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
||||
return processed_image
|
||||
@ -634,8 +634,8 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||
mm = context.models
|
||||
onnx_det = mm.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"])
|
||||
onnx_pose = mm.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
|
||||
onnx_det = mm.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"])
|
||||
onnx_pose = mm.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
|
||||
|
||||
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
||||
processed_image = dw_openpose(
|
||||
|
@ -133,7 +133,7 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
def infill(self, image: Image.Image, context: InvocationContext):
|
||||
with context.models.load_ckpt_from_url(
|
||||
with context.models.load_and_cache_model(
|
||||
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
loader=LaMA.load_jit_model,
|
||||
) as model:
|
||||
|
@ -91,7 +91,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
context.logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
loadnet = context.models.load_ckpt_from_url(
|
||||
loadnet = context.models.load_and_cache_model(
|
||||
source=ESRGAN_MODEL_URLS[self.model_name],
|
||||
)
|
||||
|
||||
|
@ -387,12 +387,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
model_path.mkdir(parents=True, exist_ok=True)
|
||||
model_source = self._guess_source(source)
|
||||
remote_files, _ = self._remote_files_from_source(model_source)
|
||||
job = self._download_queue.multifile_download(
|
||||
parts=remote_files,
|
||||
job = self._multifile_download(
|
||||
dest=model_path,
|
||||
remote_files=remote_files,
|
||||
subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
|
||||
)
|
||||
files_string = "file" if len(remote_files) == 1 else "file"
|
||||
self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
|
||||
files_string = "file" if len(remote_files) == 1 else "files"
|
||||
self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
|
||||
self._download_queue.wait_for_job(job)
|
||||
if job.complete:
|
||||
assert job.download_path is not None
|
||||
@ -734,26 +735,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
)
|
||||
# remember the temporary directory for later removal
|
||||
install_job._install_tmpdir = destdir
|
||||
install_job.total_bytes = sum((x.size or 0) for x in remote_files)
|
||||
|
||||
# In the event that there is a subfolder specified in the source,
|
||||
# we need to remove it from the destination path in order to avoid
|
||||
# creating unwanted subfolders
|
||||
if isinstance(source, HFModelSource) and source.subfolder:
|
||||
root = Path(remote_files[0].path.parts[0])
|
||||
subfolder = root / source.subfolder
|
||||
else:
|
||||
root = Path(".")
|
||||
subfolder = Path(".")
|
||||
|
||||
parts: List[RemoteModelFile] = []
|
||||
for model_file in remote_files:
|
||||
assert install_job.total_bytes is not None
|
||||
assert model_file.size is not None
|
||||
install_job.total_bytes += model_file.size
|
||||
parts.append(RemoteModelFile(url=model_file.url, path=model_file.path.relative_to(subfolder)))
|
||||
multifile_job = self._multifile_download(
|
||||
parts=parts,
|
||||
remote_files=remote_files,
|
||||
dest=destdir,
|
||||
subfolder=source.subfolder if isinstance(source, HFModelSource) else None,
|
||||
access_token=source.access_token,
|
||||
submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict
|
||||
)
|
||||
@ -776,8 +763,35 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
return size
|
||||
|
||||
def _multifile_download(
|
||||
self, parts: List[RemoteModelFile], dest: Path, access_token: Optional[str] = None, submit_job: bool = True
|
||||
self,
|
||||
remote_files: List[RemoteModelFile],
|
||||
dest: Path,
|
||||
subfolder: Optional[Path] = None,
|
||||
access_token: Optional[str] = None,
|
||||
submit_job: bool = True,
|
||||
) -> MultiFileDownloadJob:
|
||||
# HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and
|
||||
# we are installing the "vae" subfolder, we do not want to create an additional folder level, such
|
||||
# as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
|
||||
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
|
||||
if subfolder:
|
||||
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
|
||||
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
|
||||
path_to_add = Path(f"{top}_{subfolder}")
|
||||
else:
|
||||
path_to_remove = Path(".")
|
||||
path_to_add = Path(".")
|
||||
|
||||
parts: List[RemoteModelFile] = []
|
||||
for model_file in remote_files:
|
||||
assert model_file.size is not None
|
||||
parts.append(
|
||||
RemoteModelFile(
|
||||
url=model_file.url, # if a subfolder, then sdxl-turbo_vae/config.json
|
||||
path=path_to_add / model_file.path.relative_to(path_to_remove),
|
||||
)
|
||||
)
|
||||
|
||||
return self._download_queue.multifile_download(
|
||||
parts=parts,
|
||||
dest=dest,
|
||||
@ -795,56 +809,53 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# ------------------------------------------------------------------
|
||||
def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache[download_job.id]
|
||||
install_job.status = InstallStatus.DOWNLOADING
|
||||
if install_job := self._download_cache.get(download_job.id, None):
|
||||
install_job.status = InstallStatus.DOWNLOADING
|
||||
|
||||
assert download_job.download_path
|
||||
if install_job.local_path == install_job._install_tmpdir: # first time
|
||||
install_job.local_path = download_job.download_path
|
||||
install_job.total_bytes = download_job.total_bytes
|
||||
assert download_job.download_path
|
||||
if install_job.local_path == install_job._install_tmpdir: # first time
|
||||
install_job.local_path = download_job.download_path
|
||||
install_job.total_bytes = download_job.total_bytes
|
||||
|
||||
def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache[download_job.id]
|
||||
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
|
||||
self._download_queue.cancel_job(download_job)
|
||||
else:
|
||||
# update sizes
|
||||
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
|
||||
self._signal_job_downloading(install_job)
|
||||
if install_job := self._download_cache.get(download_job.id, None):
|
||||
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
|
||||
self._download_queue.cancel_job(download_job)
|
||||
else:
|
||||
# update sizes
|
||||
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
|
||||
self._signal_job_downloading(install_job)
|
||||
|
||||
def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache.pop(download_job.id)
|
||||
self._signal_job_downloads_done(install_job)
|
||||
self._put_in_queue(install_job) # this starts the installation and registration
|
||||
if install_job := self._download_cache.pop(download_job.id, None):
|
||||
self._signal_job_downloads_done(install_job)
|
||||
self._put_in_queue(install_job) # this starts the installation and registration
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
|
||||
def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache.pop(download_job.id)
|
||||
assert install_job is not None
|
||||
assert excp is not None
|
||||
install_job.set_error(excp)
|
||||
self._download_queue.cancel_job(download_job)
|
||||
if install_job := self._download_cache.pop(download_job.id, None):
|
||||
assert excp is not None
|
||||
install_job.set_error(excp)
|
||||
self._download_queue.cancel_job(download_job)
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
|
||||
def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache.pop(download_job.id, None)
|
||||
if not install_job:
|
||||
return
|
||||
self._downloads_changed_event.set()
|
||||
# if install job has already registered an error, then do not replace its status with cancelled
|
||||
if not install_job.errored:
|
||||
install_job.cancel()
|
||||
if install_job := self._download_cache.pop(download_job.id, None):
|
||||
self._downloads_changed_event.set()
|
||||
# if install job has already registered an error, then do not replace its status with cancelled
|
||||
if not install_job.errored:
|
||||
install_job.cancel()
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
|
||||
# ------------------------------------------------------------------------------------------------
|
||||
# Internal methods that put events on the event bus
|
||||
|
@ -43,11 +43,11 @@ class ModelLoadServiceBase(ABC):
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
|
||||
@abstractmethod
|
||||
def load_ckpt_from_path(
|
||||
def load_model_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Load the checkpoint-format model file located at the indicated Path.
|
||||
Load the model file or directory located at the indicated Path.
|
||||
|
||||
This will load an arbitrary model file into the RAM cache. If the optional loader
|
||||
argument is provided, the loader will be invoked to load the model into
|
||||
|
@ -20,6 +20,8 @@ from invokeai.backend.model_manager.load import (
|
||||
)
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .model_load_base import ModelLoadServiceBase
|
||||
@ -94,7 +96,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
)
|
||||
return loaded_model
|
||||
|
||||
def load_ckpt_from_path(
|
||||
def load_model_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
@ -128,6 +130,16 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
result: Dict[str, Tensor] = torch_load(checkpoint, map_location="cpu")
|
||||
return result
|
||||
|
||||
def diffusers_load_directory(directory: Path) -> AnyModel:
|
||||
load_class = GenericDiffusersLoader(
|
||||
app_config=self._app_config,
|
||||
logger=self._logger,
|
||||
ram_cache=self._ram_cache,
|
||||
convert_cache=self.convert_cache,
|
||||
).get_hf_load_class(directory)
|
||||
result: AnyModel = load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
|
||||
return result
|
||||
|
||||
if loader is None:
|
||||
loader = (
|
||||
torch_load_file
|
||||
|
@ -72,7 +72,7 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_ckpt_from_url(
|
||||
def load_model_from_url(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
||||
|
@ -64,6 +64,34 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
if hasattr(service, "stop"):
|
||||
service.stop(invoker)
|
||||
|
||||
def load_model_from_url(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Download, cache, and Load the model file located at the indicated URL.
|
||||
|
||||
This will check the model download cache for the model designated
|
||||
by the provided URL and download it if needed using download_and_cache_ckpt().
|
||||
It will then load the model into the RAM cache. If the optional loader
|
||||
argument is provided, the loader will be invoked to load the model into
|
||||
memory. Otherwise the method will call safetensors.torch.load_file() or
|
||||
torch.load() as appropriate to the file suffix.
|
||||
|
||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
||||
|
||||
Args:
|
||||
source: A URL or a string that can be converted in one. Repo_ids
|
||||
do not work here.
|
||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
"""
|
||||
model_path = self.install.download_and_cache_model(source=str(source))
|
||||
return self.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
|
||||
@classmethod
|
||||
def build_model_manager(
|
||||
cls,
|
||||
@ -102,31 +130,3 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
event_bus=events,
|
||||
)
|
||||
return cls(store=model_record_service, install=installer, load=loader)
|
||||
|
||||
def load_ckpt_from_url(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Download, cache, and Load the model file located at the indicated URL.
|
||||
|
||||
This will check the model download cache for the model designated
|
||||
by the provided URL and download it if needed using download_and_cache_ckpt().
|
||||
It will then load the model into the RAM cache. If the optional loader
|
||||
argument is provided, the loader will be invoked to load the model into
|
||||
memory. Otherwise the method will call safetensors.torch.load_file() or
|
||||
torch.load() as appropriate to the file suffix.
|
||||
|
||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
||||
|
||||
Args:
|
||||
source: A URL or a string that can be converted in one. Repo_ids
|
||||
do not work here.
|
||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
"""
|
||||
model_path = self.install.download_and_cache_ckpt(source=source)
|
||||
return self.load.load_ckpt_from_path(model_path=model_path, loader=loader)
|
||||
|
@ -435,11 +435,9 @@ class ModelsInterface(InvocationContextInterface):
|
||||
)
|
||||
return result
|
||||
|
||||
def download_and_cache_ckpt(
|
||||
def download_and_cache_model(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
access_token: Optional[str] = None,
|
||||
timeout: Optional[int] = 0,
|
||||
) -> Path:
|
||||
"""
|
||||
Download the model file located at source to the models cache and return its Path.
|
||||
@ -449,12 +447,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
installed, the cached path will be returned. Otherwise it will be downloaded.
|
||||
|
||||
Args:
|
||||
source: A URL or a string that can be converted in one. Repo_ids
|
||||
do not work here.
|
||||
access_token: Optional access token for restricted resources.
|
||||
timeout: Wait up to the indicated number of seconds before timing
|
||||
out long downloads.
|
||||
|
||||
source: A model path, URL or repo_id.
|
||||
Result:
|
||||
Path to the downloaded model
|
||||
|
||||
@ -463,39 +456,14 @@ class ModelsInterface(InvocationContextInterface):
|
||||
TimeoutError
|
||||
"""
|
||||
installer = self._services.model_manager.install
|
||||
path: Path = installer.download_and_cache_ckpt(
|
||||
path: Path = installer.download_and_cache_model(
|
||||
source=source,
|
||||
access_token=access_token,
|
||||
timeout=timeout,
|
||||
)
|
||||
return path
|
||||
|
||||
def load_ckpt_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Load the checkpoint-format model file located at the indicated Path.
|
||||
|
||||
This will load an arbitrary model file into the RAM cache. If the optional loader
|
||||
argument is provided, the loader will be invoked to load the model into
|
||||
memory. Otherwise the method will call safetensors.torch.load_file() or
|
||||
torch.load() as appropriate to the file suffix.
|
||||
|
||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
||||
|
||||
Args:
|
||||
model_path: A pathlib.Path to a checkpoint-style models file
|
||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
"""
|
||||
result: LoadedModel = self._services.model_manager.load.load_ckpt_from_path(model_path, loader=loader)
|
||||
return result
|
||||
|
||||
def load_ckpt_from_url(
|
||||
def load_and_cache_model(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
source: Path | str | AnyHttpUrl,
|
||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
@ -511,14 +479,17 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
||||
|
||||
Args:
|
||||
source: A URL or a string that can be converted in one. Repo_ids
|
||||
do not work here.
|
||||
source: A model Path, URL, or repoid.
|
||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
"""
|
||||
result: LoadedModel = self._services.model_manager.load_ckpt_from_url(source=source, loader=loader)
|
||||
result: LoadedModel = (
|
||||
self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader)
|
||||
if isinstance(source, Path)
|
||||
else self._services.model_manager.load_model_from_url(source=source, loader=loader)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
|
@ -304,6 +304,19 @@ def test_multifile_onefile(tmp_path: Path, mm2_session: Session) -> None:
|
||||
queue.stop()
|
||||
|
||||
|
||||
def test_multifile_no_rel_paths(tmp_path: Path, mm2_session: Session) -> None:
|
||||
queue = DownloadQueueService(
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
|
||||
with pytest.raises(AssertionError) as error:
|
||||
queue.multifile_download(
|
||||
parts=[RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("/etc/passwd"))],
|
||||
dest=tmp_path,
|
||||
)
|
||||
assert str(error.value) == "only relative download paths accepted"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def clear_config() -> Generator[None, None, None]:
|
||||
try:
|
||||
|
@ -24,7 +24,7 @@ def mock_context(
|
||||
|
||||
|
||||
def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) -> None:
|
||||
downloaded_path = mock_context.models.download_and_cache_ckpt(
|
||||
downloaded_path = mock_context.models.download_and_cache_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
assert downloaded_path.is_file()
|
||||
@ -32,24 +32,24 @@ def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path)
|
||||
assert downloaded_path.name == "test_embedding.safetensors"
|
||||
assert downloaded_path.parent.parent == mm2_root_dir / "models/.download_cache"
|
||||
|
||||
downloaded_path_2 = mock_context.models.download_and_cache_ckpt(
|
||||
downloaded_path_2 = mock_context.models.download_and_cache_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
assert downloaded_path == downloaded_path_2
|
||||
|
||||
|
||||
def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -> None:
|
||||
downloaded_path = mock_context.models.download_and_cache_ckpt(
|
||||
downloaded_path = mock_context.models.download_and_cache_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
loaded_model_1 = mock_context.models.load_ckpt_from_path(downloaded_path)
|
||||
loaded_model_1 = mock_context.models.load_and_cache_model(downloaded_path)
|
||||
assert isinstance(loaded_model_1, LoadedModel)
|
||||
|
||||
loaded_model_2 = mock_context.models.load_ckpt_from_path(downloaded_path)
|
||||
loaded_model_2 = mock_context.models.load_and_cache_model(downloaded_path)
|
||||
assert isinstance(loaded_model_2, LoadedModel)
|
||||
assert loaded_model_1.model is loaded_model_2.model
|
||||
|
||||
loaded_model_3 = mock_context.models.load_ckpt_from_path(embedding_file)
|
||||
loaded_model_3 = mock_context.models.load_and_cache_model(embedding_file)
|
||||
assert isinstance(loaded_model_3, LoadedModel)
|
||||
assert loaded_model_1.model is not loaded_model_3.model
|
||||
assert isinstance(loaded_model_1.model, dict)
|
||||
@ -58,9 +58,25 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -
|
||||
|
||||
|
||||
def test_download_and_load(mock_context: InvocationContext) -> None:
|
||||
loaded_model_1 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors")
|
||||
loaded_model_1 = mock_context.models.load_and_cache_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
assert isinstance(loaded_model_1, LoadedModel)
|
||||
|
||||
loaded_model_2 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors")
|
||||
loaded_model_2 = mock_context.models.load_and_cache_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
assert isinstance(loaded_model_2, LoadedModel)
|
||||
assert loaded_model_1.model is loaded_model_2.model # should be cached copy
|
||||
|
||||
|
||||
def test_download_diffusers(mock_context: InvocationContext) -> None:
|
||||
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo")
|
||||
assert (model_path / "model_index.json").exists()
|
||||
assert (model_path / "vae").is_dir()
|
||||
|
||||
|
||||
def test_download_diffusers_subfolder(mock_context: InvocationContext) -> None:
|
||||
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::vae")
|
||||
assert model_path.is_dir()
|
||||
assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists()
|
||||
|
Loading…
Reference in New Issue
Block a user