mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-04-04 22:43:40 +08:00
feat(nodes): add methods to invalidate invocation typeadapters
This commit is contained in:
parent
6155f9ff9e
commit
595133463e
@ -257,19 +257,28 @@ class InvocationRegistry:
|
||||
def register_invocation(cls, invocation: type[BaseInvocation]) -> None:
|
||||
"""Registers an invocation."""
|
||||
cls._invocation_classes.add(invocation)
|
||||
cls.get_invocation_typeadapter.cache_clear()
|
||||
cls.invalidate_invocation_typeadapter()
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=1)
|
||||
def get_invocation_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantic TypeAdapter for the union of all invocation types.
|
||||
|
||||
This is used to parse serialized invocations into the correct invocation class.
|
||||
|
||||
This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
|
||||
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
|
||||
the updated allowlist and denylist.
|
||||
|
||||
@see https://docs.pydantic.dev/latest/concepts/type_adapter/
|
||||
"""
|
||||
return TypeAdapter(Annotated[Union[tuple(cls.get_invocation_classes())], Field(discriminator="type")])
|
||||
|
||||
@classmethod
|
||||
def invalidate_invocation_typeadapter(cls) -> None:
|
||||
"""Invalidates the cached invocation type adapter."""
|
||||
cls.get_invocation_typeadapter.cache_clear()
|
||||
|
||||
@classmethod
|
||||
def get_invocation_classes(cls) -> Iterable[type[BaseInvocation]]:
|
||||
"""Gets all invocations, respecting the allowlist and denylist."""
|
||||
@ -306,7 +315,7 @@ class InvocationRegistry:
|
||||
def register_output(cls, output: "type[TBaseInvocationOutput]") -> None:
|
||||
"""Registers an invocation output."""
|
||||
cls._output_classes.add(output)
|
||||
cls.get_output_typeadapter.cache_clear()
|
||||
cls.invalidate_output_typeadapter()
|
||||
|
||||
@classmethod
|
||||
def get_output_classes(cls) -> Iterable[type[BaseInvocationOutput]]:
|
||||
@ -318,12 +327,21 @@ class InvocationRegistry:
|
||||
def get_output_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantic TypeAdapter for the union of all invocation output types.
|
||||
|
||||
This is used to parse serialized invocation outputs into the correct invocation output class.
|
||||
|
||||
This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
|
||||
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
|
||||
the updated allowlist and denylist.
|
||||
|
||||
@see https://docs.pydantic.dev/latest/concepts/type_adapter/
|
||||
"""
|
||||
return TypeAdapter(Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")])
|
||||
|
||||
@classmethod
|
||||
def invalidate_output_typeadapter(cls) -> None:
|
||||
"""Invalidates the cached invocation output type adapter."""
|
||||
cls.get_output_typeadapter.cache_clear()
|
||||
|
||||
@classmethod
|
||||
def get_output_types(cls) -> Iterable[str]:
|
||||
"""Gets all invocation output types."""
|
||||
|
@ -274,7 +274,7 @@ def test_deny_nodes(patch_rootdir):
|
||||
|
||||
# We've changed the config, we need to invalidate the typeadapter cache so that the new config is used for
|
||||
# subsequent graph validations
|
||||
InvocationRegistry.get_invocation_typeadapter.cache_clear()
|
||||
InvocationRegistry.invalidate_invocation_typeadapter()
|
||||
|
||||
# confirm graph validation fails when using denied node
|
||||
Graph.model_validate({"nodes": {"1": {"id": "1", "type": "integer"}}})
|
||||
@ -296,4 +296,4 @@ def test_deny_nodes(patch_rootdir):
|
||||
|
||||
# Reset the config so that it doesn't affect other tests
|
||||
get_config.cache_clear()
|
||||
InvocationRegistry.get_invocation_typeadapter.cache_clear()
|
||||
InvocationRegistry.invalidate_invocation_typeadapter()
|
||||
|
Loading…
x
Reference in New Issue
Block a user