mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-09 04:18:46 +08:00
tests: update tests to use new events
This commit is contained in:
parent
655f62008f
commit
a876675448
@ -9,6 +9,11 @@ import pytest
|
||||
from invokeai.app.services.board_records.board_records_common import BoardRecord, BoardRecordNotFoundException
|
||||
from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadTargetException
|
||||
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
|
||||
from invokeai.app.services.events.events_common import (
|
||||
BulkDownloadCompleteEvent,
|
||||
BulkDownloadErrorEvent,
|
||||
BulkDownloadStartedEvent,
|
||||
)
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageRecordNotFoundException,
|
||||
@ -281,9 +286,9 @@ def assert_handler_success(
|
||||
|
||||
# Check that the correct events were emitted
|
||||
assert len(event_bus.events) == 2
|
||||
assert event_bus.events[0].event_name == "bulk_download_started"
|
||||
assert event_bus.events[1].event_name == "bulk_download_completed"
|
||||
assert event_bus.events[1].payload["bulk_download_item_name"] == os.path.basename(expected_zip_path)
|
||||
assert isinstance(event_bus.events[0], BulkDownloadStartedEvent)
|
||||
assert isinstance(event_bus.events[1], BulkDownloadCompleteEvent)
|
||||
assert event_bus.events[1].bulk_download_item_name == os.path.basename(expected_zip_path)
|
||||
|
||||
|
||||
def test_handler_on_image_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
|
||||
@ -329,9 +334,9 @@ def test_handler_on_generic_exception(
|
||||
event_bus: TestEventService = mock_invoker.services.events
|
||||
|
||||
assert len(event_bus.events) == 2
|
||||
assert event_bus.events[0].event_name == "bulk_download_started"
|
||||
assert event_bus.events[1].event_name == "bulk_download_failed"
|
||||
assert event_bus.events[1].payload["error"] == exception.__str__()
|
||||
assert isinstance(event_bus.events[0], BulkDownloadStartedEvent)
|
||||
assert isinstance(event_bus.events[1], BulkDownloadErrorEvent)
|
||||
assert event_bus.events[1].error == exception.__str__()
|
||||
|
||||
|
||||
def execute_handler_test_on_error(
|
||||
@ -344,9 +349,9 @@ def execute_handler_test_on_error(
|
||||
event_bus: TestEventService = mock_invoker.services.events
|
||||
|
||||
assert len(event_bus.events) == 2
|
||||
assert event_bus.events[0].event_name == "bulk_download_started"
|
||||
assert event_bus.events[1].event_name == "bulk_download_failed"
|
||||
assert event_bus.events[1].payload["error"] == error.__str__()
|
||||
assert isinstance(event_bus.events[0], BulkDownloadStartedEvent)
|
||||
assert isinstance(event_bus.events[1], BulkDownloadErrorEvent)
|
||||
assert event_bus.events[1].error == error.__str__()
|
||||
|
||||
|
||||
def test_delete(tmp_path: Path):
|
||||
|
@ -10,6 +10,13 @@ from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
|
||||
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
|
||||
from invokeai.app.services.events.events_common import (
|
||||
DownloadCancelledEvent,
|
||||
DownloadCompleteEvent,
|
||||
DownloadErrorEvent,
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
)
|
||||
from tests.test_nodes import TestEventService
|
||||
|
||||
# Prevent pytest deprecation warnings
|
||||
@ -116,14 +123,22 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
|
||||
queue.join()
|
||||
events = event_bus.events
|
||||
assert len(events) == 3
|
||||
assert events[0].payload["timestamp"] <= events[1].payload["timestamp"]
|
||||
assert events[1].payload["timestamp"] <= events[2].payload["timestamp"]
|
||||
assert events[0].event_name == "download_started"
|
||||
assert events[1].event_name == "download_progress"
|
||||
assert events[1].payload["total_bytes"] > 0
|
||||
assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"]
|
||||
assert events[2].event_name == "download_complete"
|
||||
assert events[2].payload["total_bytes"] == 32029
|
||||
assert isinstance(events[0], DownloadStartedEvent)
|
||||
assert isinstance(events[1], DownloadProgressEvent)
|
||||
assert isinstance(events[2], DownloadCompleteEvent)
|
||||
assert events[0].timestamp <= events[1].timestamp
|
||||
assert events[1].timestamp <= events[2].timestamp
|
||||
assert events[1].total_bytes > 0
|
||||
assert events[1].current_bytes <= events[1].total_bytes
|
||||
assert events[2].total_bytes == 32029
|
||||
# assert events[0].payload["timestamp"] <= events[1].payload["timestamp"]
|
||||
# assert events[1].payload["timestamp"] <= events[2].payload["timestamp"]
|
||||
# assert events[0].event_name == "download_started"
|
||||
# assert events[1].event_name == "download_progress"
|
||||
# assert events[1].payload["total_bytes"] > 0
|
||||
# assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"]
|
||||
# assert events[2].event_name == "download_complete"
|
||||
# assert events[2].payload["total_bytes"] == 32029
|
||||
|
||||
# test a failure
|
||||
event_bus.events = [] # reset our accumulator
|
||||
@ -132,10 +147,15 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
|
||||
events = event_bus.events
|
||||
print("\n".join([x.model_dump_json() for x in events]))
|
||||
assert len(events) == 1
|
||||
assert events[0].event_name == "download_error"
|
||||
assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)"
|
||||
assert events[0].payload["error"] is not None
|
||||
assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"])
|
||||
assert isinstance(events[0], DownloadErrorEvent)
|
||||
assert events[0].error_type == "HTTPError(NOT FOUND)"
|
||||
assert events[0].error is not None
|
||||
assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].error)
|
||||
|
||||
# assert events[0].event_name == "download_error"
|
||||
# assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)"
|
||||
# assert events[0].payload["error"] is not None
|
||||
# assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"])
|
||||
queue.stop()
|
||||
|
||||
|
||||
@ -202,6 +222,8 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||
assert job.status == DownloadJobStatus.CANCELLED
|
||||
assert cancelled
|
||||
events = event_bus.events
|
||||
assert events[-1].event_name == "download_cancelled"
|
||||
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
|
||||
assert isinstance(events[-1], DownloadCancelledEvent)
|
||||
assert events[-1].source == "http://www.civitai.com/models/12345"
|
||||
# assert events[-1].event_name == "download_cancelled"
|
||||
# assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
|
||||
queue.stop()
|
||||
|
@ -13,6 +13,12 @@ from pydantic_core import Url
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.events.events_common import (
|
||||
ModelInstallCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallStartedEvent,
|
||||
)
|
||||
from invokeai.app.services.model_install import (
|
||||
ModelInstallServiceBase,
|
||||
)
|
||||
@ -25,6 +31,7 @@ from invokeai.app.services.model_install.model_install_common import (
|
||||
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
|
||||
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
from tests.test_nodes import TestEventService
|
||||
|
||||
OS = platform.uname().system
|
||||
|
||||
@ -132,19 +139,26 @@ def test_background_install(
|
||||
assert job.total_bytes == size
|
||||
|
||||
# test that the expected events were issued
|
||||
bus = mm2_installer.event_bus
|
||||
bus: TestEventService = mm2_installer.event_bus
|
||||
assert bus
|
||||
assert hasattr(bus, "events")
|
||||
|
||||
assert len(bus.events) == 2
|
||||
event_names = [x.event_name for x in bus.events]
|
||||
assert "model_install_running" in event_names
|
||||
assert "model_install_completed" in event_names
|
||||
assert Path(bus.events[0].payload["source"]) == source
|
||||
assert Path(bus.events[1].payload["source"]) == source
|
||||
key = bus.events[1].payload["key"]
|
||||
assert isinstance(bus.events[0], ModelInstallStartedEvent)
|
||||
assert isinstance(bus.events[1], ModelInstallCompleteEvent)
|
||||
assert Path(bus.events[0].source) == source
|
||||
assert Path(bus.events[1].source) == source
|
||||
key = bus.events[1].key
|
||||
assert key is not None
|
||||
|
||||
# event_names = [x.event_name for x in bus.events]
|
||||
# assert "model_install_running" in event_names
|
||||
# assert "model_install_completed" in event_names
|
||||
# assert Path(bus.events[0].payload["source"]) == source
|
||||
# assert Path(bus.events[1].payload["source"]) == source
|
||||
# key = bus.events[1].payload["key"]
|
||||
# assert key is not None
|
||||
|
||||
# see if the thing actually got installed at the expected location
|
||||
model_record = mm2_installer.record_store.get_model(key)
|
||||
assert model_record is not None
|
||||
@ -221,7 +235,7 @@ def test_delete_register(
|
||||
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
|
||||
|
||||
bus = mm2_installer.event_bus
|
||||
bus: TestEventService = mm2_installer.event_bus
|
||||
store = mm2_installer.record_store
|
||||
assert store is not None
|
||||
assert bus is not None
|
||||
@ -239,20 +253,17 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
|
||||
assert len(bus.events) == 4
|
||||
event_names = [x.event_name for x in bus.events]
|
||||
assert event_names == [
|
||||
"model_install_downloading",
|
||||
"model_install_downloads_done",
|
||||
"model_install_running",
|
||||
"model_install_completed",
|
||||
]
|
||||
assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent)
|
||||
assert isinstance(bus.events[1], ModelInstallDownloadsCompleteEvent)
|
||||
assert isinstance(bus.events[2], ModelInstallStartedEvent)
|
||||
assert isinstance(bus.events[3], ModelInstallCompleteEvent)
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
|
||||
bus = mm2_installer.event_bus
|
||||
bus: TestEventService = mm2_installer.event_bus
|
||||
store = mm2_installer.record_store
|
||||
assert isinstance(bus, EventServiceBase)
|
||||
assert store is not None
|
||||
@ -269,15 +280,10 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
|
||||
assert model_record.type == ModelType.Main
|
||||
assert model_record.format == ModelFormat.Diffusers
|
||||
|
||||
assert hasattr(bus, "events") # the dummyeventservice has this
|
||||
assert any(isinstance(x, ModelInstallStartedEvent) for x in bus.events)
|
||||
assert any(isinstance(x, ModelInstallDownloadProgressEvent) for x in bus.events)
|
||||
assert any(isinstance(x, ModelInstallCompleteEvent) for x in bus.events)
|
||||
assert len(bus.events) >= 3
|
||||
event_names = {x.event_name for x in bus.events}
|
||||
assert event_names == {
|
||||
"model_install_downloading",
|
||||
"model_install_downloads_done",
|
||||
"model_install_running",
|
||||
"model_install_completed",
|
||||
}
|
||||
|
||||
|
||||
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
|
@ -3,16 +3,13 @@
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase
|
||||
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase
|
||||
@ -39,27 +36,7 @@ from tests.backend.model_manager.model_metadata.metadata_examples import (
|
||||
RepoHFModelJson1,
|
||||
)
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
|
||||
|
||||
class DummyEvent(BaseModel):
|
||||
"""Dummy Event to use with Dummy Event service."""
|
||||
|
||||
event_name: str
|
||||
payload: Dict[str, Any]
|
||||
|
||||
|
||||
class DummyEventService(EventServiceBase):
|
||||
"""Dummy event service for testing."""
|
||||
|
||||
events: List[DummyEvent]
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.events = []
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
"""Dispatch an event by appending it to self.events."""
|
||||
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
|
||||
from tests.test_nodes import TestEventService
|
||||
|
||||
|
||||
# Create a temporary directory using the contents of `./data/invokeai_root` as the template
|
||||
@ -127,7 +104,7 @@ def mm2_installer(
|
||||
) -> ModelInstallServiceBase:
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
events = DummyEventService()
|
||||
events = TestEventService()
|
||||
store = ModelRecordServiceSQL(db)
|
||||
|
||||
installer = ModelInstallService(
|
||||
|
@ -1,7 +1,5 @@
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@ -10,6 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
)
|
||||
from invokeai.app.invocations.fields import InputField, OutputField
|
||||
from invokeai.app.invocations.image import ImageField
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@ -117,11 +116,10 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg
|
||||
)
|
||||
|
||||
|
||||
class TestEvent(BaseModel):
|
||||
class TestEvent(EventBase):
|
||||
__test__ = False # not a pytest test case
|
||||
|
||||
event_name: str
|
||||
payload: Any
|
||||
__event_name__ = "test_event"
|
||||
|
||||
|
||||
class TestEventService(EventServiceBase):
|
||||
@ -129,10 +127,10 @@ class TestEventService(EventServiceBase):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.events: list[TestEvent] = []
|
||||
self.events: list[EventBase] = []
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
self.events.append(TestEvent(event_name=payload["event"], payload=payload["data"]))
|
||||
def dispatch(self, event: EventBase) -> None:
|
||||
self.events.append(event)
|
||||
pass
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user