feat(nodes): add methods to invalidate invocation typeadapters

This commit is contained in:
psychedelicious 2025-03-24 09:40:33 +10:00
parent 6155f9ff9e
commit 595133463e
2 changed files with 22 additions and 4 deletions

View File

@ -257,19 +257,28 @@ class InvocationRegistry:
def register_invocation(cls, invocation: type[BaseInvocation]) -> None:
"""Registers an invocation."""
cls._invocation_classes.add(invocation)
cls.get_invocation_typeadapter.cache_clear()
cls.invalidate_invocation_typeadapter()
@classmethod
@lru_cache(maxsize=1)
def get_invocation_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantic TypeAdapter for the union of all invocation types.
This is used to parse serialized invocations into the correct invocation class.
This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
the updated allowlist and denylist.
@see https://docs.pydantic.dev/latest/concepts/type_adapter/
"""
return TypeAdapter(Annotated[Union[tuple(cls.get_invocation_classes())], Field(discriminator="type")])
@classmethod
def invalidate_invocation_typeadapter(cls) -> None:
"""Invalidates the cached invocation type adapter."""
cls.get_invocation_typeadapter.cache_clear()
@classmethod
def get_invocation_classes(cls) -> Iterable[type[BaseInvocation]]:
"""Gets all invocations, respecting the allowlist and denylist."""
@ -306,7 +315,7 @@ class InvocationRegistry:
def register_output(cls, output: "type[TBaseInvocationOutput]") -> None:
"""Registers an invocation output."""
cls._output_classes.add(output)
cls.get_output_typeadapter.cache_clear()
cls.invalidate_output_typeadapter()
@classmethod
def get_output_classes(cls) -> Iterable[type[BaseInvocationOutput]]:
@ -318,12 +327,21 @@ class InvocationRegistry:
def get_output_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantic TypeAdapter for the union of all invocation output types.
This is used to parse serialized invocation outputs into the correct invocation output class.
This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
the updated allowlist and denylist.
@see https://docs.pydantic.dev/latest/concepts/type_adapter/
"""
return TypeAdapter(Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")])
@classmethod
def invalidate_output_typeadapter(cls) -> None:
"""Invalidates the cached invocation output type adapter."""
cls.get_output_typeadapter.cache_clear()
@classmethod
def get_output_types(cls) -> Iterable[str]:
"""Gets all invocation output types."""

View File

@ -274,7 +274,7 @@ def test_deny_nodes(patch_rootdir):
# We've changed the config, we need to invalidate the typeadapter cache so that the new config is used for
# subsequent graph validations
InvocationRegistry.get_invocation_typeadapter.cache_clear()
InvocationRegistry.invalidate_invocation_typeadapter()
# confirm graph validation fails when using denied node
Graph.model_validate({"nodes": {"1": {"id": "1", "type": "integer"}}})
@ -296,4 +296,4 @@ def test_deny_nodes(patch_rootdir):
# Reset the config so that it doesn't affect other tests
get_config.cache_clear()
InvocationRegistry.get_invocation_typeadapter.cache_clear()
InvocationRegistry.invalidate_invocation_typeadapter()