refactor(backend): Reorganize & clean up execution update system (#9663)

- Prep work for #8782
- Prep work for #8779

### Changes 🏗️

- refactor(platform): Differentiate graph/node execution events
- fix(platform): Subscribe to execution updates by `graph_exec_id`
instead of `graph_id`+`graph_version`
- refactor(backend): Move all execution related models and functions
from `.data.graph` to `.data.execution`
- refactor(backend): Reorganize & refactor `.data.execution`

- fix(libs): Remove `load_dotenv` in `.auth.config` to fix test config
issues
- dx: Bump version of `black` in pre-commit config to v24.10.0 to match
poetry.lock

- Other minor refactoring in both frontend and backend

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - Run an agent in the builder
    - [x] -> works normally, node I/O is updated in real time
  - Run an agent in the library
    - [x] -> works normally

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
This commit is contained in:
Reinier van der Leer 2025-03-25 13:14:04 +01:00 committed by GitHub
parent 37f212e950
commit 1162ec1474
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 885 additions and 687 deletions

View File

@ -140,7 +140,7 @@ repos:
language: system
- repo: https://github.com/psf/black
rev: 23.12.1
rev: 24.10.0
# Black has sensible defaults, doesn't need package context, and ignores
# everything in .gitignore, so it works fine without any config or arguments.
hooks:

View File

@ -1,11 +1,9 @@
from .config import Settings
from .depends import requires_admin_user, requires_user
from .jwt_utils import parse_jwt_token
from .middleware import APIKeyValidator, auth_middleware
from .models import User
__all__ = [
"Settings",
"parse_jwt_token",
"requires_user",
"requires_admin_user",

View File

@ -1,14 +1,11 @@
import os
from dotenv import load_dotenv
load_dotenv()
class Settings:
JWT_SECRET_KEY: str = os.getenv("SUPABASE_JWT_SECRET", "")
ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
JWT_ALGORITHM: str = "HS256"
def __init__(self):
self.JWT_SECRET_KEY: str = os.getenv("SUPABASE_JWT_SECRET", "")
self.ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
self.JWT_ALGORITHM: str = "HS256"
@property
def is_configured(self) -> bool:

View File

@ -1,6 +1,6 @@
import fastapi
from .config import Settings
from .config import settings
from .middleware import auth_middleware
from .models import DEFAULT_USER_ID, User
@ -17,7 +17,7 @@ def requires_admin_user(
def verify_user(payload: dict | None, admin_only: bool) -> User:
if not payload:
if Settings.ENABLE_AUTH:
if settings.ENABLE_AUTH:
raise fastapi.HTTPException(
status_code=401, detail="Authorization header is missing"
)

View File

@ -1929,4 +1929,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<4.0"
content-hash = "931772287f71c539575d601e6398423bf68e09ca87ae1a144057c7f5707cf978"
content-hash = "02023e8698c80648fec23a112ec2ec90d617bba83081d194fab90f682908f0f3"

View File

@ -16,7 +16,6 @@ pyjwt = "^2.10.1"
pytest-asyncio = "^0.25.3"
pytest-mock = "^3.14.0"
python = ">=3.10,<4.0"
python-dotenv = "^1.0.1"
supabase = "^2.13.0"
[tool.poetry.group.dev.dependencies]

View File

@ -75,6 +75,8 @@ class AgentExecutorBlock(Block):
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
from backend.data.execution import ExecutionEventType
executor_manager = get_executor_manager_client()
event_bus = get_event_bus()
@ -90,11 +92,7 @@ class AgentExecutorBlock(Block):
for event in event_bus.listen(
graph_id=graph_exec.graph_id, graph_exec_id=graph_exec.graph_exec_id
):
logger.info(
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
)
if not event.node_id:
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
if event.status in [
ExecutionStatus.COMPLETED,
ExecutionStatus.TERMINATED,
@ -105,6 +103,10 @@ class AgentExecutorBlock(Block):
else:
continue
logger.info(
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
)
if not event.block_id:
logger.warning(f"{log_id} received event without block_id {event}")
continue

View File

@ -220,9 +220,8 @@ def event():
@test.command()
@click.argument("server_address")
@click.argument("graph_id")
@click.argument("graph_version")
def websocket(server_address: str, graph_id: str, graph_version: int):
@click.argument("graph_exec_id")
def websocket(server_address: str, graph_exec_id: str):
"""
Tests the websocket connection.
"""
@ -230,16 +229,20 @@ def websocket(server_address: str, graph_id: str, graph_version: int):
import websockets.asyncio.client
from backend.server.ws_api import ExecutionSubscription, Methods, WsMessage
from backend.server.ws_api import (
WSMessage,
WSMethod,
WSSubscribeGraphExecutionRequest,
)
async def send_message(server_address: str):
uri = f"ws://{server_address}"
async with websockets.asyncio.client.connect(uri) as websocket:
try:
msg = WsMessage(
method=Methods.SUBSCRIBE,
data=ExecutionSubscription(
graph_id=graph_id, graph_version=graph_version
msg = WSMessage(
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
data=WSSubscribeGraphExecutionRequest(
graph_exec_id=graph_exec_id,
).model_dump(),
).model_dump_json()
await websocket.send(msg)

View File

@ -1,7 +1,18 @@
import logging
from collections import defaultdict
from datetime import datetime, timezone
from enum import Enum
from multiprocessing import Manager
from typing import Any, AsyncGenerator, Generator, Generic, Optional, Type, TypeVar
from typing import (
Annotated,
Any,
AsyncGenerator,
Generator,
Generic,
Literal,
Optional,
TypeVar,
)
from prisma import Json
from prisma.enums import AgentExecutionStatus
@ -10,62 +21,148 @@ from prisma.models import (
AgentNodeExecution,
AgentNodeExecutionInputOutput,
)
from prisma.types import AgentNodeExecutionUpdateInput, AgentNodeExecutionWhereInput
from prisma.types import (
AgentGraphExecutionWhereInput,
AgentNodeExecutionUpdateInput,
AgentNodeExecutionWhereInput,
)
from pydantic import BaseModel
from pydantic.fields import Field
from backend.data.block import BlockData, BlockInput, CompletedBlockOutput
from backend.data.includes import EXECUTION_RESULT_INCLUDE, GRAPH_EXECUTION_INCLUDE
from backend.data.model import GraphExecutionStats, NodeExecutionStats
from backend.data.queue import AsyncRedisEventBus, RedisEventBus
from backend.server.v2.store.exceptions import DatabaseError
from backend.util import mock, type
from backend.util import mock
from backend.util import type as type_utils
from backend.util.settings import Config
from .block import BlockData, BlockInput, BlockType, CompletedBlockOutput, get_block
from .db import BaseDbModel
from .includes import EXECUTION_RESULT_INCLUDE, GRAPH_EXECUTION_INCLUDE
from .model import GraphExecutionStats, NodeExecutionStats
from .queue import AsyncRedisEventBus, RedisEventBus
class GraphExecutionEntry(BaseModel):
user_id: str
graph_exec_id: str
graph_id: str
graph_version: int
start_node_execs: list["NodeExecutionEntry"]
T = TypeVar("T")
logger = logging.getLogger(__name__)
config = Config()
class NodeExecutionEntry(BaseModel):
user_id: str
graph_exec_id: str
graph_id: str
node_exec_id: str
node_id: str
block_id: str
data: BlockInput
# -------------------------- Models -------------------------- #
ExecutionStatus = AgentExecutionStatus
T = TypeVar("T")
class GraphExecutionMeta(BaseDbModel):
user_id: str
started_at: datetime
ended_at: datetime
cost: Optional[int] = Field(..., description="Execution cost in credits")
duration: float = Field(..., description="Seconds from start to end of run")
total_run_time: float = Field(..., description="Seconds of node runtime")
status: ExecutionStatus
graph_id: str
graph_version: int
preset_id: Optional[str] = None
@staticmethod
def from_db(_graph_exec: AgentGraphExecution):
now = datetime.now(timezone.utc)
start_time = _graph_exec.startedAt or _graph_exec.createdAt
end_time = _graph_exec.updatedAt or now
duration = (end_time - start_time).total_seconds()
total_run_time = duration
try:
stats = GraphExecutionStats.model_validate(_graph_exec.stats)
except ValueError as e:
if _graph_exec.stats is not None:
logger.warning(
"Failed to parse invalid graph execution stats "
f"{_graph_exec.stats}: {e}"
)
stats = None
duration = stats.walltime if stats else duration
total_run_time = stats.nodes_walltime if stats else total_run_time
return GraphExecutionMeta(
id=_graph_exec.id,
user_id=_graph_exec.userId,
started_at=start_time,
ended_at=end_time,
cost=stats.cost if stats else None,
duration=duration,
total_run_time=total_run_time,
status=ExecutionStatus(_graph_exec.executionStatus),
graph_id=_graph_exec.agentGraphId,
graph_version=_graph_exec.agentGraphVersion,
preset_id=_graph_exec.agentPresetId,
)
class ExecutionQueue(Generic[T]):
"""
Queue for managing the execution of agents.
This will be shared between different processes
"""
class GraphExecution(GraphExecutionMeta):
inputs: BlockInput
outputs: CompletedBlockOutput
node_executions: list["NodeExecutionResult"]
def __init__(self):
self.queue = Manager().Queue()
@staticmethod
def from_db(_graph_exec: AgentGraphExecution):
if _graph_exec.AgentNodeExecutions is None:
raise ValueError("Node executions must be included in query")
def add(self, execution: T) -> T:
self.queue.put(execution)
return execution
graph_exec = GraphExecutionMeta.from_db(_graph_exec)
def get(self) -> T:
return self.queue.get()
node_executions = sorted(
[
NodeExecutionResult.from_db(ne, _graph_exec.userId)
for ne in _graph_exec.AgentNodeExecutions
],
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
)
def empty(self) -> bool:
return self.queue.empty()
inputs = {
**{
# inputs from Agent Input Blocks
exec.input_data["name"]: exec.input_data.get("value")
for exec in node_executions
if (
(block := get_block(exec.block_id))
and block.block_type == BlockType.INPUT
)
},
**{
# input from webhook-triggered block
"payload": exec.input_data["payload"]
for exec in node_executions
if (
(block := get_block(exec.block_id))
and block.block_type
in [BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL]
)
},
}
outputs: CompletedBlockOutput = defaultdict(list)
for exec in node_executions:
if (
block := get_block(exec.block_id)
) and block.block_type == BlockType.OUTPUT:
outputs[exec.input_data["name"]].append(
exec.input_data.get("value", None)
)
return GraphExecution(
**{
field_name: getattr(graph_exec, field_name)
for field_name in graph_exec.model_fields
},
inputs=inputs,
outputs=outputs,
node_executions=node_executions,
)
class ExecutionResult(BaseModel):
class NodeExecutionResult(BaseModel):
user_id: str
graph_id: str
graph_version: int
@ -81,41 +178,20 @@ class ExecutionResult(BaseModel):
start_time: datetime | None
end_time: datetime | None
@staticmethod
def from_graph(graph_exec: AgentGraphExecution):
return ExecutionResult(
user_id=graph_exec.userId,
graph_id=graph_exec.agentGraphId,
graph_version=graph_exec.agentGraphVersion,
graph_exec_id=graph_exec.id,
node_exec_id="",
node_id="",
block_id="",
status=graph_exec.executionStatus,
# TODO: Populate input_data & output_data from AgentNodeExecutions
# Input & Output comes AgentInputBlock & AgentOutputBlock.
input_data={},
output_data={},
add_time=graph_exec.createdAt,
queue_time=graph_exec.createdAt,
start_time=graph_exec.startedAt,
end_time=graph_exec.updatedAt,
)
@staticmethod
def from_db(execution: AgentNodeExecution, user_id: Optional[str] = None):
if execution.executionData:
# Execution that has been queued for execution will persist its data.
input_data = type.convert(execution.executionData, dict[str, Any])
input_data = type_utils.convert(execution.executionData, dict[str, Any])
else:
# For incomplete execution, executionData will not be yet available.
input_data: BlockInput = defaultdict()
for data in execution.Input or []:
input_data[data.name] = type.convert(data.data, Type[Any])
input_data[data.name] = type_utils.convert(data.data, type[Any])
output_data: CompletedBlockOutput = defaultdict(list)
for data in execution.Output or []:
output_data[data.name].append(type.convert(data.data, Type[Any]))
output_data[data.name].append(type_utils.convert(data.data, type[Any]))
graph_execution: AgentGraphExecution | None = execution.AgentGraphExecution
if graph_execution:
@ -125,7 +201,7 @@ class ExecutionResult(BaseModel):
"AgentGraphExecution must be included or user_id passed in"
)
return ExecutionResult(
return NodeExecutionResult(
user_id=user_id,
graph_id=graph_execution.agentGraphId if graph_execution else "",
graph_version=graph_execution.agentGraphVersion if graph_execution else 0,
@ -146,13 +222,49 @@ class ExecutionResult(BaseModel):
# --------------------- Model functions --------------------- #
async def get_graph_executions(
graph_id: Optional[str] = None,
user_id: Optional[str] = None,
) -> list[GraphExecutionMeta]:
where_filter: AgentGraphExecutionWhereInput = {
"isDeleted": False,
}
if user_id:
where_filter["userId"] = user_id
if graph_id:
where_filter["agentGraphId"] = graph_id
executions = await AgentGraphExecution.prisma().find_many(
where=where_filter,
order={"createdAt": "desc"},
)
return [GraphExecutionMeta.from_db(execution) for execution in executions]
async def get_graph_execution_meta(
user_id: str, execution_id: str
) -> GraphExecutionMeta | None:
execution = await AgentGraphExecution.prisma().find_first(
where={"id": execution_id, "isDeleted": False, "userId": user_id}
)
return GraphExecutionMeta.from_db(execution) if execution else None
async def get_graph_execution(user_id: str, execution_id: str) -> GraphExecution | None:
execution = await AgentGraphExecution.prisma().find_first(
where={"id": execution_id, "isDeleted": False, "userId": user_id},
include=GRAPH_EXECUTION_INCLUDE,
)
return GraphExecution.from_db(execution) if execution else None
async def create_graph_execution(
graph_id: str,
graph_version: int,
nodes_input: list[tuple[str, BlockInput]],
user_id: str,
preset_id: str | None = None,
) -> tuple[str, list[ExecutionResult]]:
) -> tuple[str, list[NodeExecutionResult]]:
"""
Create a new AgentGraphExecution record.
Returns:
@ -186,7 +298,7 @@ async def create_graph_execution(
)
return result.id, [
ExecutionResult.from_db(execution, result.userId)
NodeExecutionResult.from_db(execution, result.userId)
for execution in result.AgentNodeExecutions or []
]
@ -236,7 +348,7 @@ async def upsert_execution_input(
)
return existing_execution.id, {
**{
input_data.name: type.convert(input_data.data, Type[Any])
input_data.name: type_utils.convert(input_data.data, type[Any])
for input_data in existing_execution.Input or []
},
input_name: input_data,
@ -276,7 +388,7 @@ async def upsert_execution_output(
)
async def update_graph_execution_start_time(graph_exec_id: str) -> ExecutionResult:
async def update_graph_execution_start_time(graph_exec_id: str) -> GraphExecutionMeta:
res = await AgentGraphExecution.prisma().update(
where={"id": graph_exec_id},
data={
@ -285,16 +397,16 @@ async def update_graph_execution_start_time(graph_exec_id: str) -> ExecutionResu
},
)
if not res:
raise ValueError(f"Execution {graph_exec_id} not found.")
raise ValueError(f"Graph execution #{graph_exec_id} not found")
return ExecutionResult.from_graph(res)
return GraphExecutionMeta.from_db(res)
async def update_graph_execution_stats(
graph_exec_id: str,
status: ExecutionStatus,
stats: GraphExecutionStats | None = None,
) -> ExecutionResult:
) -> GraphExecutionMeta:
data = stats.model_dump() if stats else {}
if isinstance(data.get("error"), Exception):
data["error"] = str(data["error"])
@ -312,9 +424,9 @@ async def update_graph_execution_stats(
},
)
if not res:
raise ValueError(f"Execution {graph_exec_id} not found.")
raise ValueError(f"Graph execution #{graph_exec_id} not found")
return ExecutionResult.from_graph(res)
return GraphExecutionMeta.from_db(res)
async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionStats):
@ -327,7 +439,7 @@ async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionSta
)
async def update_execution_status_batch(
async def update_node_execution_status_batch(
node_exec_ids: list[str],
status: ExecutionStatus,
stats: dict[str, Any] | None = None,
@ -338,12 +450,12 @@ async def update_execution_status_batch(
)
async def update_execution_status(
async def update_node_execution_status(
node_exec_id: str,
status: ExecutionStatus,
execution_data: BlockInput | None = None,
stats: dict[str, Any] | None = None,
) -> ExecutionResult:
) -> NodeExecutionResult:
if status == ExecutionStatus.QUEUED and execution_data is None:
raise ValueError("Execution data must be provided when queuing an execution.")
@ -355,7 +467,7 @@ async def update_execution_status(
if not res:
raise ValueError(f"Execution {node_exec_id} not found.")
return ExecutionResult.from_db(res)
return NodeExecutionResult.from_db(res)
def _get_update_status_data(
@ -381,7 +493,7 @@ def _get_update_status_data(
return update_data
async def delete_execution(
async def delete_graph_execution(
graph_exec_id: str, user_id: str, soft_delete: bool = True
) -> None:
if soft_delete:
@ -398,12 +510,12 @@ async def delete_execution(
)
async def get_execution_results(
async def get_node_execution_results(
graph_exec_id: str,
block_ids: list[str] | None = None,
statuses: list[ExecutionStatus] | None = None,
limit: int | None = None,
) -> list[ExecutionResult]:
) -> list[NodeExecutionResult]:
where_clause: AgentNodeExecutionWhereInput = {
"agentGraphExecutionId": graph_exec_id,
}
@ -417,10 +529,110 @@ async def get_execution_results(
include=EXECUTION_RESULT_INCLUDE,
take=limit,
)
res = [ExecutionResult.from_db(execution) for execution in executions]
res = [NodeExecutionResult.from_db(execution) for execution in executions]
return res
async def get_graph_executions_in_timerange(
user_id: str, start_time: str, end_time: str
) -> list[GraphExecution]:
try:
executions = await AgentGraphExecution.prisma().find_many(
where={
"startedAt": {
"gte": datetime.fromisoformat(start_time),
"lte": datetime.fromisoformat(end_time),
},
"userId": user_id,
"isDeleted": False,
},
include=GRAPH_EXECUTION_INCLUDE,
)
return [GraphExecution.from_db(execution) for execution in executions]
except Exception as e:
raise DatabaseError(
f"Failed to get executions in timerange {start_time} to {end_time} for user {user_id}: {e}"
) from e
async def get_latest_node_execution(
node_id: str, graph_eid: str
) -> NodeExecutionResult | None:
execution = await AgentNodeExecution.prisma().find_first(
where={
"agentNodeId": node_id,
"agentGraphExecutionId": graph_eid,
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
},
order=[
{"queuedTime": "desc"},
{"addedTime": "desc"},
],
include=EXECUTION_RESULT_INCLUDE,
)
if not execution:
return None
return NodeExecutionResult.from_db(execution)
async def get_incomplete_node_executions(
node_id: str, graph_eid: str
) -> list[NodeExecutionResult]:
executions = await AgentNodeExecution.prisma().find_many(
where={
"agentNodeId": node_id,
"agentGraphExecutionId": graph_eid,
"executionStatus": ExecutionStatus.INCOMPLETE,
},
include=EXECUTION_RESULT_INCLUDE,
)
return [NodeExecutionResult.from_db(execution) for execution in executions]
# ----------------- Execution Infrastructure ----------------- #
class GraphExecutionEntry(BaseModel):
user_id: str
graph_exec_id: str
graph_id: str
graph_version: int
start_node_execs: list["NodeExecutionEntry"]
class NodeExecutionEntry(BaseModel):
user_id: str
graph_exec_id: str
graph_id: str
node_exec_id: str
node_id: str
block_id: str
data: BlockInput
class ExecutionQueue(Generic[T]):
"""
Queue for managing the execution of agents.
This will be shared between different processes
"""
def __init__(self):
self.queue = Manager().Queue()
def add(self, execution: T) -> T:
self.queue.put(execution)
return execution
def get(self) -> T:
return self.queue.get()
def empty(self) -> bool:
return self.queue.empty()
# ------------------- Execution Utilities -------------------- #
LIST_SPLIT = "_$_"
DICT_SPLIT = "_#_"
OBJC_SPLIT = "_@_"
@ -556,72 +768,82 @@ def merge_execution_input(data: BlockInput) -> BlockInput:
return data
async def get_latest_execution(node_id: str, graph_eid: str) -> ExecutionResult | None:
execution = await AgentNodeExecution.prisma().find_first(
where={
"agentNodeId": node_id,
"agentGraphExecutionId": graph_eid,
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
},
order=[
{"queuedTime": "desc"},
{"addedTime": "desc"},
],
include=EXECUTION_RESULT_INCLUDE,
)
if not execution:
return None
return ExecutionResult.from_db(execution)
async def get_incomplete_executions(
node_id: str, graph_eid: str
) -> list[ExecutionResult]:
executions = await AgentNodeExecution.prisma().find_many(
where={
"agentNodeId": node_id,
"agentGraphExecutionId": graph_eid,
"executionStatus": ExecutionStatus.INCOMPLETE,
},
include=EXECUTION_RESULT_INCLUDE,
)
return [ExecutionResult.from_db(execution) for execution in executions]
# --------------------- Event Bus --------------------- #
config = Config()
class ExecutionEventType(str, Enum):
GRAPH_EXEC_UPDATE = "graph_execution_update"
NODE_EXEC_UPDATE = "node_execution_update"
class RedisExecutionEventBus(RedisEventBus[ExecutionResult]):
Model = ExecutionResult
class GraphExecutionEvent(GraphExecutionMeta):
event_type: Literal[ExecutionEventType.GRAPH_EXEC_UPDATE] = (
ExecutionEventType.GRAPH_EXEC_UPDATE
)
class NodeExecutionEvent(NodeExecutionResult):
event_type: Literal[ExecutionEventType.NODE_EXEC_UPDATE] = (
ExecutionEventType.NODE_EXEC_UPDATE
)
ExecutionEvent = Annotated[
GraphExecutionEvent | NodeExecutionEvent, Field(discriminator="event_type")
]
class RedisExecutionEventBus(RedisEventBus[ExecutionEvent]):
Model = ExecutionEvent # type: ignore
@property
def event_bus_name(self) -> str:
return config.execution_event_bus_name
def publish(self, res: ExecutionResult):
self.publish_event(res, f"{res.graph_id}/{res.graph_exec_id}")
def publish(self, res: GraphExecutionMeta | NodeExecutionResult):
if isinstance(res, GraphExecutionMeta):
self.publish_graph_exec_update(res)
else:
self.publish_node_exec_update(res)
def publish_node_exec_update(self, res: NodeExecutionResult):
event = NodeExecutionEvent.model_validate(res.model_dump())
self.publish_event(event, f"{res.graph_id}/{res.graph_exec_id}")
def publish_graph_exec_update(self, res: GraphExecutionMeta):
event = GraphExecutionEvent.model_validate(res.model_dump())
self.publish_event(event, f"{res.graph_id}/{res.id}")
def listen(
self, graph_id: str = "*", graph_exec_id: str = "*"
) -> Generator[ExecutionResult, None, None]:
for execution_result in self.listen_events(f"{graph_id}/{graph_exec_id}"):
yield execution_result
) -> Generator[ExecutionEvent, None, None]:
for event in self.listen_events(f"{graph_id}/{graph_exec_id}"):
yield event
class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionResult]):
Model = ExecutionResult
class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
Model = ExecutionEvent # type: ignore
@property
def event_bus_name(self) -> str:
return config.execution_event_bus_name
async def publish(self, res: ExecutionResult):
await self.publish_event(res, f"{res.graph_id}/{res.graph_exec_id}")
async def publish(self, res: GraphExecutionMeta | NodeExecutionResult):
if isinstance(res, GraphExecutionMeta):
await self.publish_graph_exec_update(res)
else:
await self.publish_node_exec_update(res)
async def publish_node_exec_update(self, res: NodeExecutionResult):
event = NodeExecutionEvent.model_validate(res.model_dump())
await self.publish_event(event, f"{res.graph_id}/{res.graph_exec_id}")
async def publish_graph_exec_update(self, res: GraphExecutionMeta):
event = GraphExecutionEvent.model_validate(res.model_dump())
await self.publish_event(event, f"{res.graph_id}/{res.id}")
async def listen(
self, graph_id: str = "*", graph_exec_id: str = "*"
) -> AsyncGenerator[ExecutionResult, None]:
async for execution_result in self.listen_events(f"{graph_id}/{graph_exec_id}"):
yield execution_result
) -> AsyncGenerator[ExecutionEvent, None]:
async for event in self.listen_events(f"{graph_id}/{graph_exec_id}"):
yield event

View File

@ -1,21 +1,14 @@
import logging
import uuid
from collections import defaultdict
from datetime import datetime, timezone
from typing import Any, Literal, Optional, Type
import prisma
from prisma import Json
from prisma.enums import SubmissionStatus
from prisma.models import (
AgentGraph,
AgentGraphExecution,
AgentNode,
AgentNodeLink,
StoreListingVersion,
)
from prisma.types import AgentGraphExecutionWhereInput, AgentGraphWhereInput
from pydantic.fields import Field, computed_field
from prisma.models import AgentGraph, AgentNode, AgentNodeLink, StoreListingVersion
from prisma.types import AgentGraphWhereInput
from pydantic.fields import computed_field
from backend.blocks.agent import AgentExecutorBlock
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
@ -25,15 +18,11 @@ from backend.util import type as type_utils
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
from .db import BaseDbModel, transaction
from .execution import ExecutionResult, ExecutionStatus
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE, GRAPH_EXECUTION_INCLUDE
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
from .integrations import Webhook
logger = logging.getLogger(__name__)
_INPUT_BLOCK_ID = AgentInputBlock().id
_OUTPUT_BLOCK_ID = AgentOutputBlock().id
class Link(BaseDbModel):
source_id: str
@ -164,111 +153,6 @@ class NodeModel(Node):
Webhook.model_rebuild()
class GraphExecutionMeta(BaseDbModel):
execution_id: str
started_at: datetime
ended_at: datetime
cost: Optional[int] = Field(..., description="Execution cost in credits")
duration: float
total_run_time: float
status: ExecutionStatus
graph_id: str
graph_version: int
preset_id: Optional[str]
@staticmethod
def from_db(_graph_exec: AgentGraphExecution):
now = datetime.now(timezone.utc)
start_time = _graph_exec.startedAt or _graph_exec.createdAt
end_time = _graph_exec.updatedAt or now
duration = (end_time - start_time).total_seconds()
total_run_time = duration
try:
stats = type_utils.convert(_graph_exec.stats or {}, dict[str, Any])
except ValueError:
stats = {}
duration = stats.get("walltime", duration)
total_run_time = stats.get("nodes_walltime", total_run_time)
return GraphExecutionMeta(
id=_graph_exec.id,
execution_id=_graph_exec.id,
started_at=start_time,
ended_at=end_time,
cost=stats.get("cost", None),
duration=duration,
total_run_time=total_run_time,
status=ExecutionStatus(_graph_exec.executionStatus),
graph_id=_graph_exec.agentGraphId,
graph_version=_graph_exec.agentGraphVersion,
preset_id=_graph_exec.agentPresetId,
)
class GraphExecution(GraphExecutionMeta):
inputs: dict[str, Any]
outputs: dict[str, list[Any]]
node_executions: list[ExecutionResult]
@staticmethod
def from_db(_graph_exec: AgentGraphExecution):
if _graph_exec.AgentNodeExecutions is None:
raise ValueError("Node executions must be included in query")
graph_exec = GraphExecutionMeta.from_db(_graph_exec)
node_executions = sorted(
[
ExecutionResult.from_db(ne, _graph_exec.userId)
for ne in _graph_exec.AgentNodeExecutions
],
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
)
inputs = {
**{
# inputs from Agent Input Blocks
exec.input_data["name"]: exec.input_data.get("value")
for exec in node_executions
if (
(block := get_block(exec.block_id))
and block.block_type == BlockType.INPUT
)
},
**{
# input from webhook-triggered block
"payload": exec.input_data["payload"]
for exec in node_executions
if (
(block := get_block(exec.block_id))
and block.block_type
in [BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL]
)
},
}
outputs: dict[str, list] = defaultdict(list)
for exec in node_executions:
if (
block := get_block(exec.block_id)
) and block.block_type == BlockType.OUTPUT:
outputs[exec.input_data["name"]].append(
exec.input_data.get("value", None)
)
return GraphExecution(
**{
field_name: getattr(graph_exec, field_name)
for field_name in graph_exec.model_fields
},
inputs=inputs,
outputs=outputs,
node_executions=node_executions,
)
class BaseGraph(BaseDbModel):
version: int = 1
is_active: bool = True
@ -644,45 +528,6 @@ async def get_graphs(
return graph_models
async def get_graph_executions(
graph_id: Optional[str] = None,
user_id: Optional[str] = None,
) -> list[GraphExecutionMeta]:
where_filter: AgentGraphExecutionWhereInput = {
"isDeleted": False,
}
if user_id:
where_filter["userId"] = user_id
if graph_id:
where_filter["agentGraphId"] = graph_id
executions = await AgentGraphExecution.prisma().find_many(
where=where_filter,
order={"createdAt": "desc"},
)
return [GraphExecutionMeta.from_db(execution) for execution in executions]
async def get_execution_meta(
user_id: str, execution_id: str
) -> GraphExecutionMeta | None:
execution = await AgentGraphExecution.prisma().find_first(
where={"id": execution_id, "isDeleted": False, "userId": user_id}
)
return GraphExecutionMeta.from_db(execution) if execution else None
async def get_execution(
user_id: str,
execution_id: str,
) -> GraphExecution | None:
execution = await AgentGraphExecution.prisma().find_first(
where={"id": execution_id, "isDeleted": False, "userId": user_id},
include=GRAPH_EXECUTION_INCLUDE,
)
return GraphExecution.from_db(execution) if execution else None
async def get_graph_metadata(graph_id: str, version: int | None = None) -> Graph | None:
where_clause: AgentGraphWhereInput = {
"id": graph_id,

View File

@ -411,6 +411,7 @@ class NodeExecutionStats(BaseModel):
class Config:
arbitrary_types_allowed = True
extra = "allow"
error: Optional[Exception | str] = None
walltime: float = 0
@ -428,12 +429,19 @@ class GraphExecutionStats(BaseModel):
class Config:
arbitrary_types_allowed = True
extra = "allow"
error: Optional[Exception | str] = None
walltime: float = 0
walltime: float = Field(
default=0, description="Time between start and end of run (seconds)"
)
cputime: float = 0
nodes_walltime: float = 0
nodes_walltime: float = Field(
default=0, description="Total node execution time (seconds)"
)
nodes_cputime: float = 0
node_count: int = 0
node_error_count: int = 0
cost: float = 0
node_count: int = Field(default=0, description="Total number of node executions")
node_error_count: int = Field(
default=0, description="Total number of errors generated"
)
cost: int = Field(default=0, description="Total execution cost (cents)")

View File

@ -1,8 +1,6 @@
import asyncio
import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, AsyncGenerator, Generator, Generic, Optional, TypeVar
from pydantic import BaseModel
@ -14,13 +12,6 @@ from backend.data import redis
logger = logging.getLogger(__name__)
class DateTimeEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, datetime):
return o.isoformat()
return super().default(o)
M = TypeVar("M", bound=BaseModel)
@ -32,8 +23,12 @@ class BaseRedisEventBus(Generic[M], ABC):
def event_bus_name(self) -> str:
pass
@property
def Message(self) -> type["_EventPayloadWrapper[M]"]:
return _EventPayloadWrapper[self.Model]
def _serialize_message(self, item: M, channel_key: str) -> tuple[str, str]:
message = json.dumps(item.model_dump(), cls=DateTimeEncoder)
message = self.Message(payload=item).model_dump_json()
channel_name = f"{self.event_bus_name}/{channel_key}"
logger.debug(f"[{channel_name}] Publishing an event to Redis {message}")
return message, channel_name
@ -43,9 +38,8 @@ class BaseRedisEventBus(Generic[M], ABC):
if msg["type"] != message_type:
return None
try:
data = json.loads(msg["data"])
logger.debug(f"Consuming an event from Redis {data}")
return self.Model(**data)
logger.debug(f"[{channel_key}] Consuming an event from Redis {msg['data']}")
return self.Message.model_validate_json(msg["data"]).payload
except Exception as e:
logger.error(f"Failed to parse event result from Redis {msg} {e}")
@ -57,9 +51,16 @@ class BaseRedisEventBus(Generic[M], ABC):
return pubsub, full_channel_name
class RedisEventBus(BaseRedisEventBus[M], ABC):
Model: type[M]
class _EventPayloadWrapper(BaseModel, Generic[M]):
"""
Wrapper model to allow `RedisEventBus.Model` to be a discriminated union
of multiple event types.
"""
payload: M
class RedisEventBus(BaseRedisEventBus[M], ABC):
@property
def connection(self) -> redis.Redis:
return redis.get_redis()
@ -85,8 +86,6 @@ class RedisEventBus(BaseRedisEventBus[M], ABC):
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
Model: type[M]
@property
async def connection(self) -> redis.AsyncRedis:
return await redis.get_redis_async()

View File

@ -1,16 +1,17 @@
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.execution import (
ExecutionResult,
GraphExecutionMeta,
NodeExecutionResult,
RedisExecutionEventBus,
create_graph_execution,
get_execution_results,
get_incomplete_executions,
get_latest_execution,
update_execution_status,
update_execution_status_batch,
get_incomplete_node_executions,
get_latest_node_execution,
get_node_execution_results,
update_graph_execution_start_time,
update_graph_execution_stats,
update_node_execution_stats,
update_node_execution_status,
update_node_execution_status_batch,
upsert_execution_input,
upsert_execution_output,
)
@ -55,23 +56,29 @@ class DatabaseManager(AppService):
super().__init__()
self.use_db = True
self.use_redis = True
self.event_queue = RedisExecutionEventBus()
self.execution_event_bus = RedisExecutionEventBus()
@classmethod
def get_port(cls) -> int:
return config.database_api_port
@expose
def send_execution_update(self, execution_result: ExecutionResult):
self.event_queue.publish(execution_result)
def send_execution_update(
self, execution_result: GraphExecutionMeta | NodeExecutionResult
):
self.execution_event_bus.publish(execution_result)
# Executions
create_graph_execution = exposed_run_and_wait(create_graph_execution)
get_execution_results = exposed_run_and_wait(get_execution_results)
get_incomplete_executions = exposed_run_and_wait(get_incomplete_executions)
get_latest_execution = exposed_run_and_wait(get_latest_execution)
update_execution_status = exposed_run_and_wait(update_execution_status)
update_execution_status_batch = exposed_run_and_wait(update_execution_status_batch)
get_node_execution_results = exposed_run_and_wait(get_node_execution_results)
get_incomplete_node_executions = exposed_run_and_wait(
get_incomplete_node_executions
)
get_latest_node_execution = exposed_run_and_wait(get_latest_node_execution)
update_node_execution_status = exposed_run_and_wait(update_node_execution_status)
update_node_execution_status_batch = exposed_run_and_wait(
update_node_execution_status_batch
)
update_graph_execution_start_time = exposed_run_and_wait(
update_graph_execution_start_time
)

View File

@ -40,10 +40,10 @@ from backend.data.block import (
)
from backend.data.execution import (
ExecutionQueue,
ExecutionResult,
ExecutionStatus,
GraphExecutionEntry,
NodeExecutionEntry,
NodeExecutionResult,
merge_execution_input,
parse_execution_output,
)
@ -149,8 +149,9 @@ def execute_node(
node_exec_id = data.node_exec_id
node_id = data.node_id
def update_execution(status: ExecutionStatus) -> ExecutionResult:
exec_update = db_client.update_execution_status(node_exec_id, status)
def update_execution(status: ExecutionStatus) -> NodeExecutionResult:
"""Sets status and fetches+broadcasts the latest state of the node execution"""
exec_update = db_client.update_node_execution_status(node_exec_id, status)
db_client.send_execution_update(exec_update)
return exec_update
@ -281,7 +282,7 @@ def _enqueue_next_nodes(
def add_enqueued_execution(
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
) -> NodeExecutionEntry:
exec_update = db_client.update_execution_status(
exec_update = db_client.update_node_execution_status(
node_exec_id, ExecutionStatus.QUEUED, data
)
db_client.send_execution_update(exec_update)
@ -326,7 +327,7 @@ def _enqueue_next_nodes(
if link.is_static and link.sink_name not in next_node_input
}
if static_link_names and (
latest_execution := db_client.get_latest_execution(
latest_execution := db_client.get_latest_node_execution(
next_node_id, graph_exec_id
)
):
@ -359,7 +360,7 @@ def _enqueue_next_nodes(
# If link is static, there could be some incomplete executions waiting for it.
# Load and complete the input missing input data, and try to re-enqueue them.
for iexec in db_client.get_incomplete_executions(
for iexec in db_client.get_incomplete_node_executions(
next_node_id, graph_exec_id
):
idata = iexec.input_data
@ -732,7 +733,6 @@ class Executor:
running_executions: dict[str, AsyncResult] = {}
def make_exec_callback(exec_data: NodeExecutionEntry):
def callback(result: object):
running_executions.pop(exec_data.node_id)
@ -778,7 +778,7 @@ class Executor:
exec_id = exec_data.node_exec_id
cls.db_client.upsert_execution_output(exec_id, "error", str(error))
exec_update = cls.db_client.update_execution_status(
exec_update = cls.db_client.update_node_execution_status(
exec_id, ExecutionStatus.FAILED
)
cls.db_client.send_execution_update(exec_update)
@ -842,7 +842,7 @@ class Executor:
metadata = cls.db_client.get_graph_metadata(
graph_exec.graph_id, graph_exec.graph_version
)
outputs = cls.db_client.get_execution_results(
outputs = cls.db_client.get_node_execution_results(
graph_exec.graph_exec_id,
block_ids=[AgentOutputBlock().id],
)
@ -1061,7 +1061,7 @@ class ExecutionManager(AppService):
graph_exec_id,
ExecutionStatus.TERMINATED,
)
node_execs = self.db_client.get_execution_results(
node_execs = self.db_client.get_node_execution_results(
graph_exec_id=graph_exec_id,
statuses=[
ExecutionStatus.QUEUED,
@ -1069,7 +1069,7 @@ class ExecutionManager(AppService):
ExecutionStatus.INCOMPLETE,
],
)
self.db_client.update_execution_status_batch(
self.db_client.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in node_execs],
ExecutionStatus.TERMINATED,
)

View File

@ -2,8 +2,17 @@ from typing import Dict, Set
from fastapi import WebSocket
from backend.data import execution
from backend.server.model import Methods, WsMessage
from backend.data.execution import (
ExecutionEventType,
GraphExecutionEvent,
NodeExecutionEvent,
)
from backend.server.model import WSMessage, WSMethod
_EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
ExecutionEventType.GRAPH_EXEC_UPDATE: WSMethod.GRAPH_EXECUTION_EVENT,
ExecutionEventType.NODE_EXEC_UPDATE: WSMethod.NODE_EXECUTION_EVENT,
}
class ConnectionManager:
@ -11,39 +20,58 @@ class ConnectionManager:
self.active_connections: Set[WebSocket] = set()
self.subscriptions: Dict[str, Set[WebSocket]] = {}
async def connect(self, websocket: WebSocket):
async def connect_socket(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.add(websocket)
def disconnect(self, websocket: WebSocket):
def disconnect_socket(self, websocket: WebSocket):
self.active_connections.remove(websocket)
for subscribers in self.subscriptions.values():
subscribers.discard(websocket)
async def subscribe(
self, *, user_id: str, graph_id: str, graph_version: int, websocket: WebSocket
):
key = f"{user_id}_{graph_id}_{graph_version}"
async def subscribe_graph_exec(
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
) -> str:
key = _graph_exec_channel_key(user_id, graph_exec_id)
if key not in self.subscriptions:
self.subscriptions[key] = set()
self.subscriptions[key].add(websocket)
return key
async def unsubscribe(
self, *, user_id: str, graph_id: str, graph_version: int, websocket: WebSocket
):
key = f"{user_id}_{graph_id}_{graph_version}"
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
) -> str | None:
key = _graph_exec_channel_key(user_id, graph_exec_id)
if key in self.subscriptions:
self.subscriptions[key].discard(websocket)
if not self.subscriptions[key]:
del self.subscriptions[key]
return key
return None
async def send_execution_result(self, result: execution.ExecutionResult):
key = f"{result.user_id}_{result.graph_id}_{result.graph_version}"
async def send_execution_update(
self, exec_event: GraphExecutionEvent | NodeExecutionEvent
) -> int:
graph_exec_id = (
exec_event.id
if isinstance(exec_event, GraphExecutionEvent)
else exec_event.graph_exec_id
)
key = _graph_exec_channel_key(exec_event.user_id, graph_exec_id)
n_sent = 0
if key in self.subscriptions:
message = WsMessage(
method=Methods.EXECUTION_EVENT,
message = WSMessage(
method=_EVENT_TYPE_TO_METHOD_MAP[exec_event.event_type],
channel=key,
data=result.model_dump(),
data=exec_event.model_dump(),
).model_dump_json()
for connection in self.subscriptions[key]:
await connection.send_text(message)
n_sent += 1
return n_sent
def _graph_exec_channel_key(user_id: str, graph_exec_id: str) -> str:
return f"{user_id}|graph_exec#{graph_exec_id}"

View File

@ -12,7 +12,7 @@ from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data.api_key import APIKey
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.execution import ExecutionResult
from backend.data.execution import NodeExecutionResult
from backend.executor import ExecutionManager
from backend.server.external.middleware import require_permission
from backend.util.service import get_service_client
@ -53,7 +53,7 @@ class GraphExecutionResult(TypedDict):
output: Optional[List[Dict[str, str]]]
def get_outputs_with_names(results: List[ExecutionResult]) -> List[Dict[str, str]]:
def get_outputs_with_names(results: list[NodeExecutionResult]) -> list[dict[str, str]]:
outputs = []
for result in results:
if "output" in result.output_data:
@ -130,7 +130,7 @@ async def get_graph_execution_results(
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
results = await execution_db.get_execution_results(graph_exec_id)
results = await execution_db.get_node_execution_results(graph_exec_id)
last_result = results[-1] if results else None
execution_status = (
last_result.status if last_result else AgentExecutionStatus.INCOMPLETE

View File

@ -1,31 +1,31 @@
import enum
from typing import Any, List, Optional, Union
from typing import Any, Optional
import pydantic
import backend.data.graph
from backend.data.api_key import APIKeyPermission, APIKeyWithoutHash
from backend.data.graph import Graph
class Methods(enum.Enum):
SUBSCRIBE = "subscribe"
class WSMethod(enum.Enum):
SUBSCRIBE_GRAPH_EXEC = "subscribe_graph_execution"
UNSUBSCRIBE = "unsubscribe"
EXECUTION_EVENT = "execution_event"
GRAPH_EXECUTION_EVENT = "graph_execution_event"
NODE_EXECUTION_EVENT = "node_execution_event"
ERROR = "error"
HEARTBEAT = "heartbeat"
class WsMessage(pydantic.BaseModel):
method: Methods
data: Optional[Union[dict[str, Any], list[Any], str]] = None
class WSMessage(pydantic.BaseModel):
method: WSMethod
data: Optional[dict[str, Any] | list[Any] | str] = None
success: bool | None = None
channel: str | None = None
error: str | None = None
class ExecutionSubscription(pydantic.BaseModel):
graph_id: str
graph_version: int
class WSSubscribeGraphExecutionRequest(pydantic.BaseModel):
graph_exec_id: str
class ExecuteGraphResponse(pydantic.BaseModel):
@ -33,12 +33,12 @@ class ExecuteGraphResponse(pydantic.BaseModel):
class CreateGraph(pydantic.BaseModel):
graph: backend.data.graph.Graph
graph: Graph
class CreateAPIKeyRequest(pydantic.BaseModel):
name: str
permissions: List[APIKeyPermission]
permissions: list[APIKeyPermission]
description: Optional[str] = None
@ -52,7 +52,7 @@ class SetGraphActiveVersion(pydantic.BaseModel):
class UpdatePermissionsRequest(pydantic.BaseModel):
permissions: List[APIKeyPermission]
permissions: list[APIKeyPermission]
class Pagination(pydantic.BaseModel):

View File

@ -177,7 +177,9 @@ class AgentServer(backend.util.service.AppProcess):
@staticmethod
async def test_get_graph_run_status(graph_exec_id: str, user_id: str):
execution = await backend.data.graph.get_execution_meta(
from backend.data.execution import get_graph_execution_meta
execution = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not execution:

View File

@ -599,8 +599,8 @@ def execute_graph(
)
async def stop_graph_run(
graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.GraphExecution:
if not await graph_db.get_execution_meta(
) -> execution_db.GraphExecution:
if not await execution_db.get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
):
raise HTTPException(404, detail=f"Agent execution #{graph_exec_id} not found")
@ -610,7 +610,9 @@ async def stop_graph_run(
)
# Retrieve & return canceled graph execution in its final state
result = await graph_db.get_execution(execution_id=graph_exec_id, user_id=user_id)
result = await execution_db.get_graph_execution(
execution_id=graph_exec_id, user_id=user_id
)
if not result:
raise HTTPException(
500,
@ -626,8 +628,8 @@ async def stop_graph_run(
)
async def get_graphs_executions(
user_id: Annotated[str, Depends(get_user_id)],
) -> list[graph_db.GraphExecutionMeta]:
return await graph_db.get_graph_executions(user_id=user_id)
) -> list[execution_db.GraphExecutionMeta]:
return await execution_db.get_graph_executions(user_id=user_id)
@v1_router.get(
@ -638,8 +640,8 @@ async def get_graphs_executions(
async def get_graph_executions(
graph_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> list[graph_db.GraphExecutionMeta]:
return await graph_db.get_graph_executions(graph_id=graph_id, user_id=user_id)
) -> list[execution_db.GraphExecutionMeta]:
return await execution_db.get_graph_executions(graph_id=graph_id, user_id=user_id)
@v1_router.get(
@ -651,8 +653,10 @@ async def get_graph_execution(
graph_id: str,
graph_exec_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> graph_db.GraphExecution:
result = await graph_db.get_execution(execution_id=graph_exec_id, user_id=user_id)
) -> execution_db.GraphExecution:
result = await execution_db.get_graph_execution(
execution_id=graph_exec_id, user_id=user_id
)
if not result or result.graph_id != graph_id:
raise HTTPException(
status_code=404, detail=f"Graph execution #{graph_exec_id} not found."
@ -671,7 +675,9 @@ async def delete_graph_execution(
graph_exec_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> None:
await execution_db.delete_execution(graph_exec_id=graph_exec_id, user_id=user_id)
await execution_db.delete_graph_execution(
graph_exec_id=graph_exec_id, user_id=user_id
)
########################################################

View File

@ -12,7 +12,7 @@ from backend.data import redis
from backend.data.execution import AsyncRedisExecutionEventBus
from backend.data.user import DEFAULT_USER_ID
from backend.server.conn_manager import ConnectionManager
from backend.server.model import ExecutionSubscription, Methods, WsMessage
from backend.server.model import WSMessage, WSMethod, WSSubscribeGraphExecutionRequest
from backend.util.service import AppProcess, get_service_client
from backend.util.settings import AppEnvironment, Config, Settings
@ -52,7 +52,7 @@ async def event_broadcaster(manager: ConnectionManager):
redis.connect()
event_queue = AsyncRedisExecutionEventBus()
async for event in event_queue.listen():
await manager.send_execution_result(event)
await manager.send_execution_update(event)
except Exception as e:
logger.exception(f"Event broadcaster error: {e}")
raise
@ -85,18 +85,18 @@ async def handle_subscribe(
connection_manager: ConnectionManager,
websocket: WebSocket,
user_id: str,
message: WsMessage,
message: WSMessage,
):
if not message.data:
await websocket.send_text(
WsMessage(
method=Methods.ERROR,
WSMessage(
method=WSMethod.ERROR,
success=False,
error="Subscription data missing",
).model_dump_json()
)
else:
sub_req = ExecutionSubscription.model_validate(message.data)
sub_req = WSSubscribeGraphExecutionRequest.model_validate(message.data)
# Verify that user has read access to graph
# if not get_db_client().get_graph(
@ -113,21 +113,20 @@ async def handle_subscribe(
# )
# return
await connection_manager.subscribe(
channel_key = await connection_manager.subscribe_graph_exec(
user_id=user_id,
graph_id=sub_req.graph_id,
graph_version=sub_req.graph_version,
graph_exec_id=sub_req.graph_exec_id,
websocket=websocket,
)
logger.debug(
f"New execution subscription for user #{user_id} "
f"graph #{sub_req.graph_id}v{sub_req.graph_version}"
f"New subscription for user #{user_id}, "
f"graph execution #{sub_req.graph_exec_id}"
)
await websocket.send_text(
WsMessage(
method=Methods.SUBSCRIBE,
WSMessage(
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
success=True,
channel=f"{user_id}_{sub_req.graph_id}_{sub_req.graph_version}",
channel=channel_key,
).model_dump_json()
)
@ -136,33 +135,32 @@ async def handle_unsubscribe(
connection_manager: ConnectionManager,
websocket: WebSocket,
user_id: str,
message: WsMessage,
message: WSMessage,
):
if not message.data:
await websocket.send_text(
WsMessage(
method=Methods.ERROR,
WSMessage(
method=WSMethod.ERROR,
success=False,
error="Subscription data missing",
).model_dump_json()
)
else:
unsub_req = ExecutionSubscription.model_validate(message.data)
await connection_manager.unsubscribe(
unsub_req = WSSubscribeGraphExecutionRequest.model_validate(message.data)
channel_key = await connection_manager.unsubscribe(
user_id=user_id,
graph_id=unsub_req.graph_id,
graph_version=unsub_req.graph_version,
graph_exec_id=unsub_req.graph_exec_id,
websocket=websocket,
)
logger.debug(
f"Removed execution subscription for user #{user_id} "
f"graph #{unsub_req.graph_id}v{unsub_req.graph_version}"
f"Removed subscription for user #{user_id}, "
f"graph execution #{unsub_req.graph_exec_id}"
)
await websocket.send_text(
WsMessage(
method=Methods.UNSUBSCRIBE,
WSMessage(
method=WSMethod.UNSUBSCRIBE,
success=True,
channel=f"{unsub_req.graph_id}_{unsub_req.graph_version}",
channel=channel_key,
).model_dump_json()
)
@ -179,20 +177,24 @@ async def websocket_router(
user_id = await authenticate_websocket(websocket)
if not user_id:
return
await manager.connect(websocket)
await manager.connect_socket(websocket)
try:
while True:
data = await websocket.receive_text()
message = WsMessage.model_validate_json(data)
message = WSMessage.model_validate_json(data)
if message.method == Methods.HEARTBEAT:
if message.method == WSMethod.HEARTBEAT:
await websocket.send_json(
{"method": Methods.HEARTBEAT.value, "data": "pong", "success": True}
{
"method": WSMethod.HEARTBEAT.value,
"data": "pong",
"success": True,
}
)
continue
try:
if message.method == Methods.SUBSCRIBE:
if message.method == WSMethod.SUBSCRIBE_GRAPH_EXEC:
await handle_subscribe(
connection_manager=manager,
websocket=websocket,
@ -201,7 +203,7 @@ async def websocket_router(
)
continue
elif message.method == Methods.UNSUBSCRIBE:
elif message.method == WSMethod.UNSUBSCRIBE:
await handle_unsubscribe(
connection_manager=manager,
websocket=websocket,
@ -216,7 +218,7 @@ async def websocket_router(
)
continue
if message.method == Methods.ERROR:
if message.method == WSMethod.ERROR:
logger.error(f"WebSocket Error message received: {message.data}")
else:
@ -225,15 +227,15 @@ async def websocket_router(
f"{message.data}"
)
await websocket.send_text(
WsMessage(
method=Methods.ERROR,
WSMessage(
method=WSMethod.ERROR,
success=False,
error="Message type is not processed by the server",
).model_dump_json()
)
except WebSocketDisconnect:
manager.disconnect(websocket)
manager.disconnect_socket(websocket)
logger.debug("WebSocket client disconnected")

View File

@ -5,7 +5,7 @@ from typing import Sequence, cast
from backend.data import db
from backend.data.block import Block, BlockSchema, initialize_blocks
from backend.data.execution import ExecutionResult, ExecutionStatus
from backend.data.execution import ExecutionStatus, NodeExecutionResult
from backend.data.model import _BaseCredentials
from backend.data.user import create_default_user
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
@ -63,7 +63,7 @@ async def wait_execution(
graph_id: str,
graph_exec_id: str,
timeout: int = 30,
) -> Sequence[ExecutionResult]:
) -> Sequence[NodeExecutionResult]:
async def is_execution_completed():
status = await AgentServer().test_get_graph_run_status(graph_exec_id, user_id)
log.info(f"Execution status: {status}")

View File

@ -299,7 +299,6 @@ pydantic-settings = "^2.7.1"
pyjwt = "^2.10.1"
pytest-asyncio = "^0.25.3"
pytest-mock = "^3.14.0"
python-dotenv = "^1.0.1"
supabase = "^2.13.0"
[package.source]

View File

@ -2,9 +2,12 @@ import logging
import os
import pytest
from dotenv import load_dotenv
from backend.util.logging import configure_logging
load_dotenv()
# NOTE: You can run tests like with the --log-cli-level=INFO to see the logs
# Set up logging
configure_logging()

View File

@ -4,9 +4,13 @@ from unittest.mock import AsyncMock
import pytest
from fastapi import WebSocket
from backend.data.execution import ExecutionResult, ExecutionStatus
from backend.data.execution import (
ExecutionStatus,
GraphExecutionEvent,
NodeExecutionEvent,
)
from backend.server.conn_manager import ConnectionManager
from backend.server.model import Methods, WsMessage
from backend.server.model import WSMessage, WSMethod
@pytest.fixture
@ -25,7 +29,7 @@ def mock_websocket() -> AsyncMock:
async def test_connect(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
await connection_manager.connect(mock_websocket)
await connection_manager.connect_socket(mock_websocket)
assert mock_websocket in connection_manager.active_connections
mock_websocket.accept.assert_called_once()
@ -34,37 +38,39 @@ def test_disconnect(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
connection_manager.active_connections.add(mock_websocket)
connection_manager.subscriptions["test_graph_1"] = {mock_websocket}
connection_manager.subscriptions["test_channel_42"] = {mock_websocket}
connection_manager.disconnect(mock_websocket)
connection_manager.disconnect_socket(mock_websocket)
assert mock_websocket not in connection_manager.active_connections
assert mock_websocket not in connection_manager.subscriptions["test_graph_1"]
assert mock_websocket not in connection_manager.subscriptions["test_channel_42"]
@pytest.mark.asyncio
async def test_subscribe(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
await connection_manager.subscribe(
await connection_manager.subscribe_graph_exec(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
graph_exec_id="graph-exec-1",
websocket=mock_websocket,
)
assert mock_websocket in connection_manager.subscriptions["user-1_test_graph_1"]
assert (
mock_websocket
in connection_manager.subscriptions["user-1|graph_exec#graph-exec-1"]
)
@pytest.mark.asyncio
async def test_unsubscribe(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
channel_key = "user-1|graph_exec#graph-exec-1"
connection_manager.subscriptions[channel_key] = {mock_websocket}
await connection_manager.unsubscribe(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
graph_exec_id="graph-exec-1",
websocket=mock_websocket,
)
@ -72,15 +78,46 @@ async def test_unsubscribe(
@pytest.mark.asyncio
async def test_send_execution_result(
async def test_send_graph_execution_result(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
result: ExecutionResult = ExecutionResult(
channel_key = "user-1|graph_exec#graph-exec-1"
connection_manager.subscriptions[channel_key] = {mock_websocket}
result = GraphExecutionEvent(
id="graph-exec-1",
user_id="user-1",
graph_id="test_graph",
graph_version=1,
graph_exec_id="test_exec_id",
status=ExecutionStatus.COMPLETED,
cost=0,
duration=1.2,
total_run_time=0.5,
started_at=datetime.now(tz=timezone.utc),
ended_at=datetime.now(tz=timezone.utc),
)
await connection_manager.send_execution_update(result)
mock_websocket.send_text.assert_called_once_with(
WSMessage(
method=WSMethod.GRAPH_EXECUTION_EVENT,
channel="user-1|graph_exec#graph-exec-1",
data=result.model_dump(),
).model_dump_json()
)
@pytest.mark.asyncio
async def test_send_node_execution_result(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
channel_key = "user-1|graph_exec#graph-exec-1"
connection_manager.subscriptions[channel_key] = {mock_websocket}
result = NodeExecutionEvent(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
graph_exec_id="graph-exec-1",
node_exec_id="test_node_exec_id",
node_id="test_node_id",
block_id="test_block_id",
@ -93,12 +130,12 @@ async def test_send_execution_result(
end_time=datetime.now(tz=timezone.utc),
)
await connection_manager.send_execution_result(result)
await connection_manager.send_execution_update(result)
mock_websocket.send_text.assert_called_once_with(
WsMessage(
method=Methods.EXECUTION_EVENT,
channel="user-1_test_graph_1",
WSMessage(
method=WSMethod.NODE_EXECUTION_EVENT,
channel="user-1|graph_exec#graph-exec-1",
data=result.model_dump(),
).model_dump_json()
)
@ -108,12 +145,13 @@ async def test_send_execution_result(
async def test_send_execution_result_user_mismatch(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
result: ExecutionResult = ExecutionResult(
channel_key = "user-1|graph_exec#graph-exec-1"
connection_manager.subscriptions[channel_key] = {mock_websocket}
result = NodeExecutionEvent(
user_id="user-2",
graph_id="test_graph",
graph_version=1,
graph_exec_id="test_exec_id",
graph_exec_id="graph-exec-1",
node_exec_id="test_node_exec_id",
node_id="test_node_id",
block_id="test_block_id",
@ -126,7 +164,7 @@ async def test_send_execution_result_user_mismatch(
end_time=datetime.now(tz=timezone.utc),
)
await connection_manager.send_execution_result(result)
await connection_manager.send_execution_update(result)
mock_websocket.send_text.assert_not_called()
@ -135,7 +173,7 @@ async def test_send_execution_result_user_mismatch(
async def test_send_execution_result_no_subscribers(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
result: ExecutionResult = ExecutionResult(
result = NodeExecutionEvent(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
@ -152,6 +190,6 @@ async def test_send_execution_result_no_subscribers(
end_time=datetime.now(),
)
await connection_manager.send_execution_result(result)
await connection_manager.send_execution_update(result)
mock_websocket.send_text.assert_not_called()

View File

@ -7,8 +7,8 @@ from fastapi import WebSocket, WebSocketDisconnect
from backend.data.user import DEFAULT_USER_ID
from backend.server.conn_manager import ConnectionManager
from backend.server.ws_api import (
Methods,
WsMessage,
WSMessage,
WSMethod,
handle_subscribe,
handle_unsubscribe,
websocket_router,
@ -30,28 +30,33 @@ async def test_websocket_router_subscribe(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
mock_websocket.receive_text.side_effect = [
WsMessage(
method=Methods.SUBSCRIBE,
data={"graph_id": "test_graph", "graph_version": 1},
WSMessage(
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
data={"graph_exec_id": "test-graph-exec-1"},
).model_dump_json(),
WebSocketDisconnect(),
]
mock_manager.subscribe_graph_exec.return_value = (
f"{DEFAULT_USER_ID}|graph_exec#test-graph-exec-1"
)
await websocket_router(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
)
mock_manager.connect.assert_called_once_with(mock_websocket)
mock_manager.subscribe.assert_called_once_with(
mock_manager.connect_socket.assert_called_once_with(mock_websocket)
mock_manager.subscribe_graph_exec.assert_called_once_with(
user_id=DEFAULT_USER_ID,
graph_id="test_graph",
graph_version=1,
graph_exec_id="test-graph-exec-1",
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once()
assert '"method":"subscribe"' in mock_websocket.send_text.call_args[0][0]
assert (
'"method":"subscribe_graph_execution"'
in mock_websocket.send_text.call_args[0][0]
)
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
mock_manager.disconnect.assert_called_once_with(mock_websocket)
mock_manager.disconnect_socket.assert_called_once_with(mock_websocket)
@pytest.mark.asyncio
@ -59,28 +64,30 @@ async def test_websocket_router_unsubscribe(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
mock_websocket.receive_text.side_effect = [
WsMessage(
method=Methods.UNSUBSCRIBE,
data={"graph_id": "test_graph", "graph_version": 1},
WSMessage(
method=WSMethod.UNSUBSCRIBE,
data={"graph_exec_id": "test-graph-exec-1"},
).model_dump_json(),
WebSocketDisconnect(),
]
mock_manager.unsubscribe.return_value = (
f"{DEFAULT_USER_ID}|graph_exec#test-graph-exec-1"
)
await websocket_router(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
)
mock_manager.connect.assert_called_once_with(mock_websocket)
mock_manager.connect_socket.assert_called_once_with(mock_websocket)
mock_manager.unsubscribe.assert_called_once_with(
user_id=DEFAULT_USER_ID,
graph_id="test_graph",
graph_version=1,
graph_exec_id="test-graph-exec-1",
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once()
assert '"method":"unsubscribe"' in mock_websocket.send_text.call_args[0][0]
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
mock_manager.disconnect.assert_called_once_with(mock_websocket)
mock_manager.disconnect_socket.assert_called_once_with(mock_websocket)
@pytest.mark.asyncio
@ -88,7 +95,7 @@ async def test_websocket_router_invalid_method(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
mock_websocket.receive_text.side_effect = [
WsMessage(method=Methods.EXECUTION_EVENT).model_dump_json(),
WSMessage(method=WSMethod.GRAPH_EXECUTION_EVENT).model_dump_json(),
WebSocketDisconnect(),
]
@ -96,19 +103,23 @@ async def test_websocket_router_invalid_method(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
)
mock_manager.connect.assert_called_once_with(mock_websocket)
mock_manager.connect_socket.assert_called_once_with(mock_websocket)
mock_websocket.send_text.assert_called_once()
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
assert '"success":false' in mock_websocket.send_text.call_args[0][0]
mock_manager.disconnect.assert_called_once_with(mock_websocket)
mock_manager.disconnect_socket.assert_called_once_with(mock_websocket)
@pytest.mark.asyncio
async def test_handle_subscribe_success(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
message = WsMessage(
method=Methods.SUBSCRIBE, data={"graph_id": "test_graph", "graph_version": 1}
message = WSMessage(
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
data={"graph_exec_id": "test-graph-exec-id"},
)
mock_manager.subscribe_graph_exec.return_value = (
"user-1|graph_exec#test-graph-exec-id"
)
await handle_subscribe(
@ -118,14 +129,16 @@ async def test_handle_subscribe_success(
message=message,
)
mock_manager.subscribe.assert_called_once_with(
mock_manager.subscribe_graph_exec.assert_called_once_with(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
graph_exec_id="test-graph-exec-id",
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once()
assert '"method":"subscribe"' in mock_websocket.send_text.call_args[0][0]
assert (
'"method":"subscribe_graph_execution"'
in mock_websocket.send_text.call_args[0][0]
)
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
@ -133,7 +146,7 @@ async def test_handle_subscribe_success(
async def test_handle_subscribe_missing_data(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
message = WsMessage(method=Methods.SUBSCRIBE)
message = WSMessage(method=WSMethod.SUBSCRIBE_GRAPH_EXEC)
await handle_subscribe(
connection_manager=cast(ConnectionManager, mock_manager),
@ -142,7 +155,7 @@ async def test_handle_subscribe_missing_data(
message=message,
)
mock_manager.subscribe.assert_not_called()
mock_manager.subscribe_graph_exec.assert_not_called()
mock_websocket.send_text.assert_called_once()
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
assert '"success":false' in mock_websocket.send_text.call_args[0][0]
@ -152,9 +165,10 @@ async def test_handle_subscribe_missing_data(
async def test_handle_unsubscribe_success(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
message = WsMessage(
method=Methods.UNSUBSCRIBE, data={"graph_id": "test_graph", "graph_version": 1}
message = WSMessage(
method=WSMethod.UNSUBSCRIBE, data={"graph_exec_id": "test-graph-exec-id"}
)
mock_manager.unsubscribe.return_value = "user-1|graph_exec#test-graph-exec-id"
await handle_unsubscribe(
connection_manager=cast(ConnectionManager, mock_manager),
@ -165,8 +179,7 @@ async def test_handle_unsubscribe_success(
mock_manager.unsubscribe.assert_called_once_with(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
graph_exec_id="test-graph-exec-id",
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once()
@ -178,7 +191,7 @@ async def test_handle_unsubscribe_success(
async def test_handle_unsubscribe_missing_data(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
message = WsMessage(method=Methods.UNSUBSCRIBE)
message = WSMessage(method=WSMethod.UNSUBSCRIBE)
await handle_unsubscribe(
connection_manager=cast(ConnectionManager, mock_manager),

View File

@ -0,0 +1,21 @@
import AgentFlowListSkeleton from "@/components/monitor/skeletons/AgentFlowListSkeleton";
import React from "react";
import FlowRunsListSkeleton from "@/components/monitor/skeletons/FlowRunsListSkeleton";
import FlowRunsStatusSkeleton from "@/components/monitor/skeletons/FlowRunsStatusSkeleton";
export default function MonitorLoadingSkeleton() {
return (
<div className="space-y-4 p-4">
<div className="grid grid-cols-1 gap-4 md:grid-cols-3">
{/* Agents Section */}
<AgentFlowListSkeleton />
{/* Runs Section */}
<FlowRunsListSkeleton />
{/* Stats Section */}
<FlowRunsStatusSkeleton />
</div>
</div>
);
}

View File

@ -90,7 +90,7 @@ export default function AgentRunsPage(): React.ReactElement {
);
api.getGraphExecutions(agent.agent_id).then((agentRuns) => {
const sortedRuns = agentRuns.toSorted(
(a, b) => b.started_at - a.started_at,
(a, b) => Number(b.started_at) - Number(a.started_at),
);
setAgentRuns(sortedRuns);
@ -102,7 +102,7 @@ export default function AgentRunsPage(): React.ReactElement {
if (!selectedView.id && isFirstLoad && sortedRuns.length > 0) {
// only for first load or first execution
setIsFirstLoad(false);
selectView({ type: "run", id: sortedRuns[0].execution_id });
selectView({ type: "run", id: sortedRuns[0].id });
}
});
});
@ -121,10 +121,8 @@ export default function AgentRunsPage(): React.ReactElement {
useEffect(() => {
if (selectedView.type != "run" || !selectedView.id || !agent) return;
const newSelectedRun = agentRuns.find(
(run) => run.execution_id == selectedView.id,
);
if (selectedView.id !== selectedRun?.execution_id) {
const newSelectedRun = agentRuns.find((run) => run.id == selectedView.id);
if (selectedView.id !== selectedRun?.id) {
// Pull partial data from "cache" while waiting for the rest to load
setSelectedRun(newSelectedRun ?? null);
@ -136,14 +134,7 @@ export default function AgentRunsPage(): React.ReactElement {
setSelectedRun(run);
});
}
}, [
api,
selectedView,
agent,
agentRuns,
selectedRun?.execution_id,
getGraphVersion,
]);
}, [api, selectedView, agent, agentRuns, selectedRun?.id, getGraphVersion]);
const fetchSchedules = useCallback(async () => {
if (!agent) return;
@ -169,17 +160,15 @@ export default function AgentRunsPage(): React.ReactElement {
const deleteRun = useCallback(
async (run: GraphExecutionMeta) => {
if (run.status == "RUNNING" || run.status == "QUEUED") {
await api.stopGraphExecution(run.graph_id, run.execution_id);
await api.stopGraphExecution(run.graph_id, run.id);
}
await api.deleteGraphExecution(run.execution_id);
await api.deleteGraphExecution(run.id);
setConfirmingDeleteAgentRun(null);
if (selectedView.type == "run" && selectedView.id == run.execution_id) {
if (selectedView.type == "run" && selectedView.id == run.id) {
openRunDraftView();
}
setAgentRuns(
agentRuns.filter((r) => r.execution_id !== run.execution_id),
);
setAgentRuns(agentRuns.filter((r) => r.id !== run.id));
},
[agentRuns, api, selectedView, openRunDraftView],
);

View File

@ -102,9 +102,7 @@ const Monitor = () => {
: executions),
].sort((a, b) => Number(b.started_at) - Number(a.started_at))}
selectedRun={selectedRun}
onSelectRun={(r) =>
setSelectedRun(r.execution_id == selectedRun?.execution_id ? null : r)
}
onSelectRun={(r) => setSelectedRun(r.id == selectedRun?.id ? null : r)}
/>
{(selectedRun && (
<FlowRunInfo

View File

@ -90,8 +90,8 @@ export default function AgentRunDetailsView({
);
const stopRun = useCallback(
() => api.stopGraphExecution(graph.id, run.execution_id),
[api, graph.id, run.execution_id],
() => api.stopGraphExecution(graph.id, run.id),
[api, graph.id, run.id],
);
const agentRunOutputs:

View File

@ -111,8 +111,8 @@ export default function AgentRunsSelectorList({
status={agentRunStatusMap[run.status]}
title={agent.name}
timestamp={run.started_at}
selected={selectedView.id === run.execution_id}
onClick={() => onSelectRun(run.execution_id)}
selected={selectedView.id === run.id}
onClick={() => onSelectRun(run.id)}
onDelete={() => onDeleteRun(run)}
/>
))

View File

@ -125,7 +125,9 @@ export const AgentFlowList = ({
if (!a.lastRun && !b.lastRun) return 0;
if (!a.lastRun) return 1;
if (!b.lastRun) return -1;
return b.lastRun.started_at - a.lastRun.started_at;
return (
Number(b.lastRun.started_at) - Number(a.lastRun.started_at)
);
})
.map(({ flow, runCount, lastRun }) => (
<TableRow

View File

@ -27,7 +27,7 @@ export const FlowRunInfo: React.FC<
const fetchBlockResults = useCallback(async () => {
const executionResults = (
await api.getGraphExecutionInfo(flow.agent_id, execution.execution_id)
await api.getGraphExecutionInfo(flow.agent_id, execution.id)
).node_executions;
// Create a map of the latest COMPLETED execution results of output nodes by node_id
@ -69,7 +69,7 @@ export const FlowRunInfo: React.FC<
result: result.output_data?.output || undefined,
})),
);
}, [api, flow.agent_id, execution.execution_id]);
}, [api, flow.agent_id, execution.id]);
// Fetch graph and execution data
useEffect(() => {
@ -84,8 +84,8 @@ export const FlowRunInfo: React.FC<
}
const handleStopRun = useCallback(() => {
api.stopGraphExecution(flow.agent_id, execution.execution_id);
}, [api, flow.agent_id, execution.execution_id]);
api.stopGraphExecution(flow.agent_id, execution.id);
}, [api, flow.agent_id, execution.id]);
return (
<>
@ -109,7 +109,7 @@ export const FlowRunInfo: React.FC<
{flow.can_access_graph && (
<Link
className={buttonVariants({ variant: "default" })}
href={`/build?flowID=${execution.graph_id}&flowVersion=${execution.graph_version}&flowExecutionID=${execution.execution_id}`}
href={`/build?flowID=${execution.graph_id}&flowVersion=${execution.graph_version}&flowExecutionID=${execution.id}`}
>
<Pencil2Icon className="mr-2" /> Open in Builder
</Link>
@ -121,7 +121,7 @@ export const FlowRunInfo: React.FC<
<strong>Agent ID:</strong> <code>{flow.agent_id}</code>
</p>
<p className="hidden">
<strong>Run ID:</strong> <code>{execution.execution_id}</code>
<strong>Run ID:</strong> <code>{execution.id}</code>
</p>
<div>
<strong>Status:</strong>{" "}

View File

@ -37,17 +37,13 @@ export const FlowRunsList: React.FC<{
<TableBody data-testid="flow-runs-list-body">
{executions.map((execution) => (
<TableRow
key={execution.execution_id}
data-testid={`flow-run-${execution.execution_id}-graph-${execution.graph_id}`}
data-runid={execution.execution_id}
key={execution.id}
data-testid={`flow-run-${execution.id}-graph-${execution.graph_id}`}
data-runid={execution.id}
data-graphid={execution.graph_id}
className="cursor-pointer"
onClick={() => onSelectRun(execution)}
data-state={
selectedRun?.execution_id == execution.execution_id
? "selected"
: null
}
data-state={selectedRun?.id == execution.id ? "selected" : null}
>
<TableCell>
<TextRenderer

View File

@ -29,7 +29,7 @@ export const FlowRunsStatus: React.FC<{
: statsSince;
const filteredFlowRuns =
statsSinceTimestamp != null
? executions.filter((fr) => fr.started_at > statsSinceTimestamp)
? executions.filter((fr) => Number(fr.started_at) > statsSinceTimestamp)
: executions;
return (

View File

@ -99,7 +99,7 @@ export const FlowRunsTimeline = ({
.filter((e) => e.graph_id == flow.agent_id)
.map((e) => ({
...e,
time: e.started_at + e.total_run_time * 1000,
time: Number(e.started_at) + e.total_run_time * 1000,
_duration: e.total_run_time,
}))}
name={flow.name}
@ -108,14 +108,14 @@ export const FlowRunsTimeline = ({
))}
{executions.map((execution) => (
<Line
key={execution.execution_id}
key={execution.id}
type="linear"
dataKey="_duration"
data={[
{ ...execution, time: execution.started_at, _duration: 0 },
{ ...execution, time: Number(execution.started_at), _duration: 0 },
{
...execution,
time: execution.ended_at,
time: Number(execution.ended_at),
_duration: execution.total_run_time,
},
]}

View File

@ -106,24 +106,24 @@ export default function useAgentGraph(
// Subscribe to execution events
useEffect(() => {
api.onWebSocketMessage("execution_event", (data) => {
api.onWebSocketMessage("node_execution_event", (data) => {
if (data.graph_exec_id != flowExecutionID) {
return;
}
setUpdateQueue((prev) => [...prev, data]);
});
if (flowID && flowVersion) {
if (flowExecutionID) {
api
.subscribeToExecution(flowID, flowVersion)
.subscribeToGraphExecution(flowExecutionID)
.then(() =>
console.debug(
`Subscribed to execution events for ${flowID} v.${flowVersion}`,
`Subscribed to updates for execution #${flowExecutionID}`,
),
)
.catch((error) =>
console.error(
`Failed to subscribe to execution events for ${flowID} v.${flowVersion}:`,
`Failed to subscribe to updates for execution #${flowExecutionID}:`,
error,
),
);
@ -235,7 +235,7 @@ export default function useAgentGraph(
return newNodes;
});
},
[availableNodes, availableFlows, formatEdgeID, getOutputType],
[availableNodes, availableFlows, getOutputType],
);
const getFrontendId = useCallback(
@ -636,7 +636,7 @@ export default function useAgentGraph(
// Track execution until completed
const pendingNodeExecutions: Set<string> = new Set();
const cancelExecListener = api.onWebSocketMessage(
"execution_event",
"node_execution_event",
(nodeResult) => {
// We are racing the server here, since we need the ID to filter events
if (nodeResult.graph_exec_id != flowExecutionID) {

View File

@ -106,8 +106,9 @@ export default class BackendAPI {
}
////////////////////////////////////////
///////////// CREDITS //////////////////
/////////////// CREDITS ////////////////
////////////////////////////////////////
getUserCredit(): Promise<{ credits: number }> {
try {
return this._get("/credits");
@ -172,8 +173,9 @@ export default class BackendAPI {
}
////////////////////////////////////////
/////////// ONBOARDING /////////////////
////////////// ONBOARDING //////////////
////////////////////////////////////////
getUserOnboarding(): Promise<UserOnboarding> {
return this._get("/onboarding");
}
@ -192,8 +194,9 @@ export default class BackendAPI {
}
////////////////////////////////////////
/////////// GRAPHS /////////////////////
//////////////// GRAPHS ////////////////
////////////////////////////////////////
getBlocks(): Promise<Block[]> {
return this._get("/blocks");
}
@ -398,9 +401,9 @@ export default class BackendAPI {
return this._request("POST", "/analytics/log_raw_analytics", analytic);
}
///////////////////////////////////////////
/////////// V2 STORE API /////////////////
/////////////////////////////////////////
////////////////////////////////////////
///////////// V2 STORE API /////////////
////////////////////////////////////////
getStoreProfile(): Promise<ProfileDetails | null> {
try {
@ -529,9 +532,9 @@ export default class BackendAPI {
return this._get(url);
}
/////////////////////////////////////////
/////////// Admin API ///////////////////
/////////////////////////////////////////
////////////////////////////////////////
////////////// Admin API ///////////////
////////////////////////////////////////
getAdminListingsWithVersions(params?: {
status?: SubmissionStatus;
@ -553,9 +556,9 @@ export default class BackendAPI {
);
}
/////////////////////////////////////////
/////////// V2 LIBRARY API //////////////
/////////////////////////////////////////
////////////////////////////////////////
//////////// V2 LIBRARY API ////////////
////////////////////////////////////////
listLibraryAgents(params?: {
search_term?: string;
@ -651,9 +654,9 @@ export default class BackendAPI {
);
}
///////////////////////////////////////////
/////////// INTERNAL FUNCTIONS ////////////
//////////////////////////////??///////////
////////////////////////////////////////
////////// INTERNAL FUNCTIONS //////////
////////////////////////////////////////
private _get(path: string, query?: Record<string, any>) {
return this._request("GET", path, query);
@ -815,104 +818,16 @@ export default class BackendAPI {
}
}
startHeartbeat() {
this.stopHeartbeat();
this.heartbeatInterval = window.setInterval(() => {
if (this.webSocket?.readyState === WebSocket.OPEN) {
this.webSocket.send(
JSON.stringify({
method: "heartbeat",
data: "ping",
success: true,
}),
);
////////////////////////////////////////
////////////// WEBSOCKETS //////////////
////////////////////////////////////////
this.heartbeatTimeoutId = window.setTimeout(() => {
console.warn("Heartbeat timeout - reconnecting");
this.webSocket?.close();
this.connectWebSocket();
}, this.HEARTBEAT_TIMEOUT);
}
}, this.HEARTBEAT_INTERVAL);
}
stopHeartbeat() {
if (this.heartbeatInterval) {
clearInterval(this.heartbeatInterval);
this.heartbeatInterval = null;
}
if (this.heartbeatTimeoutId) {
clearTimeout(this.heartbeatTimeoutId);
this.heartbeatTimeoutId = null;
}
}
handleHeartbeatResponse() {
if (this.heartbeatTimeoutId) {
clearTimeout(this.heartbeatTimeoutId);
this.heartbeatTimeoutId = null;
}
}
async connectWebSocket(): Promise<void> {
this.wsConnecting ??= new Promise(async (resolve, reject) => {
try {
const token =
(await this.supabaseClient?.auth.getSession())?.data.session
?.access_token || "";
const wsUrlWithToken = `${this.wsUrl}?token=${token}`;
this.webSocket = new WebSocket(wsUrlWithToken);
this.webSocket.onopen = () => {
this.startHeartbeat(); // Start heartbeat when connection opens
resolve();
};
this.webSocket.onclose = (event) => {
console.warn("WebSocket connection closed", event);
this.stopHeartbeat(); // Stop heartbeat when connection closes
this.wsConnecting = null;
// Attempt to reconnect after a delay
setTimeout(() => this.connectWebSocket(), 1000);
};
this.webSocket.onerror = (error) => {
console.error("WebSocket error:", error);
this.stopHeartbeat(); // Stop heartbeat on error
this.wsConnecting = null;
reject(error);
};
this.webSocket.onmessage = (event) => {
const message: WebsocketMessage = JSON.parse(event.data);
// Handle heartbeat response
if (message.method === "heartbeat" && message.data === "pong") {
this.handleHeartbeatResponse();
return;
}
if (message.method === "execution_event") {
message.data = parseNodeExecutionResultTimestamps(message.data);
}
this.wsMessageHandlers[message.method]?.forEach((handler) =>
handler(message.data),
);
};
} catch (error) {
console.error("Error connecting to WebSocket:", error);
reject(error);
}
subscribeToGraphExecution(graphExecID: GraphExecutionID): Promise<void> {
return this.sendWebSocketMessage("subscribe_graph_execution", {
graph_exec_id: graphExecID,
});
return this.wsConnecting;
}
disconnectWebSocket() {
this.stopHeartbeat(); // Stop heartbeat when disconnecting
if (this.webSocket && this.webSocket.readyState === WebSocket.OPEN) {
this.webSocket.close();
}
}
async sendWebSocketMessage<M extends keyof WebsocketMessageTypeMap>(
method: M,
data: WebsocketMessageTypeMap[M],
@ -920,7 +835,7 @@ export default class BackendAPI {
callCountLimit = 4,
): Promise<void> {
if (this.webSocket && this.webSocket.readyState === WebSocket.OPEN) {
const result = this.webSocket.send(JSON.stringify({ method, data }));
this.webSocket.send(JSON.stringify({ method, data }));
return;
}
if (callCount >= callCountLimit) {
@ -948,11 +863,105 @@ export default class BackendAPI {
return () => this.wsMessageHandlers[method].delete(handler);
}
async subscribeToExecution(graphId: string, graphVersion: number) {
await this.sendWebSocketMessage("subscribe", {
graph_id: graphId,
graph_version: graphVersion,
async connectWebSocket(): Promise<void> {
this.wsConnecting ??= new Promise(async (resolve, reject) => {
try {
const token =
(await this.supabaseClient?.auth.getSession())?.data.session
?.access_token || "";
const wsUrlWithToken = `${this.wsUrl}?token=${token}`;
this.webSocket = new WebSocket(wsUrlWithToken);
this.webSocket.onopen = () => {
this._startWSHeartbeat(); // Start heartbeat when connection opens
resolve();
};
this.webSocket.onclose = (event) => {
console.warn("WebSocket connection closed", event);
this._stopWSHeartbeat(); // Stop heartbeat when connection closes
this.wsConnecting = null;
// Attempt to reconnect after a delay
setTimeout(() => this.connectWebSocket(), 1000);
};
this.webSocket.onerror = (error) => {
console.error("WebSocket error:", error);
this._stopWSHeartbeat(); // Stop heartbeat on error
this.wsConnecting = null;
reject(error);
};
this.webSocket.onmessage = (event) => {
const message: WebsocketMessage = JSON.parse(event.data);
// Handle heartbeat response
if (message.method === "heartbeat" && message.data === "pong") {
this._handleWSHeartbeatResponse();
return;
}
if (message.method === "node_execution_event") {
message.data = parseNodeExecutionResultTimestamps(message.data);
} else if (message.method == "graph_execution_event") {
message.data = parseGraphExecutionMetaTimestamps(message.data);
}
this.wsMessageHandlers[message.method]?.forEach((handler) =>
handler(message.data),
);
};
} catch (error) {
console.error("Error connecting to WebSocket:", error);
reject(error);
}
});
return this.wsConnecting;
}
disconnectWebSocket() {
this._stopWSHeartbeat(); // Stop heartbeat when disconnecting
if (this.webSocket && this.webSocket.readyState === WebSocket.OPEN) {
this.webSocket.close();
}
}
_startWSHeartbeat() {
this._stopWSHeartbeat();
this.heartbeatInterval = window.setInterval(() => {
if (this.webSocket?.readyState === WebSocket.OPEN) {
this.webSocket.send(
JSON.stringify({
method: "heartbeat",
data: "ping",
success: true,
}),
);
this.heartbeatTimeoutId = window.setTimeout(() => {
console.warn("Heartbeat timeout - reconnecting");
this.webSocket?.close();
this.connectWebSocket();
}, this.HEARTBEAT_TIMEOUT);
}
}, this.HEARTBEAT_INTERVAL);
}
_stopWSHeartbeat() {
if (this.heartbeatInterval) {
clearInterval(this.heartbeatInterval);
this.heartbeatInterval = null;
}
if (this.heartbeatTimeoutId) {
clearTimeout(this.heartbeatTimeoutId);
this.heartbeatTimeoutId = null;
}
}
_handleWSHeartbeatResponse() {
if (this.heartbeatTimeoutId) {
clearTimeout(this.heartbeatTimeoutId);
this.heartbeatTimeoutId = null;
}
}
}
@ -963,8 +972,9 @@ type GraphCreateRequestBody = {
};
type WebsocketMessageTypeMap = {
subscribe: { graph_id: string; graph_version: number };
execution_event: NodeExecutionResult;
subscribe_graph_execution: { graph_exec_id: GraphExecutionID };
graph_execution_event: GraphExecutionMeta;
node_execution_event: NodeExecutionResult;
heartbeat: "ping" | "pong";
};
@ -984,19 +994,30 @@ type _PydanticValidationError = {
/* *** HELPER FUNCTIONS *** */
function parseGraphExecutionMetaTimestamps(result: any): GraphExecutionMeta {
return _parseObjectTimestamps<GraphExecutionMeta>(result, [
"started_at",
"ended_at",
]);
}
function parseNodeExecutionResultTimestamps(result: any): NodeExecutionResult {
return {
...result,
add_time: new Date(result.add_time),
queue_time: result.queue_time ? new Date(result.queue_time) : undefined,
start_time: result.start_time ? new Date(result.start_time) : undefined,
end_time: result.end_time ? new Date(result.end_time) : undefined,
};
return _parseObjectTimestamps<NodeExecutionResult>(result, [
"add_time",
"queue_time",
"start_time",
"end_time",
]);
}
function parseScheduleTimestamp(result: any): Schedule {
return {
...result,
next_run_time: new Date(result.next_run_time),
};
return _parseObjectTimestamps<Schedule>(result, ["next_run_time"]);
}
function _parseObjectTimestamps<T>(obj: any, keys: (keyof T)[]): T {
const result = { ...obj };
keys.forEach(
(key) => (result[key] = result[key] ? new Date(result[key]) : undefined),
);
return result;
}

View File

@ -230,9 +230,9 @@ export type LinkCreatable = Omit<Link, "id" | "is_static"> & {
/* Mirror of backend/data/graph.py:GraphExecutionMeta */
export type GraphExecutionMeta = {
execution_id: GraphExecutionID;
started_at: number;
ended_at: number;
id: GraphExecutionID;
started_at: Date;
ended_at: Date;
cost?: number;
duration: number;
total_run_time: number;