add multifile_download() method to download service

This commit is contained in:
Lincoln Stein 2024-05-12 20:14:00 -06:00
parent b48d4a049d
commit 0bf14c2830
6 changed files with 350 additions and 83 deletions

View File

@ -1,10 +1,17 @@
"""Init file for download queue."""
from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException
from .download_base import (
DownloadJob,
DownloadJobStatus,
DownloadQueueServiceBase,
MultiFileDownloadJob,
UnknownJobIDException,
)
from .download_default import DownloadQueueService, TqdmProgress
__all__ = [
"DownloadJob",
"MultiFileDownloadJob",
"DownloadQueueServiceBase",
"DownloadQueueService",
"TqdmProgress",

View File

@ -5,11 +5,13 @@ from abc import ABC, abstractmethod
from enum import Enum
from functools import total_ordering
from pathlib import Path
from typing import Any, Callable, List, Optional
from typing import Any, Callable, List, Optional, Set
from pydantic import BaseModel, Field, PrivateAttr
from pydantic.networks import AnyHttpUrl
from invokeai.backend.model_manager.metadata import RemoteModelFile
class DownloadJobStatus(str, Enum):
"""State of a download job."""
@ -33,30 +35,19 @@ class ServiceInactiveException(Exception):
"""This exception is raised when user attempts to initiate a download before the service is started."""
DownloadEventHandler = Callable[["DownloadJob"], None]
DownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
DownloadEventHandler = Callable[["DownloadJobBase"], None]
DownloadExceptionHandler = Callable[["DownloadJobBase", Optional[Exception]], None]
MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None]
MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None]
@total_ordering
class DownloadJob(BaseModel):
"""Class to monitor and control a model download request."""
class DownloadJobBase(BaseModel):
"""Base of classes to monitor and control downloads."""
# required variables to be passed in on creation
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
dest: Path = Field(description="Destination of downloaded model on local disk; a directory or file path")
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
# automatically assigned on creation
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
# set internally during download process
dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path")
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory")
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file")
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
job_ended: Optional[str] = Field(
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
)
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
bytes: int = Field(default=0, description="Bytes downloaded so far")
total_bytes: int = Field(default=0, description="Total file size (bytes)")
@ -74,14 +65,6 @@ class DownloadJob(BaseModel):
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None)
def __hash__(self) -> int:
"""Return hash of the string representation of this object, for indexing."""
return hash(str(self))
def __le__(self, other: "DownloadJob") -> bool:
"""Return True if this job's priority is less than another's."""
return self.priority <= other.priority
def cancel(self) -> None:
"""Call to cancel the job."""
self._cancelled = True
@ -98,6 +81,11 @@ class DownloadJob(BaseModel):
"""Return true if job completed without errors."""
return self.status == DownloadJobStatus.COMPLETED
@property
def waiting(self) -> bool:
"""Return true if the job is waiting to run."""
return self.status == DownloadJobStatus.WAITING
@property
def running(self) -> bool:
"""Return true if the job is running."""
@ -154,6 +142,39 @@ class DownloadJob(BaseModel):
self._on_cancelled = on_cancelled
@total_ordering
class DownloadJob(DownloadJobBase):
"""Class to monitor and control a model download request."""
# required variables to be passed in on creation
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
# automatically assigned on creation
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
# set internally during download process
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
job_ended: Optional[str] = Field(
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
)
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
def __hash__(self) -> int:
"""Return hash of the string representation of this object, for indexing."""
return hash(str(self))
def __le__(self, other: "DownloadJob") -> bool:
"""Return True if this job's priority is less than another's."""
return self.priority <= other.priority
class MultiFileDownloadJob(DownloadJobBase):
"""Class to monitor and control multifile downloads."""
download_parts: Set[DownloadJob] = Field(default_factory=set, description="List of download parts.")
class DownloadQueueServiceBase(ABC):
"""Multithreaded queue for downloading models via URL."""
@ -201,6 +222,33 @@ class DownloadQueueServiceBase(ABC):
"""
pass
@abstractmethod
def multifile_download(
self,
parts: Set[RemoteModelFile],
dest: Path,
access_token: Optional[str] = None,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> MultiFileDownloadJob:
"""
Create and enqueue a multifile download job.
:param parts: Set of URL / filename pairs
:param dest: Path to download to. See below.
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
events.
:returns: A MultiFileDownloadJob object for monitoring the state of the download.
The `dest` argument is a Path object pointing to a directory. All downloads
with be placed inside this directory. The callbacks will receive the
MultiFileDownloadJob.
"""
pass
@abstractmethod
def submit_download_job(
self,
@ -262,7 +310,7 @@ class DownloadQueueServiceBase(ABC):
pass
@abstractmethod
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
def wait_for_job(self, job: DownloadJob | MultiFileDownloadJob, timeout: int = 0) -> DownloadJob:
"""Wait until the indicated download job has reached a terminal state.
This will block until the indicated install job has completed,

View File

@ -18,6 +18,7 @@ from tqdm import tqdm
from invokeai.app.services.config import InvokeAIAppConfig, get_config
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.backend.model_manager.metadata import RemoteModelFile
from invokeai.backend.util.logging import InvokeAILogger
from .download_base import (
@ -27,6 +28,9 @@ from .download_base import (
DownloadJobCancelledException,
DownloadJobStatus,
DownloadQueueServiceBase,
MultiFileDownloadEventHandler,
MultiFileDownloadExceptionHandler,
MultiFileDownloadJob,
ServiceInactiveException,
UnknownJobIDException,
)
@ -54,10 +58,11 @@ class DownloadQueueService(DownloadQueueServiceBase):
"""
self._app_config = app_config or get_config()
self._jobs: Dict[int, DownloadJob] = {}
self._download_part2parent: Dict[AnyHttpUrl, MultiFileDownloadJob] = {}
self._next_job_id = 0
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
self._stop_event = threading.Event()
self._job_completed_event = threading.Event()
self._job_terminated_event = threading.Event()
self._worker_pool: Set[threading.Thread] = set()
self._lock = threading.Lock()
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
@ -155,6 +160,49 @@ class DownloadQueueService(DownloadQueueServiceBase):
)
return job
def multifile_download(
self,
parts: Set[RemoteModelFile],
dest: Path,
access_token: Optional[str] = None,
on_start: Optional[MultiFileDownloadEventHandler] = None,
on_progress: Optional[MultiFileDownloadEventHandler] = None,
on_complete: Optional[MultiFileDownloadEventHandler] = None,
on_cancelled: Optional[MultiFileDownloadEventHandler] = None,
on_error: Optional[MultiFileDownloadExceptionHandler] = None,
) -> MultiFileDownloadJob:
mfdj = MultiFileDownloadJob(dest=dest)
mfdj.set_callbacks(
on_start=on_start,
on_progress=on_progress,
on_complete=on_complete,
on_cancelled=on_cancelled,
on_error=on_error,
)
for part in parts:
url = part.url
path = dest / part.path
assert path.is_relative_to(dest), "only relative download paths accepted"
job = DownloadJob(
source=url,
dest=path,
access_token=access_token,
)
mfdj.download_parts.add(job)
self._download_part2parent[job.source] = mfdj
for download_job in mfdj.download_parts:
self.submit_download_job(
download_job,
on_start=self._mfd_started,
on_progress=self._mfd_progress,
on_complete=self._mfd_complete,
on_cancelled=self._mfd_cancelled,
on_error=self._mfd_error,
)
return mfdj
def join(self) -> None:
"""Wait for all jobs to complete."""
self._queue.join()
@ -187,7 +235,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
If it is running it will be stopped.
job.status will be set to DownloadJobStatus.CANCELLED
"""
with self._lock:
if job.status in [DownloadJobStatus.WAITING, DownloadJobStatus.RUNNING]:
job.cancel()
def cancel_all_jobs(self) -> None:
@ -196,12 +244,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
if not job.in_terminal_state:
self.cancel_job(job)
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
def wait_for_job(self, job: DownloadJob | MultiFileDownloadJob, timeout: int = 0) -> DownloadJob:
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
start = time.time()
while not job.in_terminal_state:
if self._job_completed_event.wait(timeout=0.25): # in case we miss an event
self._job_completed_event.clear()
if self._job_terminated_event.wait(timeout=0.25): # in case we miss an event
self._job_terminated_event.clear()
if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded")
return job
@ -230,22 +278,25 @@ class DownloadQueueService(DownloadQueueServiceBase):
job.job_started = get_iso_timestamp()
self._do_download(job)
self._signal_job_complete(job)
except (OSError, HTTPError) as excp:
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
job.error = traceback.format_exc()
self._signal_job_error(job, excp)
except DownloadJobCancelledException:
self._signal_job_cancelled(job)
self._cleanup_cancelled_job(job)
except Exception as excp:
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
job.error = traceback.format_exc()
self._signal_job_error(job, excp)
finally:
job.job_ended = get_iso_timestamp()
self._job_completed_event.set() # signal a change to terminal state
self._job_terminated_event.set() # signal a change to terminal state
self._download_part2parent.pop(job.source, None) # if this is a subpart of a multipart job, remove it
self._job_terminated_event.set()
self._queue.task_done()
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
def _do_download(self, job: DownloadJob) -> None:
"""Do the actual download."""
url = job.source
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
open_mode = "wb"
@ -339,7 +390,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]:
# Pull the token from config if it exists and matches the URL
print(self._app_config)
token = None
for pair in self._app_config.remote_api_tokens or []:
if re.search(pair.url_regex, str(source)):
@ -349,25 +399,13 @@ class DownloadQueueService(DownloadQueueServiceBase):
def _signal_job_started(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.RUNNING
if job.on_start:
try:
job.on_start(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
)
self._execute_cb(job, "on_start")
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix())
def _signal_job_progress(self, job: DownloadJob) -> None:
if job.on_progress:
try:
job.on_progress(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
)
self._execute_cb(job, "on_progress")
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_progress(
@ -379,13 +417,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
def _signal_job_complete(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.COMPLETED
if job.on_complete:
try:
job.on_complete(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
)
self._execute_cb(job, "on_complete")
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_complete(
@ -396,26 +428,21 @@ class DownloadQueueService(DownloadQueueServiceBase):
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
return
job.status = DownloadJobStatus.CANCELLED
if job.on_cancelled:
try:
job.on_cancelled(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
)
self._execute_cb(job, "on_cancelled")
if self._event_bus:
self._event_bus.emit_download_cancelled(str(job.source))
# if multifile download, then signal the parent
if parent_job := self._download_part2parent.get(job.source, None):
if not parent_job.in_terminal_state:
parent_job.status = DownloadJobStatus.CANCELLED
self._execute_cb(parent_job, "on_cancelled")
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
job.status = DownloadJobStatus.ERROR
self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}")
if job.on_error:
try:
job.on_error(job, excp)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
)
self._execute_cb(job, "on_error", excp)
if self._event_bus:
assert job.error_type
assert job.error
@ -430,6 +457,86 @@ class DownloadQueueService(DownloadQueueServiceBase):
except OSError as excp:
self._logger.warning(excp)
########################################
# callbacks used for multifile downloads
########################################
def _mfd_started(self, download_job: DownloadJob) -> None:
self._logger.info(f"File download started: {download_job.source}")
with self._lock:
mf_job = self._download_part2parent[download_job.source]
if mf_job.waiting:
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
mf_job.status = DownloadJobStatus.RUNNING
self._execute_cb(mf_job, "on_start")
def _mfd_progress(self, download_job: DownloadJob) -> None:
with self._lock:
mf_job = self._download_part2parent[download_job.source]
if mf_job.cancelled:
for part in mf_job.download_parts:
self.cancel_job(part)
elif mf_job.running:
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
mf_job.bytes = sum(x.total_bytes for x in mf_job.download_parts)
self._execute_cb(mf_job, "on_progress")
def _mfd_complete(self, download_job: DownloadJob) -> None:
self._logger.info(f"Download complete: {download_job.source}")
with self._lock:
mf_job = self._download_part2parent[download_job.source]
# are there any more active jobs left in this task?
if mf_job.running and all(x.complete for x in mf_job.download_parts):
mf_job.status = DownloadJobStatus.COMPLETED
self._execute_cb(mf_job, "on_complete")
# we're done with this sub-job
self._job_terminated_event.set()
def _mfd_cancelled(self, download_job: DownloadJob) -> None:
with self._lock:
mf_job = self._download_part2parent[download_job.source]
assert mf_job is not None
if not mf_job.in_terminal_state:
self._logger.warning(f"Download cancelled: {download_job.source}")
mf_job.cancel()
for s in mf_job.download_parts:
self.cancel_job(s)
def _mfd_error(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
with self._lock:
mf_job = self._download_part2parent[download_job.source]
assert mf_job is not None
if not mf_job.in_terminal_state:
mf_job.status = download_job.status
mf_job.error = download_job.error
mf_job.error_type = download_job.error_type
self._execute_cb(mf_job, "on_error", excp)
self._logger.error(
f"Cancelling {mf_job.dest} due to an error while downloading {download_job.source}: {str(excp)}"
)
for s in [x for x in mf_job.download_parts if x.running]:
self.cancel_job(s)
self._download_part2parent.pop(download_job.source)
self._job_terminated_event.set()
def _execute_cb(
self,
job: DownloadJob | MultiFileDownloadJob,
callback_name: str,
excp: Optional[Exception] = None,
) -> None:
if callback := getattr(job, callback_name, None):
args = [job, excp] if excp else [job]
try:
callback(*args)
except Exception as e:
self._logger.error(
f"An error occurred while processing the {callback_name} callback: {traceback.format_exception(e)}"
)
def get_pc_name_max(directory: str) -> int:
if hasattr(os, "pathconf"):

View File

@ -689,7 +689,6 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.debug(f"metadata={metadata}")
if metadata and isinstance(metadata, ModelMetadataWithFiles):
remote_files = metadata.download_urls(session=self._session)
print(remote_files)
else:
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
return self._import_remote_model(

View File

@ -37,7 +37,7 @@ class RemoteModelFile(BaseModel):
url: AnyHttpUrl = Field(description="The url to download this model file")
path: Path = Field(description="The path to the file, relative to the model root")
size: int = Field(description="The size of this file, in bytes")
size: Optional[int] = Field(description="The size of this file, in bytes", default=0)
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)

View File

@ -4,7 +4,7 @@ import re
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Generator
from typing import Generator, Optional
import pytest
from pydantic.networks import AnyHttpUrl
@ -13,7 +13,8 @@ from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.config import get_config
from invokeai.app.services.config.config_default import URLRegexTokenPair
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob
from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, RemoteModelFile
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
from tests.test_nodes import TestEventService
@ -67,11 +68,116 @@ def session() -> Session:
return sess
@pytest.mark.timeout(timeout=10, method="thread")
def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
fetcher = HuggingFaceMetadataFetch(mm2_session)
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
events = set()
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
print(f"bytes = {job.bytes}")
events.add(job.status)
queue = DownloadQueueService(
requests_session=mm2_session,
)
queue.start()
job = queue.multifile_download(
parts=metadata.download_urls(session=mm2_session),
dest=tmp_path,
on_start=event_handler,
on_progress=event_handler,
on_complete=event_handler,
on_error=event_handler,
)
assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase"
queue.join()
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
assert job.bytes > 0, "expected download bytes to be positive"
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
assert Path(
tmp_path, "sdxl-turbo/model_index.json"
).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist"
assert Path(
tmp_path, "sdxl-turbo/text_encoder/config.json"
).exists(), f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist"
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
queue.stop()
@pytest.mark.timeout(timeout=10, method="thread")
def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
fetcher = HuggingFaceMetadataFetch(mm2_session)
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
events = set()
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
events.add(job.status)
queue = DownloadQueueService(
requests_session=mm2_session,
)
queue.start()
files = metadata.download_urls(session=mm2_session)
# this will give a 404 error
files.append(RemoteModelFile(url="https://test.com/missing_model.safetensors", path=Path("sdxl-turbo/broken")))
job = queue.multifile_download(
parts=files,
dest=tmp_path,
on_start=event_handler,
on_progress=event_handler,
on_complete=event_handler,
on_error=event_handler,
)
queue.join()
assert job.status == DownloadJobStatus("error"), "expected job status to be errored"
assert "HTTPError(NOT FOUND)" in job.error_type
assert DownloadJobStatus.ERROR in events
queue.stop()
@pytest.mark.timeout(timeout=15, method="thread")
def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) -> None:
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
queue.start()
cancelled = False
def cancelled_callback(job: DownloadJob) -> None:
nonlocal cancelled
cancelled = True
def handler(signum, frame):
raise TimeoutError("Join took too long to return")
fetcher = HuggingFaceMetadataFetch(mm2_session)
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
job = queue.multifile_download(
parts=metadata.download_urls(session=mm2_session),
dest=tmp_path,
on_cancelled=cancelled_callback,
)
queue.cancel_job(job)
queue.join()
assert job.status == DownloadJobStatus.CANCELLED
assert cancelled
events = event_bus.events
assert "download_cancelled" in [x.event_name for x in events]
queue.stop()
@pytest.mark.timeout(timeout=20, method="thread")
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
events = set()
def event_handler(job: DownloadJob) -> None:
def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None:
events.add(job.status)
queue = DownloadQueueService(