tests: update tests to use new events

This commit is contained in:
psychedelicious 2024-03-14 18:51:17 +11:00
parent 655f62008f
commit a876675448
5 changed files with 88 additions and 80 deletions

View File

@ -9,6 +9,11 @@ import pytest
from invokeai.app.services.board_records.board_records_common import BoardRecord, BoardRecordNotFoundException
from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadTargetException
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.events.events_common import (
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadStartedEvent,
)
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageRecordNotFoundException,
@ -281,9 +286,9 @@ def assert_handler_success(
# Check that the correct events were emitted
assert len(event_bus.events) == 2
assert event_bus.events[0].event_name == "bulk_download_started"
assert event_bus.events[1].event_name == "bulk_download_completed"
assert event_bus.events[1].payload["bulk_download_item_name"] == os.path.basename(expected_zip_path)
assert isinstance(event_bus.events[0], BulkDownloadStartedEvent)
assert isinstance(event_bus.events[1], BulkDownloadCompleteEvent)
assert event_bus.events[1].bulk_download_item_name == os.path.basename(expected_zip_path)
def test_handler_on_image_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
@ -329,9 +334,9 @@ def test_handler_on_generic_exception(
event_bus: TestEventService = mock_invoker.services.events
assert len(event_bus.events) == 2
assert event_bus.events[0].event_name == "bulk_download_started"
assert event_bus.events[1].event_name == "bulk_download_failed"
assert event_bus.events[1].payload["error"] == exception.__str__()
assert isinstance(event_bus.events[0], BulkDownloadStartedEvent)
assert isinstance(event_bus.events[1], BulkDownloadErrorEvent)
assert event_bus.events[1].error == exception.__str__()
def execute_handler_test_on_error(
@ -344,9 +349,9 @@ def execute_handler_test_on_error(
event_bus: TestEventService = mock_invoker.services.events
assert len(event_bus.events) == 2
assert event_bus.events[0].event_name == "bulk_download_started"
assert event_bus.events[1].event_name == "bulk_download_failed"
assert event_bus.events[1].payload["error"] == error.__str__()
assert isinstance(event_bus.events[0], BulkDownloadStartedEvent)
assert isinstance(event_bus.events[1], BulkDownloadErrorEvent)
assert event_bus.events[1].error == error.__str__()
def test_delete(tmp_path: Path):

View File

@ -10,6 +10,13 @@ from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
from invokeai.app.services.events.events_common import (
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
)
from tests.test_nodes import TestEventService
# Prevent pytest deprecation warnings
@ -116,14 +123,22 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
queue.join()
events = event_bus.events
assert len(events) == 3
assert events[0].payload["timestamp"] <= events[1].payload["timestamp"]
assert events[1].payload["timestamp"] <= events[2].payload["timestamp"]
assert events[0].event_name == "download_started"
assert events[1].event_name == "download_progress"
assert events[1].payload["total_bytes"] > 0
assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"]
assert events[2].event_name == "download_complete"
assert events[2].payload["total_bytes"] == 32029
assert isinstance(events[0], DownloadStartedEvent)
assert isinstance(events[1], DownloadProgressEvent)
assert isinstance(events[2], DownloadCompleteEvent)
assert events[0].timestamp <= events[1].timestamp
assert events[1].timestamp <= events[2].timestamp
assert events[1].total_bytes > 0
assert events[1].current_bytes <= events[1].total_bytes
assert events[2].total_bytes == 32029
# assert events[0].payload["timestamp"] <= events[1].payload["timestamp"]
# assert events[1].payload["timestamp"] <= events[2].payload["timestamp"]
# assert events[0].event_name == "download_started"
# assert events[1].event_name == "download_progress"
# assert events[1].payload["total_bytes"] > 0
# assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"]
# assert events[2].event_name == "download_complete"
# assert events[2].payload["total_bytes"] == 32029
# test a failure
event_bus.events = [] # reset our accumulator
@ -132,10 +147,15 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
events = event_bus.events
print("\n".join([x.model_dump_json() for x in events]))
assert len(events) == 1
assert events[0].event_name == "download_error"
assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)"
assert events[0].payload["error"] is not None
assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"])
assert isinstance(events[0], DownloadErrorEvent)
assert events[0].error_type == "HTTPError(NOT FOUND)"
assert events[0].error is not None
assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].error)
# assert events[0].event_name == "download_error"
# assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)"
# assert events[0].payload["error"] is not None
# assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"])
queue.stop()
@ -202,6 +222,8 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
assert job.status == DownloadJobStatus.CANCELLED
assert cancelled
events = event_bus.events
assert events[-1].event_name == "download_cancelled"
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
assert isinstance(events[-1], DownloadCancelledEvent)
assert events[-1].source == "http://www.civitai.com/models/12345"
# assert events[-1].event_name == "download_cancelled"
# assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
queue.stop()

View File

@ -13,6 +13,12 @@ from pydantic_core import Url
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.events.events_common import (
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallStartedEvent,
)
from invokeai.app.services.model_install import (
ModelInstallServiceBase,
)
@ -25,6 +31,7 @@ from invokeai.app.services.model_install.model_install_common import (
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
from tests.test_nodes import TestEventService
OS = platform.uname().system
@ -132,19 +139,26 @@ def test_background_install(
assert job.total_bytes == size
# test that the expected events were issued
bus = mm2_installer.event_bus
bus: TestEventService = mm2_installer.event_bus
assert bus
assert hasattr(bus, "events")
assert len(bus.events) == 2
event_names = [x.event_name for x in bus.events]
assert "model_install_running" in event_names
assert "model_install_completed" in event_names
assert Path(bus.events[0].payload["source"]) == source
assert Path(bus.events[1].payload["source"]) == source
key = bus.events[1].payload["key"]
assert isinstance(bus.events[0], ModelInstallStartedEvent)
assert isinstance(bus.events[1], ModelInstallCompleteEvent)
assert Path(bus.events[0].source) == source
assert Path(bus.events[1].source) == source
key = bus.events[1].key
assert key is not None
# event_names = [x.event_name for x in bus.events]
# assert "model_install_running" in event_names
# assert "model_install_completed" in event_names
# assert Path(bus.events[0].payload["source"]) == source
# assert Path(bus.events[1].payload["source"]) == source
# key = bus.events[1].payload["key"]
# assert key is not None
# see if the thing actually got installed at the expected location
model_record = mm2_installer.record_store.get_model(key)
assert model_record is not None
@ -221,7 +235,7 @@ def test_delete_register(
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
bus = mm2_installer.event_bus
bus: TestEventService = mm2_installer.event_bus
store = mm2_installer.record_store
assert store is not None
assert bus is not None
@ -239,20 +253,17 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
assert (mm2_app_config.models_path / model_record.path).exists()
assert len(bus.events) == 4
event_names = [x.event_name for x in bus.events]
assert event_names == [
"model_install_downloading",
"model_install_downloads_done",
"model_install_running",
"model_install_completed",
]
assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent)
assert isinstance(bus.events[1], ModelInstallDownloadsCompleteEvent)
assert isinstance(bus.events[2], ModelInstallStartedEvent)
assert isinstance(bus.events[3], ModelInstallCompleteEvent)
@pytest.mark.timeout(timeout=20, method="thread")
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
bus = mm2_installer.event_bus
bus: TestEventService = mm2_installer.event_bus
store = mm2_installer.record_store
assert isinstance(bus, EventServiceBase)
assert store is not None
@ -269,15 +280,10 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
assert model_record.type == ModelType.Main
assert model_record.format == ModelFormat.Diffusers
assert hasattr(bus, "events") # the dummyeventservice has this
assert any(isinstance(x, ModelInstallStartedEvent) for x in bus.events)
assert any(isinstance(x, ModelInstallDownloadProgressEvent) for x in bus.events)
assert any(isinstance(x, ModelInstallCompleteEvent) for x in bus.events)
assert len(bus.events) >= 3
event_names = {x.event_name for x in bus.events}
assert event_names == {
"model_install_downloading",
"model_install_downloads_done",
"model_install_running",
"model_install_completed",
}
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:

View File

@ -3,16 +3,13 @@
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List
import pytest
from pydantic import BaseModel
from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase
@ -39,27 +36,7 @@ from tests.backend.model_manager.model_metadata.metadata_examples import (
RepoHFModelJson1,
)
from tests.fixtures.sqlite_database import create_mock_sqlite_database
class DummyEvent(BaseModel):
"""Dummy Event to use with Dummy Event service."""
event_name: str
payload: Dict[str, Any]
class DummyEventService(EventServiceBase):
"""Dummy event service for testing."""
events: List[DummyEvent]
def __init__(self) -> None:
super().__init__()
self.events = []
def dispatch(self, event_name: str, payload: Any) -> None:
"""Dispatch an event by appending it to self.events."""
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
from tests.test_nodes import TestEventService
# Create a temporary directory using the contents of `./data/invokeai_root` as the template
@ -127,7 +104,7 @@ def mm2_installer(
) -> ModelInstallServiceBase:
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger)
events = DummyEventService()
events = TestEventService()
store = ModelRecordServiceSQL(db)
installer = ModelInstallService(

View File

@ -1,7 +1,5 @@
from typing import Any, Callable, Union
from pydantic import BaseModel
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -10,6 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.fields import InputField, OutputField
from invokeai.app.invocations.image import ImageField
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.shared.invocation_context import InvocationContext
@ -117,11 +116,10 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg
)
class TestEvent(BaseModel):
class TestEvent(EventBase):
__test__ = False # not a pytest test case
event_name: str
payload: Any
__event_name__ = "test_event"
class TestEventService(EventServiceBase):
@ -129,10 +127,10 @@ class TestEventService(EventServiceBase):
def __init__(self):
super().__init__()
self.events: list[TestEvent] = []
self.events: list[EventBase] = []
def dispatch(self, event_name: str, payload: Any) -> None:
self.events.append(TestEvent(event_name=payload["event"], payload=payload["data"]))
def dispatch(self, event: EventBase) -> None:
self.events.append(event)
pass