refactor(backend): Defer loading of .blocks and .integrations.webhooks on module init (#9664)

Currently, an import statement like `from backend.blocks.basic import
AgentInputBlock` will initialize `backend.blocks` and thereby load all
other blocks. This has quite high potential to cause circular import
issues, and it's bad for performance in cases where we don't want to
load all blocks (yet).
The same goes for `backend.integrations.webhooks`.

### Changes 🏗️

- Change `__init__.py` of `backend.blocks` and
`backend.integrations.webhooks` to cached loader functions rather than
init-time code
- Change type of `BlockWebhookConfig.provider` to `ProviderName`

<!-- test edit to check that this doesn't break anything -->

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Set up and use an agent with a webhook-triggered block

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
This commit is contained in:
Reinier van der Leer 2025-03-24 16:44:45 +01:00 committed by GitHub
parent 6f48515863
commit 4ca1a453c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 152 additions and 118 deletions

View File

@ -2,88 +2,103 @@ import importlib
import os
import re
from pathlib import Path
from typing import Type, TypeVar
from backend.data.block import Block
# Dynamically load all modules under backend.blocks
AVAILABLE_MODULES = []
current_dir = Path(__file__).parent
modules = [
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
for f in current_dir.rglob("*.py")
if f.is_file() and f.name != "__init__.py"
]
for module in modules:
if not re.match("^[a-z0-9_.]+$", module):
raise ValueError(
f"Block module {module} error: module name must be lowercase, "
"and contain only alphanumeric characters and underscores."
)
importlib.import_module(f".{module}", package=__name__)
AVAILABLE_MODULES.append(module)
# Load all Block instances from the available modules
AVAILABLE_BLOCKS: dict[str, Type[Block]] = {}
from typing import TYPE_CHECKING, TypeVar
if TYPE_CHECKING:
from backend.data.block import Block
T = TypeVar("T")
def all_subclasses(cls: Type[T]) -> list[Type[T]]:
_AVAILABLE_BLOCKS: dict[str, type["Block"]] = {}
def load_all_blocks() -> dict[str, type["Block"]]:
from backend.data.block import Block
if _AVAILABLE_BLOCKS:
return _AVAILABLE_BLOCKS
# Dynamically load all modules under backend.blocks
AVAILABLE_MODULES = []
current_dir = Path(__file__).parent
modules = [
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
for f in current_dir.rglob("*.py")
if f.is_file() and f.name != "__init__.py"
]
for module in modules:
if not re.match("^[a-z0-9_.]+$", module):
raise ValueError(
f"Block module {module} error: module name must be lowercase, "
"and contain only alphanumeric characters and underscores."
)
importlib.import_module(f".{module}", package=__name__)
AVAILABLE_MODULES.append(module)
# Load all Block instances from the available modules
for block_cls in all_subclasses(Block):
class_name = block_cls.__name__
if class_name.endswith("Base"):
continue
if not class_name.endswith("Block"):
raise ValueError(
f"Block class {class_name} does not end with 'Block'. "
"If you are creating an abstract class, "
"please name the class with 'Base' at the end"
)
block = block_cls.create()
if not isinstance(block.id, str) or len(block.id) != 36:
raise ValueError(
f"Block ID {block.name} error: {block.id} is not a valid UUID"
)
if block.id in _AVAILABLE_BLOCKS:
raise ValueError(
f"Block ID {block.name} error: {block.id} is already in use"
)
input_schema = block.input_schema.model_fields
output_schema = block.output_schema.model_fields
# Make sure `error` field is a string in the output schema
if "error" in output_schema and output_schema["error"].annotation is not str:
raise ValueError(
f"{block.name} `error` field in output_schema must be a string"
)
# Ensure all fields in input_schema and output_schema are annotated SchemaFields
for field_name, field in [*input_schema.items(), *output_schema.items()]:
if field.annotation is None:
raise ValueError(
f"{block.name} has a field {field_name} that is not annotated"
)
if field.json_schema_extra is None:
raise ValueError(
f"{block.name} has a field {field_name} not defined as SchemaField"
)
for field in block.input_schema.model_fields.values():
if field.annotation is bool and field.default not in (True, False):
raise ValueError(
f"{block.name} has a boolean field with no default value"
)
_AVAILABLE_BLOCKS[block.id] = block_cls
return _AVAILABLE_BLOCKS
__all__ = ["load_all_blocks"]
def all_subclasses(cls: type[T]) -> list[type[T]]:
subclasses = cls.__subclasses__()
for subclass in subclasses:
subclasses += all_subclasses(subclass)
return subclasses
for block_cls in all_subclasses(Block):
name = block_cls.__name__
if block_cls.__name__.endswith("Base"):
continue
if not block_cls.__name__.endswith("Block"):
raise ValueError(
f"Block class {block_cls.__name__} does not end with 'Block', If you are creating an abstract class, please name the class with 'Base' at the end"
)
block = block_cls.create()
if not isinstance(block.id, str) or len(block.id) != 36:
raise ValueError(f"Block ID {block.name} error: {block.id} is not a valid UUID")
if block.id in AVAILABLE_BLOCKS:
raise ValueError(f"Block ID {block.name} error: {block.id} is already in use")
input_schema = block.input_schema.model_fields
output_schema = block.output_schema.model_fields
# Make sure `error` field is a string in the output schema
if "error" in output_schema and output_schema["error"].annotation is not str:
raise ValueError(
f"{block.name} `error` field in output_schema must be a string"
)
# Make sure all fields in input_schema and output_schema are annotated and has a value
for field_name, field in [*input_schema.items(), *output_schema.items()]:
if field.annotation is None:
raise ValueError(
f"{block.name} has a field {field_name} that is not annotated"
)
if field.json_schema_extra is None:
raise ValueError(
f"{block.name} has a field {field_name} not defined as SchemaField"
)
for field in block.input_schema.model_fields.values():
if field.annotation is bool and field.default not in (True, False):
raise ValueError(f"{block.name} has a boolean field with no default value")
if block.disabled:
continue
AVAILABLE_BLOCKS[block.id] = block_cls
__all__ = ["AVAILABLE_MODULES", "AVAILABLE_BLOCKS"]

View File

@ -8,6 +8,7 @@ from backend.data.block import (
BlockSchema,
)
from backend.data.model import SchemaField
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks.compass import CompassWebhookType
@ -42,7 +43,7 @@ class CompassAITriggerBlock(Block):
input_schema=CompassAITriggerBlock.Input,
output_schema=CompassAITriggerBlock.Output,
webhook_config=BlockManualWebhookConfig(
provider="compass",
provider=ProviderName.COMPASS,
webhook_type=CompassWebhookType.TRANSCRIPTION,
),
test_input=[

View File

@ -12,6 +12,7 @@ from backend.data.block import (
BlockWebhookConfig,
)
from backend.data.model import SchemaField
from backend.integrations.providers import ProviderName
from ._auth import (
TEST_CREDENTIALS,
@ -123,7 +124,7 @@ class GithubPullRequestTriggerBlock(GitHubTriggerBase, Block):
output_schema=GithubPullRequestTriggerBlock.Output,
# --8<-- [start:example-webhook_config]
webhook_config=BlockWebhookConfig(
provider="github",
provider=ProviderName.GITHUB,
webhook_type=GithubWebhookType.REPO,
resource_format="{repo}",
event_filter_input="events",

View File

@ -8,6 +8,7 @@ from backend.data.block import (
BlockWebhookConfig,
)
from backend.data.model import SchemaField
from backend.integrations.providers import ProviderName
from backend.util import settings
from backend.util.settings import AppEnvironment, BehaveAs
@ -82,7 +83,7 @@ class Slant3DOrderWebhookBlock(Slant3DTriggerBase, Block):
input_schema=self.Input,
output_schema=self.Output,
webhook_config=BlockWebhookConfig(
provider="slant3d",
provider=ProviderName.SLANT3D,
webhook_type="orders", # Only one type for now
resource_format="", # No resource format needed
event_filter_input="events",

View File

@ -20,6 +20,7 @@ from prisma.models import AgentBlock
from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.settings import Config
@ -225,7 +226,7 @@ class BlockManualWebhookConfig(BaseModel):
the user has to manually set up the webhook at the provider.
"""
provider: str
provider: ProviderName
"""The service provider that the webhook connects to"""
webhook_type: str
@ -461,9 +462,9 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
def get_blocks() -> dict[str, Type[Block]]:
from backend.blocks import AVAILABLE_BLOCKS # noqa: E402
from backend.blocks import load_all_blocks
return AVAILABLE_BLOCKS
return load_all_blocks()
async def initialize_blocks() -> None:

View File

@ -1,22 +1,43 @@
from typing import TYPE_CHECKING
from .compass import CompassWebhookManager
from .github import GithubWebhooksManager
from .slant3d import Slant3DWebhooksManager
if TYPE_CHECKING:
from ..providers import ProviderName
from ._base import BaseWebhooksManager
# --8<-- [start:WEBHOOK_MANAGERS_BY_NAME]
WEBHOOK_MANAGERS_BY_NAME: dict["ProviderName", type["BaseWebhooksManager"]] = {
handler.PROVIDER_NAME: handler
for handler in [
CompassWebhookManager,
GithubWebhooksManager,
Slant3DWebhooksManager,
]
}
# --8<-- [end:WEBHOOK_MANAGERS_BY_NAME]
_WEBHOOK_MANAGERS: dict["ProviderName", type["BaseWebhooksManager"]] = {}
__all__ = ["WEBHOOK_MANAGERS_BY_NAME"]
# --8<-- [start:load_webhook_managers]
def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]:
if _WEBHOOK_MANAGERS:
return _WEBHOOK_MANAGERS
from .compass import CompassWebhookManager
from .github import GithubWebhooksManager
from .slant3d import Slant3DWebhooksManager
_WEBHOOK_MANAGERS.update(
{
handler.PROVIDER_NAME: handler
for handler in [
CompassWebhookManager,
GithubWebhooksManager,
Slant3DWebhooksManager,
]
}
)
return _WEBHOOK_MANAGERS
# --8<-- [end:load_webhook_managers]
def get_webhook_manager(provider_name: "ProviderName") -> "BaseWebhooksManager":
return load_webhook_managers()[provider_name]()
def supports_webhooks(provider_name: "ProviderName") -> bool:
return provider_name in load_webhook_managers()
__all__ = ["get_webhook_manager", "supports_webhooks"]

View File

@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Callable, Optional, cast
from backend.data.block import BlockSchema, BlockWebhookConfig, get_block
from backend.data.graph import set_node_webhook
from backend.integrations.webhooks import WEBHOOK_MANAGERS_BY_NAME
from backend.integrations.webhooks import get_webhook_manager, supports_webhooks
if TYPE_CHECKING:
from backend.data.graph import GraphModel, NodeModel
@ -123,7 +123,7 @@ async def on_node_activate(
return node
provider = block.webhook_config.provider
if provider not in WEBHOOK_MANAGERS_BY_NAME:
if not supports_webhooks(provider):
raise ValueError(
f"Block #{block.id} has webhook_config for provider {provider} "
"which does not support webhooks"
@ -133,7 +133,7 @@ async def on_node_activate(
f"Activating webhook node #{node.id} with config {block.webhook_config}"
)
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
webhooks_manager = get_webhook_manager(provider)
if auto_setup_webhook := isinstance(block.webhook_config, BlockWebhookConfig):
try:
@ -234,13 +234,13 @@ async def on_node_deactivate(
return node
provider = block.webhook_config.provider
if provider not in WEBHOOK_MANAGERS_BY_NAME:
if not supports_webhooks(provider):
raise ValueError(
f"Block #{block.id} has webhook_config for provider {provider} "
"which does not support webhooks"
)
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
webhooks_manager = get_webhook_manager(provider)
if node.webhook_id:
logger.debug(f"Node #{node.id} has webhook_id {node.webhook_id}")

View File

@ -71,7 +71,7 @@ def get_outputs_with_names(results: List[ExecutionResult]) -> List[Dict[str, str
)
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
blocks = [block() for block in backend.data.block.get_blocks().values()]
return [b.to_dict() for b in blocks]
return [b.to_dict() for b in blocks if not b.disabled]
@v1_router.post(

View File

@ -17,7 +17,7 @@ from backend.executor.manager import ExecutionManager
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import WEBHOOK_MANAGERS_BY_NAME
from backend.integrations.webhooks import get_webhook_manager
from backend.util.exceptions import NeedConfirmation
from backend.util.service import get_service_client
from backend.util.settings import Settings
@ -281,7 +281,7 @@ async def webhook_ingress_generic(
webhook_id: Annotated[str, Path(title="Our ID for the webhook")],
):
logger.debug(f"Received {provider.value} webhook ingress for ID {webhook_id}")
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
webhook_manager = get_webhook_manager(provider)
webhook = await get_webhook(webhook_id)
logger.debug(f"Webhook #{webhook_id}: {webhook}")
payload, event_type = await webhook_manager.validate_payload(webhook, request)
@ -323,7 +323,7 @@ async def webhook_ping(
user_id: Annotated[str, Depends(get_user_id)], # require auth
):
webhook = await get_webhook(webhook_id)
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[webhook.provider]()
webhook_manager = get_webhook_manager(webhook.provider)
credentials = (
creds_manager.get(user_id, webhook.credentials_id)
@ -358,14 +358,6 @@ async def remove_all_webhooks_for_credentials(
NeedConfirmation: If any of the webhooks are still in use and `force` is `False`
"""
webhooks = await get_all_webhooks_by_creds(credentials.id)
if credentials.provider not in WEBHOOK_MANAGERS_BY_NAME:
if webhooks:
logger.error(
f"Credentials #{credentials.id} for provider {credentials.provider} "
f"are attached to {len(webhooks)} webhooks, "
f"but there is no available WebhooksHandler for {credentials.provider}"
)
return
if any(w.attached_nodes for w in webhooks) and not force:
raise NeedConfirmation(
"Some webhooks linked to these credentials are still in use by an agent"
@ -376,7 +368,7 @@ async def remove_all_webhooks_for_credentials(
await set_node_webhook(node.id, None)
# Prune the webhook
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[credentials.provider]()
webhook_manager = get_webhook_manager(ProviderName(credentials.provider))
success = await webhook_manager.prune_webhook_if_dangling(
webhook.id, credentials
)

View File

@ -198,7 +198,9 @@ async def get_onboarding_agents(
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
blocks = [block() for block in backend.data.block.get_blocks().values()]
costs = get_block_costs()
return [{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks]
return [
{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks if not b.disabled
]
@v1_router.post(

View File

@ -481,10 +481,10 @@ To add support for a new webhook provider, you'll need to create a WebhooksManag
--8<-- "autogpt_platform/backend/backend/integrations/webhooks/_base.py:BaseWebhooksManager5"
```
And add a reference to your `WebhooksManager` class in `WEBHOOK_MANAGERS_BY_NAME`:
And add a reference to your `WebhooksManager` class in `load_webhook_managers`:
```python title="backend/integrations/webhooks/__init__.py"
--8<-- "autogpt_platform/backend/backend/integrations/webhooks/__init__.py:WEBHOOK_MANAGERS_BY_NAME"
--8<-- "autogpt_platform/backend/backend/integrations/webhooks/__init__.py:load_webhook_managers"
```
#### Example: GitHub Webhook Integration