From 7be87c8048d6c18431bbedbb2ee9cd2073e7b984 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 8 Mar 2025 18:38:59 +1000 Subject: [PATCH 1/5] refactor(nodes): simpler logic for baseinvocation typeadapter handling --- invokeai/app/invocations/baseinvocation.py | 47 +++++++++------------- tests/test_config.py | 4 +- 2 files changed, 20 insertions(+), 31 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 31ac02ce8e..148e4993be 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -8,6 +8,7 @@ import sys import warnings from abc import ABC, abstractmethod from enum import Enum +from functools import lru_cache from inspect import signature from typing import ( TYPE_CHECKING, @@ -27,7 +28,6 @@ import semver from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefined -from typing_extensions import TypeAliasType from invokeai.app.invocations.fields import ( FieldKind, @@ -101,14 +101,12 @@ class BaseInvocationOutput(BaseModel): """ _output_classes: ClassVar[set[BaseInvocationOutput]] = set() - _typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None - _typeadapter_needs_update: ClassVar[bool] = False @classmethod def register_output(cls, output: BaseInvocationOutput) -> None: """Registers an invocation output.""" cls._output_classes.add(output) - cls._typeadapter_needs_update = True + cls.get_typeadapter.cache_clear() @classmethod def get_outputs(cls) -> Iterable[BaseInvocationOutput]: @@ -116,15 +114,15 @@ class BaseInvocationOutput(BaseModel): return cls._output_classes @classmethod + @lru_cache(maxsize=1) def get_typeadapter(cls) -> TypeAdapter[Any]: - """Gets a pydantc TypeAdapter for the union of all invocation output types.""" - if not cls._typeadapter or cls._typeadapter_needs_update: - AnyInvocationOutput = TypeAliasType( - "AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")] - ) - cls._typeadapter = TypeAdapter(AnyInvocationOutput) - cls._typeadapter_needs_update = False - return cls._typeadapter + """Gets a pydantic TypeAdapter for the union of all invocation output types. + + This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or + denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects + the updated allowlist and denylist. + """ + return TypeAdapter(Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]) @classmethod def get_output_types(cls) -> Iterable[str]: @@ -174,8 +172,6 @@ class BaseInvocation(ABC, BaseModel): """ _invocation_classes: ClassVar[set[BaseInvocation]] = set() - _typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None - _typeadapter_needs_update: ClassVar[bool] = False @classmethod def get_type(cls) -> str: @@ -186,25 +182,18 @@ class BaseInvocation(ABC, BaseModel): def register_invocation(cls, invocation: BaseInvocation) -> None: """Registers an invocation.""" cls._invocation_classes.add(invocation) - cls._typeadapter_needs_update = True + cls.get_typeadapter.cache_clear() @classmethod + @lru_cache(maxsize=1) def get_typeadapter(cls) -> TypeAdapter[Any]: - """Gets a pydantc TypeAdapter for the union of all invocation types.""" - if not cls._typeadapter or cls._typeadapter_needs_update: - AnyInvocation = TypeAliasType( - "AnyInvocation", Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")] - ) - cls._typeadapter = TypeAdapter(AnyInvocation) - cls._typeadapter_needs_update = False - return cls._typeadapter + """Gets a pydantic TypeAdapter for the union of all invocation types. - @classmethod - def invalidate_typeadapter(cls) -> None: - """Invalidates the typeadapter, forcing it to be rebuilt on next access. If the invocation allowlist or - denylist is changed, this should be called to ensure the typeadapter is updated and validation respects - the updated allowlist and denylist.""" - cls._typeadapter_needs_update = True + This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or + denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects + the updated allowlist and denylist. + """ + return TypeAdapter(Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")]) @classmethod def get_invocations(cls) -> Iterable[BaseInvocation]: diff --git a/tests/test_config.py b/tests/test_config.py index 220d6f257a..a4e7b0038c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -274,7 +274,7 @@ def test_deny_nodes(patch_rootdir): # We've changed the config, we need to invalidate the typeadapter cache so that the new config is used for # subsequent graph validations - BaseInvocation.invalidate_typeadapter() + BaseInvocation.get_typeadapter.cache_clear() # confirm graph validation fails when using denied node Graph.model_validate({"nodes": {"1": {"id": "1", "type": "integer"}}}) @@ -296,4 +296,4 @@ def test_deny_nodes(patch_rootdir): # Reset the config so that it doesn't affect other tests get_config.cache_clear() - BaseInvocation.invalidate_typeadapter() + BaseInvocation.get_typeadapter.cache_clear() From 6155f9ff9ec91c7676365a1879cdeebbbba91d4e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 11 Mar 2025 10:02:22 +1000 Subject: [PATCH 2/5] feat(nodes): move invocation/output registration to separate class --- invokeai/app/invocations/baseinvocation.py | 172 ++++++++++----------- invokeai/app/services/shared/graph.py | 9 +- invokeai/app/util/custom_openapi.py | 9 +- tests/test_config.py | 8 +- 4 files changed, 101 insertions(+), 97 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 148e4993be..aa1dbe3af4 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -100,35 +100,6 @@ class BaseInvocationOutput(BaseModel): All invocation outputs must use the `@invocation_output` decorator to provide their unique type. """ - _output_classes: ClassVar[set[BaseInvocationOutput]] = set() - - @classmethod - def register_output(cls, output: BaseInvocationOutput) -> None: - """Registers an invocation output.""" - cls._output_classes.add(output) - cls.get_typeadapter.cache_clear() - - @classmethod - def get_outputs(cls) -> Iterable[BaseInvocationOutput]: - """Gets all invocation outputs.""" - return cls._output_classes - - @classmethod - @lru_cache(maxsize=1) - def get_typeadapter(cls) -> TypeAdapter[Any]: - """Gets a pydantic TypeAdapter for the union of all invocation output types. - - This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or - denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects - the updated allowlist and denylist. - """ - return TypeAdapter(Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]) - - @classmethod - def get_output_types(cls) -> Iterable[str]: - """Gets all invocation output types.""" - return (i.get_type() for i in BaseInvocationOutput.get_outputs()) - @staticmethod def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None: """Adds various UI-facing attributes to the invocation output's OpenAPI schema.""" @@ -171,67 +142,16 @@ class BaseInvocation(ABC, BaseModel): All invocations must use the `@invocation` decorator to provide their unique type. """ - _invocation_classes: ClassVar[set[BaseInvocation]] = set() - @classmethod def get_type(cls) -> str: """Gets the invocation's type, as provided by the `@invocation` decorator.""" return cls.model_fields["type"].default - @classmethod - def register_invocation(cls, invocation: BaseInvocation) -> None: - """Registers an invocation.""" - cls._invocation_classes.add(invocation) - cls.get_typeadapter.cache_clear() - - @classmethod - @lru_cache(maxsize=1) - def get_typeadapter(cls) -> TypeAdapter[Any]: - """Gets a pydantic TypeAdapter for the union of all invocation types. - - This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or - denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects - the updated allowlist and denylist. - """ - return TypeAdapter(Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")]) - - @classmethod - def get_invocations(cls) -> Iterable[BaseInvocation]: - """Gets all invocations, respecting the allowlist and denylist.""" - app_config = get_config() - allowed_invocations: set[BaseInvocation] = set() - for sc in cls._invocation_classes: - invocation_type = sc.get_type() - is_in_allowlist = ( - invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True - ) - is_in_denylist = ( - invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False - ) - if is_in_allowlist and not is_in_denylist: - allowed_invocations.add(sc) - return allowed_invocations - - @classmethod - def get_invocations_map(cls) -> dict[str, BaseInvocation]: - """Gets a map of all invocation types to their invocation classes.""" - return {i.get_type(): i for i in BaseInvocation.get_invocations()} - - @classmethod - def get_invocation_types(cls) -> Iterable[str]: - """Gets all invocation types.""" - return (i.get_type() for i in BaseInvocation.get_invocations()) - @classmethod def get_output_annotation(cls) -> BaseInvocationOutput: """Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method).""" return signature(cls.invoke).return_annotation - @classmethod - def get_invocation_for_type(cls, invocation_type: str) -> BaseInvocation | None: - """Gets the invocation class for a given invocation type.""" - return cls.get_invocations_map().get(invocation_type) - @staticmethod def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None: """Adds various UI-facing attributes to the invocation's OpenAPI schema.""" @@ -329,6 +249,87 @@ class BaseInvocation(ABC, BaseModel): TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation) +class InvocationRegistry: + _invocation_classes: ClassVar[set[type[BaseInvocation]]] = set() + _output_classes: ClassVar[set[type[BaseInvocationOutput]]] = set() + + @classmethod + def register_invocation(cls, invocation: type[BaseInvocation]) -> None: + """Registers an invocation.""" + cls._invocation_classes.add(invocation) + cls.get_invocation_typeadapter.cache_clear() + + @classmethod + @lru_cache(maxsize=1) + def get_invocation_typeadapter(cls) -> TypeAdapter[Any]: + """Gets a pydantic TypeAdapter for the union of all invocation types. + + This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or + denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects + the updated allowlist and denylist. + """ + return TypeAdapter(Annotated[Union[tuple(cls.get_invocation_classes())], Field(discriminator="type")]) + + @classmethod + def get_invocation_classes(cls) -> Iterable[type[BaseInvocation]]: + """Gets all invocations, respecting the allowlist and denylist.""" + app_config = get_config() + allowed_invocations: set[type[BaseInvocation]] = set() + for sc in cls._invocation_classes: + invocation_type = sc.get_type() + is_in_allowlist = ( + invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True + ) + is_in_denylist = ( + invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False + ) + if is_in_allowlist and not is_in_denylist: + allowed_invocations.add(sc) + return allowed_invocations + + @classmethod + def get_invocations_map(cls) -> dict[str, type[BaseInvocation]]: + """Gets a map of all invocation types to their invocation classes.""" + return {i.get_type(): i for i in cls.get_invocation_classes()} + + @classmethod + def get_invocation_types(cls) -> Iterable[str]: + """Gets all invocation types.""" + return (i.get_type() for i in cls.get_invocation_classes()) + + @classmethod + def get_invocation_for_type(cls, invocation_type: str) -> type[BaseInvocation] | None: + """Gets the invocation class for a given invocation type.""" + return cls.get_invocations_map().get(invocation_type) + + @classmethod + def register_output(cls, output: "type[TBaseInvocationOutput]") -> None: + """Registers an invocation output.""" + cls._output_classes.add(output) + cls.get_output_typeadapter.cache_clear() + + @classmethod + def get_output_classes(cls) -> Iterable[type[BaseInvocationOutput]]: + """Gets all invocation outputs.""" + return cls._output_classes + + @classmethod + @lru_cache(maxsize=1) + def get_output_typeadapter(cls) -> TypeAdapter[Any]: + """Gets a pydantic TypeAdapter for the union of all invocation output types. + + This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or + denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects + the updated allowlist and denylist. + """ + return TypeAdapter(Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]) + + @classmethod + def get_output_types(cls) -> Iterable[str]: + """Gets all invocation output types.""" + return (i.get_type() for i in cls.get_output_classes()) + + RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = { "id", "is_intermediate", @@ -442,8 +443,8 @@ def invocation( node_pack = cls.__module__.split(".")[0] # Handle the case where an existing node is being clobbered by the one we are registering - if invocation_type in BaseInvocation.get_invocation_types(): - clobbered_invocation = BaseInvocation.get_invocation_for_type(invocation_type) + if invocation_type in InvocationRegistry.get_invocation_types(): + clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type) # This should always be true - we just checked if the invocation type was in the set assert clobbered_invocation is not None @@ -528,8 +529,7 @@ def invocation( ) cls.__doc__ = docstring - # TODO: how to type this correctly? it's typed as ModelMetaclass, a private class in pydantic - BaseInvocation.register_invocation(cls) # type: ignore + InvocationRegistry.register_invocation(cls) return cls @@ -554,7 +554,7 @@ def invocation_output( if re.compile(r"^\S+$").match(output_type) is None: raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"') - if output_type in BaseInvocationOutput.get_output_types(): + if output_type in InvocationRegistry.get_output_types(): raise ValueError(f'Invocation type "{output_type}" already exists') validate_fields(cls.model_fields, output_type) @@ -575,7 +575,7 @@ def invocation_output( ) cls.__doc__ = docstring - BaseInvocationOutput.register_output(cls) # type: ignore # TODO: how to type this correctly? + InvocationRegistry.register_output(cls) return cls diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 2d425a7515..fef99753ee 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -21,6 +21,7 @@ from invokeai.app.invocations import * # noqa: F401 F403 from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, + InvocationRegistry, invocation, invocation_output, ) @@ -283,7 +284,7 @@ class AnyInvocation(BaseInvocation): @classmethod def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: def validate_invocation(v: Any) -> "AnyInvocation": - return BaseInvocation.get_typeadapter().validate_python(v) + return InvocationRegistry.get_invocation_typeadapter().validate_python(v) return core_schema.no_info_plain_validator_function(validate_invocation) @@ -294,7 +295,7 @@ class AnyInvocation(BaseInvocation): # Nodes are too powerful, we have to make our own OpenAPI schema manually # No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually oneOf: list[dict[str, str]] = [] - names = [i.__name__ for i in BaseInvocation.get_invocations()] + names = [i.__name__ for i in InvocationRegistry.get_invocation_classes()] for name in sorted(names): oneOf.append({"$ref": f"#/components/schemas/{name}"}) return {"oneOf": oneOf} @@ -304,7 +305,7 @@ class AnyInvocationOutput(BaseInvocationOutput): @classmethod def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler): def validate_invocation_output(v: Any) -> "AnyInvocationOutput": - return BaseInvocationOutput.get_typeadapter().validate_python(v) + return InvocationRegistry.get_output_typeadapter().validate_python(v) return core_schema.no_info_plain_validator_function(validate_invocation_output) @@ -316,7 +317,7 @@ class AnyInvocationOutput(BaseInvocationOutput): # No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually oneOf: list[dict[str, str]] = [] - names = [i.__name__ for i in BaseInvocationOutput.get_outputs()] + names = [i.__name__ for i in InvocationRegistry.get_output_classes()] for name in sorted(names): oneOf.append({"$ref": f"#/components/schemas/{name}"}) return {"oneOf": oneOf} diff --git a/invokeai/app/util/custom_openapi.py b/invokeai/app/util/custom_openapi.py index e52028d772..85e17e21c6 100644 --- a/invokeai/app/util/custom_openapi.py +++ b/invokeai/app/util/custom_openapi.py @@ -4,7 +4,10 @@ from fastapi import FastAPI from fastapi.openapi.utils import get_openapi from pydantic.json_schema import models_json_schema -from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase +from invokeai.app.invocations.baseinvocation import ( + InvocationRegistry, + UIConfigBase, +) from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.services.events.events_common import EventBase @@ -56,14 +59,14 @@ def get_openapi_func( invocation_output_map_required: list[str] = [] # We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly. - for output in BaseInvocationOutput.get_outputs(): + for output in InvocationRegistry.get_output_classes(): json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}") move_defs_to_top_level(openapi_schema, json_schema) openapi_schema["components"]["schemas"][output.__name__] = json_schema # Technically, invocations are added to the schema by pydantic, but we still need to manually set their output # property, so we'll just do it all manually. - for invocation in BaseInvocation.get_invocations(): + for invocation in InvocationRegistry.get_invocation_classes(): json_schema = invocation.model_json_schema( mode="serialization", ref_template="#/components/schemas/{model}" ) diff --git a/tests/test_config.py b/tests/test_config.py index a4e7b0038c..087dc89db5 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -5,7 +5,7 @@ from typing import Any import pytest from pydantic import ValidationError -from invokeai.app.invocations.baseinvocation import BaseInvocation +from invokeai.app.invocations.baseinvocation import InvocationRegistry from invokeai.app.services.config.config_default import ( DefaultInvokeAIAppConfig, InvokeAIAppConfig, @@ -274,7 +274,7 @@ def test_deny_nodes(patch_rootdir): # We've changed the config, we need to invalidate the typeadapter cache so that the new config is used for # subsequent graph validations - BaseInvocation.get_typeadapter.cache_clear() + InvocationRegistry.get_invocation_typeadapter.cache_clear() # confirm graph validation fails when using denied node Graph.model_validate({"nodes": {"1": {"id": "1", "type": "integer"}}}) @@ -284,7 +284,7 @@ def test_deny_nodes(patch_rootdir): Graph.model_validate({"nodes": {"1": {"id": "1", "type": "float"}}}) # confirm invocations union will not have denied nodes - all_invocations = BaseInvocation.get_invocations() + all_invocations = InvocationRegistry.get_invocation_classes() has_integer = len([i for i in all_invocations if i.get_type() == "integer"]) == 1 has_string = len([i for i in all_invocations if i.get_type() == "string"]) == 1 @@ -296,4 +296,4 @@ def test_deny_nodes(patch_rootdir): # Reset the config so that it doesn't affect other tests get_config.cache_clear() - BaseInvocation.get_typeadapter.cache_clear() + InvocationRegistry.get_invocation_typeadapter.cache_clear() From 595133463e50cc3fd759aaf63cb9916c497c4a63 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 24 Mar 2025 09:40:33 +1000 Subject: [PATCH 3/5] feat(nodes): add methods to invalidate invocation typeadapters --- invokeai/app/invocations/baseinvocation.py | 22 ++++++++++++++++++++-- tests/test_config.py | 4 ++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index aa1dbe3af4..b2f3c3a9f2 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -257,19 +257,28 @@ class InvocationRegistry: def register_invocation(cls, invocation: type[BaseInvocation]) -> None: """Registers an invocation.""" cls._invocation_classes.add(invocation) - cls.get_invocation_typeadapter.cache_clear() + cls.invalidate_invocation_typeadapter() @classmethod @lru_cache(maxsize=1) def get_invocation_typeadapter(cls) -> TypeAdapter[Any]: """Gets a pydantic TypeAdapter for the union of all invocation types. + This is used to parse serialized invocations into the correct invocation class. + This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects the updated allowlist and denylist. + + @see https://docs.pydantic.dev/latest/concepts/type_adapter/ """ return TypeAdapter(Annotated[Union[tuple(cls.get_invocation_classes())], Field(discriminator="type")]) + @classmethod + def invalidate_invocation_typeadapter(cls) -> None: + """Invalidates the cached invocation type adapter.""" + cls.get_invocation_typeadapter.cache_clear() + @classmethod def get_invocation_classes(cls) -> Iterable[type[BaseInvocation]]: """Gets all invocations, respecting the allowlist and denylist.""" @@ -306,7 +315,7 @@ class InvocationRegistry: def register_output(cls, output: "type[TBaseInvocationOutput]") -> None: """Registers an invocation output.""" cls._output_classes.add(output) - cls.get_output_typeadapter.cache_clear() + cls.invalidate_output_typeadapter() @classmethod def get_output_classes(cls) -> Iterable[type[BaseInvocationOutput]]: @@ -318,12 +327,21 @@ class InvocationRegistry: def get_output_typeadapter(cls) -> TypeAdapter[Any]: """Gets a pydantic TypeAdapter for the union of all invocation output types. + This is used to parse serialized invocation outputs into the correct invocation output class. + This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects the updated allowlist and denylist. + + @see https://docs.pydantic.dev/latest/concepts/type_adapter/ """ return TypeAdapter(Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]) + @classmethod + def invalidate_output_typeadapter(cls) -> None: + """Invalidates the cached invocation output type adapter.""" + cls.get_output_typeadapter.cache_clear() + @classmethod def get_output_types(cls) -> Iterable[str]: """Gets all invocation output types.""" diff --git a/tests/test_config.py b/tests/test_config.py index 087dc89db5..610162c075 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -274,7 +274,7 @@ def test_deny_nodes(patch_rootdir): # We've changed the config, we need to invalidate the typeadapter cache so that the new config is used for # subsequent graph validations - InvocationRegistry.get_invocation_typeadapter.cache_clear() + InvocationRegistry.invalidate_invocation_typeadapter() # confirm graph validation fails when using denied node Graph.model_validate({"nodes": {"1": {"id": "1", "type": "integer"}}}) @@ -296,4 +296,4 @@ def test_deny_nodes(patch_rootdir): # Reset the config so that it doesn't affect other tests get_config.cache_clear() - InvocationRegistry.get_invocation_typeadapter.cache_clear() + InvocationRegistry.invalidate_invocation_typeadapter() From 77bf5c15bb35f05c05d310c6f0b0f71618161088 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Mon, 31 Mar 2025 11:40:37 -0400 Subject: [PATCH 4/5] GET presigned URLs directly instead of trying to use redirects --- .../src/common/hooks/useClientSideUpload.ts | 40 ++++++--- .../src/common/hooks/useImageUploadButton.tsx | 88 ++++++++++--------- 2 files changed, 76 insertions(+), 52 deletions(-) diff --git a/invokeai/frontend/web/src/common/hooks/useClientSideUpload.ts b/invokeai/frontend/web/src/common/hooks/useClientSideUpload.ts index bc78794d77..a45a35dacb 100644 --- a/invokeai/frontend/web/src/common/hooks/useClientSideUpload.ts +++ b/invokeai/frontend/web/src/common/hooks/useClientSideUpload.ts @@ -6,6 +6,16 @@ import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors'; import { useCallback } from 'react'; import { useCreateImageUploadEntryMutation } from 'services/api/endpoints/images'; import type { ImageDTO } from 'services/api/types'; + +type PresignedUrlResponse = { + fullUrl: string; + thumbnailUrl: string; +}; + +const isPresignedUrlResponse = (response: unknown): response is PresignedUrlResponse => { + return typeof response === 'object' && response !== null && 'fullUrl' in response && 'thumbnailUrl' in response; +}; + export const useClientSideUpload = () => { const dispatch = useAppDispatch(); const autoAddBoardId = useAppSelector(selectAutoAddBoardId); @@ -74,24 +84,30 @@ export const useClientSideUpload = () => { board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, }).unwrap(); - await fetch(`${presigned_url}/?type=full`, { + const response = await fetch(presigned_url, { + method: 'GET', + ...(authToken && { + headers: { + Authorization: `Bearer ${authToken}`, + }, + }), + }).then((res) => res.json()); + + if (isPresignedUrlResponse(response)) { + throw new Error('Invalid response'); + } + + const fullUrl = response.fullUrl; + const thumbnailUrl = response.thumbnailUrl; + + await fetch(fullUrl, { method: 'PUT', body: file, - ...(authToken && { - headers: { - Authorization: `Bearer ${authToken}`, - }, - }), }); - await fetch(`${presigned_url}/?type=thumbnail`, { + await fetch(thumbnailUrl, { method: 'PUT', body: thumbnail, - ...(authToken && { - headers: { - Authorization: `Bearer ${authToken}`, - }, - }), }); dispatch(imageUploadedClientSide({ imageDTO: image_dto, silent: false, isFirstUploadOfBatch: i === 0 })); diff --git a/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx b/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx index 9fedc501f5..aead53ae75 100644 --- a/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx +++ b/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx @@ -58,50 +58,58 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us const onDropAccepted = useCallback( async (files: File[]) => { - if (!allowMultiple) { - if (files.length > 1) { - log.warn('Multiple files dropped but only one allowed'); - return; - } - if (files.length === 0) { - // Should never happen - log.warn('No files dropped'); - return; - } - const file = files[0]; - assert(file !== undefined); // should never happen - const imageDTO = await uploadImage({ - file, - image_category: 'user', - is_intermediate: false, - board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, - silent: true, - }).unwrap(); - if (onUpload) { - onUpload(imageDTO); - } - } else { - let imageDTOs: ImageDTO[] = []; - if (isClientSideUploadEnabled) { - imageDTOs = await Promise.all(files.map((file, i) => clientSideUpload(file, i))); + try { + if (!allowMultiple) { + if (files.length > 1) { + log.warn('Multiple files dropped but only one allowed'); + return; + } + if (files.length === 0) { + // Should never happen + log.warn('No files dropped'); + return; + } + const file = files[0]; + assert(file !== undefined); // should never happen + const imageDTO = await uploadImage({ + file, + image_category: 'user', + is_intermediate: false, + board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, + silent: true, + }).unwrap(); + if (onUpload) { + onUpload(imageDTO); + } } else { - imageDTOs = await uploadImages( - files.map((file, i) => ({ - file, - image_category: 'user', - is_intermediate: false, - board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, - silent: false, - isFirstUploadOfBatch: i === 0, - })) - ); - } - if (onUpload) { - onUpload(imageDTOs); + let imageDTOs: ImageDTO[] = []; + if (isClientSideUploadEnabled) { + imageDTOs = await Promise.all(files.map((file, i) => clientSideUpload(file, i))); + } else { + imageDTOs = await uploadImages( + files.map((file, i) => ({ + file, + image_category: 'user', + is_intermediate: false, + board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, + silent: false, + isFirstUploadOfBatch: i === 0, + })) + ); + } + if (onUpload) { + onUpload(imageDTOs); + } } + } catch (error) { + toast({ + id: 'UPLOAD_FAILED', + title: t('toast.imageUploadFailed'), + status: 'error', + }); } }, - [allowMultiple, autoAddBoardId, onUpload, uploadImage, isClientSideUploadEnabled, clientSideUpload] + [allowMultiple, autoAddBoardId, onUpload, uploadImage, isClientSideUploadEnabled, clientSideUpload, t] ); const onDropRejected = useCallback( From a5851ca31cc60714adea0e92c016771e23cdde35 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Mon, 31 Mar 2025 11:41:29 -0400 Subject: [PATCH 5/5] fix from leftover testing --- invokeai/frontend/web/src/common/hooks/useClientSideUpload.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/common/hooks/useClientSideUpload.ts b/invokeai/frontend/web/src/common/hooks/useClientSideUpload.ts index a45a35dacb..f5cc5a7f3f 100644 --- a/invokeai/frontend/web/src/common/hooks/useClientSideUpload.ts +++ b/invokeai/frontend/web/src/common/hooks/useClientSideUpload.ts @@ -93,7 +93,7 @@ export const useClientSideUpload = () => { }), }).then((res) => res.json()); - if (isPresignedUrlResponse(response)) { + if (!isPresignedUrlResponse(response)) { throw new Error('Invalid response'); }