diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 31ac02ce8e..b2f3c3a9f2 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -8,6 +8,7 @@ import sys import warnings from abc import ABC, abstractmethod from enum import Enum +from functools import lru_cache from inspect import signature from typing import ( TYPE_CHECKING, @@ -27,7 +28,6 @@ import semver from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefined -from typing_extensions import TypeAliasType from invokeai.app.invocations.fields import ( FieldKind, @@ -100,37 +100,6 @@ class BaseInvocationOutput(BaseModel): All invocation outputs must use the `@invocation_output` decorator to provide their unique type. """ - _output_classes: ClassVar[set[BaseInvocationOutput]] = set() - _typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None - _typeadapter_needs_update: ClassVar[bool] = False - - @classmethod - def register_output(cls, output: BaseInvocationOutput) -> None: - """Registers an invocation output.""" - cls._output_classes.add(output) - cls._typeadapter_needs_update = True - - @classmethod - def get_outputs(cls) -> Iterable[BaseInvocationOutput]: - """Gets all invocation outputs.""" - return cls._output_classes - - @classmethod - def get_typeadapter(cls) -> TypeAdapter[Any]: - """Gets a pydantc TypeAdapter for the union of all invocation output types.""" - if not cls._typeadapter or cls._typeadapter_needs_update: - AnyInvocationOutput = TypeAliasType( - "AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")] - ) - cls._typeadapter = TypeAdapter(AnyInvocationOutput) - cls._typeadapter_needs_update = False - return cls._typeadapter - - @classmethod - def get_output_types(cls) -> Iterable[str]: - """Gets all invocation output types.""" - return (i.get_type() for i in BaseInvocationOutput.get_outputs()) - @staticmethod def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None: """Adds various UI-facing attributes to the invocation output's OpenAPI schema.""" @@ -173,76 +142,16 @@ class BaseInvocation(ABC, BaseModel): All invocations must use the `@invocation` decorator to provide their unique type. """ - _invocation_classes: ClassVar[set[BaseInvocation]] = set() - _typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None - _typeadapter_needs_update: ClassVar[bool] = False - @classmethod def get_type(cls) -> str: """Gets the invocation's type, as provided by the `@invocation` decorator.""" return cls.model_fields["type"].default - @classmethod - def register_invocation(cls, invocation: BaseInvocation) -> None: - """Registers an invocation.""" - cls._invocation_classes.add(invocation) - cls._typeadapter_needs_update = True - - @classmethod - def get_typeadapter(cls) -> TypeAdapter[Any]: - """Gets a pydantc TypeAdapter for the union of all invocation types.""" - if not cls._typeadapter or cls._typeadapter_needs_update: - AnyInvocation = TypeAliasType( - "AnyInvocation", Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")] - ) - cls._typeadapter = TypeAdapter(AnyInvocation) - cls._typeadapter_needs_update = False - return cls._typeadapter - - @classmethod - def invalidate_typeadapter(cls) -> None: - """Invalidates the typeadapter, forcing it to be rebuilt on next access. If the invocation allowlist or - denylist is changed, this should be called to ensure the typeadapter is updated and validation respects - the updated allowlist and denylist.""" - cls._typeadapter_needs_update = True - - @classmethod - def get_invocations(cls) -> Iterable[BaseInvocation]: - """Gets all invocations, respecting the allowlist and denylist.""" - app_config = get_config() - allowed_invocations: set[BaseInvocation] = set() - for sc in cls._invocation_classes: - invocation_type = sc.get_type() - is_in_allowlist = ( - invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True - ) - is_in_denylist = ( - invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False - ) - if is_in_allowlist and not is_in_denylist: - allowed_invocations.add(sc) - return allowed_invocations - - @classmethod - def get_invocations_map(cls) -> dict[str, BaseInvocation]: - """Gets a map of all invocation types to their invocation classes.""" - return {i.get_type(): i for i in BaseInvocation.get_invocations()} - - @classmethod - def get_invocation_types(cls) -> Iterable[str]: - """Gets all invocation types.""" - return (i.get_type() for i in BaseInvocation.get_invocations()) - @classmethod def get_output_annotation(cls) -> BaseInvocationOutput: """Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method).""" return signature(cls.invoke).return_annotation - @classmethod - def get_invocation_for_type(cls, invocation_type: str) -> BaseInvocation | None: - """Gets the invocation class for a given invocation type.""" - return cls.get_invocations_map().get(invocation_type) - @staticmethod def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None: """Adds various UI-facing attributes to the invocation's OpenAPI schema.""" @@ -340,6 +249,105 @@ class BaseInvocation(ABC, BaseModel): TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation) +class InvocationRegistry: + _invocation_classes: ClassVar[set[type[BaseInvocation]]] = set() + _output_classes: ClassVar[set[type[BaseInvocationOutput]]] = set() + + @classmethod + def register_invocation(cls, invocation: type[BaseInvocation]) -> None: + """Registers an invocation.""" + cls._invocation_classes.add(invocation) + cls.invalidate_invocation_typeadapter() + + @classmethod + @lru_cache(maxsize=1) + def get_invocation_typeadapter(cls) -> TypeAdapter[Any]: + """Gets a pydantic TypeAdapter for the union of all invocation types. + + This is used to parse serialized invocations into the correct invocation class. + + This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or + denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects + the updated allowlist and denylist. + + @see https://docs.pydantic.dev/latest/concepts/type_adapter/ + """ + return TypeAdapter(Annotated[Union[tuple(cls.get_invocation_classes())], Field(discriminator="type")]) + + @classmethod + def invalidate_invocation_typeadapter(cls) -> None: + """Invalidates the cached invocation type adapter.""" + cls.get_invocation_typeadapter.cache_clear() + + @classmethod + def get_invocation_classes(cls) -> Iterable[type[BaseInvocation]]: + """Gets all invocations, respecting the allowlist and denylist.""" + app_config = get_config() + allowed_invocations: set[type[BaseInvocation]] = set() + for sc in cls._invocation_classes: + invocation_type = sc.get_type() + is_in_allowlist = ( + invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True + ) + is_in_denylist = ( + invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False + ) + if is_in_allowlist and not is_in_denylist: + allowed_invocations.add(sc) + return allowed_invocations + + @classmethod + def get_invocations_map(cls) -> dict[str, type[BaseInvocation]]: + """Gets a map of all invocation types to their invocation classes.""" + return {i.get_type(): i for i in cls.get_invocation_classes()} + + @classmethod + def get_invocation_types(cls) -> Iterable[str]: + """Gets all invocation types.""" + return (i.get_type() for i in cls.get_invocation_classes()) + + @classmethod + def get_invocation_for_type(cls, invocation_type: str) -> type[BaseInvocation] | None: + """Gets the invocation class for a given invocation type.""" + return cls.get_invocations_map().get(invocation_type) + + @classmethod + def register_output(cls, output: "type[TBaseInvocationOutput]") -> None: + """Registers an invocation output.""" + cls._output_classes.add(output) + cls.invalidate_output_typeadapter() + + @classmethod + def get_output_classes(cls) -> Iterable[type[BaseInvocationOutput]]: + """Gets all invocation outputs.""" + return cls._output_classes + + @classmethod + @lru_cache(maxsize=1) + def get_output_typeadapter(cls) -> TypeAdapter[Any]: + """Gets a pydantic TypeAdapter for the union of all invocation output types. + + This is used to parse serialized invocation outputs into the correct invocation output class. + + This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or + denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects + the updated allowlist and denylist. + + @see https://docs.pydantic.dev/latest/concepts/type_adapter/ + """ + return TypeAdapter(Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]) + + @classmethod + def invalidate_output_typeadapter(cls) -> None: + """Invalidates the cached invocation output type adapter.""" + cls.get_output_typeadapter.cache_clear() + + @classmethod + def get_output_types(cls) -> Iterable[str]: + """Gets all invocation output types.""" + return (i.get_type() for i in cls.get_output_classes()) + + RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = { "id", "is_intermediate", @@ -453,8 +461,8 @@ def invocation( node_pack = cls.__module__.split(".")[0] # Handle the case where an existing node is being clobbered by the one we are registering - if invocation_type in BaseInvocation.get_invocation_types(): - clobbered_invocation = BaseInvocation.get_invocation_for_type(invocation_type) + if invocation_type in InvocationRegistry.get_invocation_types(): + clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type) # This should always be true - we just checked if the invocation type was in the set assert clobbered_invocation is not None @@ -539,8 +547,7 @@ def invocation( ) cls.__doc__ = docstring - # TODO: how to type this correctly? it's typed as ModelMetaclass, a private class in pydantic - BaseInvocation.register_invocation(cls) # type: ignore + InvocationRegistry.register_invocation(cls) return cls @@ -565,7 +572,7 @@ def invocation_output( if re.compile(r"^\S+$").match(output_type) is None: raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"') - if output_type in BaseInvocationOutput.get_output_types(): + if output_type in InvocationRegistry.get_output_types(): raise ValueError(f'Invocation type "{output_type}" already exists') validate_fields(cls.model_fields, output_type) @@ -586,7 +593,7 @@ def invocation_output( ) cls.__doc__ = docstring - BaseInvocationOutput.register_output(cls) # type: ignore # TODO: how to type this correctly? + InvocationRegistry.register_output(cls) return cls diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 2d425a7515..fef99753ee 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -21,6 +21,7 @@ from invokeai.app.invocations import * # noqa: F401 F403 from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, + InvocationRegistry, invocation, invocation_output, ) @@ -283,7 +284,7 @@ class AnyInvocation(BaseInvocation): @classmethod def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: def validate_invocation(v: Any) -> "AnyInvocation": - return BaseInvocation.get_typeadapter().validate_python(v) + return InvocationRegistry.get_invocation_typeadapter().validate_python(v) return core_schema.no_info_plain_validator_function(validate_invocation) @@ -294,7 +295,7 @@ class AnyInvocation(BaseInvocation): # Nodes are too powerful, we have to make our own OpenAPI schema manually # No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually oneOf: list[dict[str, str]] = [] - names = [i.__name__ for i in BaseInvocation.get_invocations()] + names = [i.__name__ for i in InvocationRegistry.get_invocation_classes()] for name in sorted(names): oneOf.append({"$ref": f"#/components/schemas/{name}"}) return {"oneOf": oneOf} @@ -304,7 +305,7 @@ class AnyInvocationOutput(BaseInvocationOutput): @classmethod def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler): def validate_invocation_output(v: Any) -> "AnyInvocationOutput": - return BaseInvocationOutput.get_typeadapter().validate_python(v) + return InvocationRegistry.get_output_typeadapter().validate_python(v) return core_schema.no_info_plain_validator_function(validate_invocation_output) @@ -316,7 +317,7 @@ class AnyInvocationOutput(BaseInvocationOutput): # No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually oneOf: list[dict[str, str]] = [] - names = [i.__name__ for i in BaseInvocationOutput.get_outputs()] + names = [i.__name__ for i in InvocationRegistry.get_output_classes()] for name in sorted(names): oneOf.append({"$ref": f"#/components/schemas/{name}"}) return {"oneOf": oneOf} diff --git a/invokeai/app/util/custom_openapi.py b/invokeai/app/util/custom_openapi.py index e52028d772..85e17e21c6 100644 --- a/invokeai/app/util/custom_openapi.py +++ b/invokeai/app/util/custom_openapi.py @@ -4,7 +4,10 @@ from fastapi import FastAPI from fastapi.openapi.utils import get_openapi from pydantic.json_schema import models_json_schema -from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase +from invokeai.app.invocations.baseinvocation import ( + InvocationRegistry, + UIConfigBase, +) from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.services.events.events_common import EventBase @@ -56,14 +59,14 @@ def get_openapi_func( invocation_output_map_required: list[str] = [] # We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly. - for output in BaseInvocationOutput.get_outputs(): + for output in InvocationRegistry.get_output_classes(): json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}") move_defs_to_top_level(openapi_schema, json_schema) openapi_schema["components"]["schemas"][output.__name__] = json_schema # Technically, invocations are added to the schema by pydantic, but we still need to manually set their output # property, so we'll just do it all manually. - for invocation in BaseInvocation.get_invocations(): + for invocation in InvocationRegistry.get_invocation_classes(): json_schema = invocation.model_json_schema( mode="serialization", ref_template="#/components/schemas/{model}" ) diff --git a/invokeai/frontend/web/src/common/hooks/useClientSideUpload.ts b/invokeai/frontend/web/src/common/hooks/useClientSideUpload.ts index bc78794d77..f5cc5a7f3f 100644 --- a/invokeai/frontend/web/src/common/hooks/useClientSideUpload.ts +++ b/invokeai/frontend/web/src/common/hooks/useClientSideUpload.ts @@ -6,6 +6,16 @@ import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors'; import { useCallback } from 'react'; import { useCreateImageUploadEntryMutation } from 'services/api/endpoints/images'; import type { ImageDTO } from 'services/api/types'; + +type PresignedUrlResponse = { + fullUrl: string; + thumbnailUrl: string; +}; + +const isPresignedUrlResponse = (response: unknown): response is PresignedUrlResponse => { + return typeof response === 'object' && response !== null && 'fullUrl' in response && 'thumbnailUrl' in response; +}; + export const useClientSideUpload = () => { const dispatch = useAppDispatch(); const autoAddBoardId = useAppSelector(selectAutoAddBoardId); @@ -74,24 +84,30 @@ export const useClientSideUpload = () => { board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, }).unwrap(); - await fetch(`${presigned_url}/?type=full`, { + const response = await fetch(presigned_url, { + method: 'GET', + ...(authToken && { + headers: { + Authorization: `Bearer ${authToken}`, + }, + }), + }).then((res) => res.json()); + + if (!isPresignedUrlResponse(response)) { + throw new Error('Invalid response'); + } + + const fullUrl = response.fullUrl; + const thumbnailUrl = response.thumbnailUrl; + + await fetch(fullUrl, { method: 'PUT', body: file, - ...(authToken && { - headers: { - Authorization: `Bearer ${authToken}`, - }, - }), }); - await fetch(`${presigned_url}/?type=thumbnail`, { + await fetch(thumbnailUrl, { method: 'PUT', body: thumbnail, - ...(authToken && { - headers: { - Authorization: `Bearer ${authToken}`, - }, - }), }); dispatch(imageUploadedClientSide({ imageDTO: image_dto, silent: false, isFirstUploadOfBatch: i === 0 })); diff --git a/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx b/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx index 9fedc501f5..aead53ae75 100644 --- a/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx +++ b/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx @@ -58,50 +58,58 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us const onDropAccepted = useCallback( async (files: File[]) => { - if (!allowMultiple) { - if (files.length > 1) { - log.warn('Multiple files dropped but only one allowed'); - return; - } - if (files.length === 0) { - // Should never happen - log.warn('No files dropped'); - return; - } - const file = files[0]; - assert(file !== undefined); // should never happen - const imageDTO = await uploadImage({ - file, - image_category: 'user', - is_intermediate: false, - board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, - silent: true, - }).unwrap(); - if (onUpload) { - onUpload(imageDTO); - } - } else { - let imageDTOs: ImageDTO[] = []; - if (isClientSideUploadEnabled) { - imageDTOs = await Promise.all(files.map((file, i) => clientSideUpload(file, i))); + try { + if (!allowMultiple) { + if (files.length > 1) { + log.warn('Multiple files dropped but only one allowed'); + return; + } + if (files.length === 0) { + // Should never happen + log.warn('No files dropped'); + return; + } + const file = files[0]; + assert(file !== undefined); // should never happen + const imageDTO = await uploadImage({ + file, + image_category: 'user', + is_intermediate: false, + board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, + silent: true, + }).unwrap(); + if (onUpload) { + onUpload(imageDTO); + } } else { - imageDTOs = await uploadImages( - files.map((file, i) => ({ - file, - image_category: 'user', - is_intermediate: false, - board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, - silent: false, - isFirstUploadOfBatch: i === 0, - })) - ); - } - if (onUpload) { - onUpload(imageDTOs); + let imageDTOs: ImageDTO[] = []; + if (isClientSideUploadEnabled) { + imageDTOs = await Promise.all(files.map((file, i) => clientSideUpload(file, i))); + } else { + imageDTOs = await uploadImages( + files.map((file, i) => ({ + file, + image_category: 'user', + is_intermediate: false, + board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, + silent: false, + isFirstUploadOfBatch: i === 0, + })) + ); + } + if (onUpload) { + onUpload(imageDTOs); + } } + } catch (error) { + toast({ + id: 'UPLOAD_FAILED', + title: t('toast.imageUploadFailed'), + status: 'error', + }); } }, - [allowMultiple, autoAddBoardId, onUpload, uploadImage, isClientSideUploadEnabled, clientSideUpload] + [allowMultiple, autoAddBoardId, onUpload, uploadImage, isClientSideUploadEnabled, clientSideUpload, t] ); const onDropRejected = useCallback( diff --git a/tests/test_config.py b/tests/test_config.py index 220d6f257a..610162c075 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -5,7 +5,7 @@ from typing import Any import pytest from pydantic import ValidationError -from invokeai.app.invocations.baseinvocation import BaseInvocation +from invokeai.app.invocations.baseinvocation import InvocationRegistry from invokeai.app.services.config.config_default import ( DefaultInvokeAIAppConfig, InvokeAIAppConfig, @@ -274,7 +274,7 @@ def test_deny_nodes(patch_rootdir): # We've changed the config, we need to invalidate the typeadapter cache so that the new config is used for # subsequent graph validations - BaseInvocation.invalidate_typeadapter() + InvocationRegistry.invalidate_invocation_typeadapter() # confirm graph validation fails when using denied node Graph.model_validate({"nodes": {"1": {"id": "1", "type": "integer"}}}) @@ -284,7 +284,7 @@ def test_deny_nodes(patch_rootdir): Graph.model_validate({"nodes": {"1": {"id": "1", "type": "float"}}}) # confirm invocations union will not have denied nodes - all_invocations = BaseInvocation.get_invocations() + all_invocations = InvocationRegistry.get_invocation_classes() has_integer = len([i for i in all_invocations if i.get_type() == "integer"]) == 1 has_string = len([i for i in all_invocations if i.get_type() == "string"]) == 1 @@ -296,4 +296,4 @@ def test_deny_nodes(patch_rootdir): # Reset the config so that it doesn't affect other tests get_config.cache_clear() - BaseInvocation.invalidate_typeadapter() + InvocationRegistry.invalidate_invocation_typeadapter()