mirror of
https://github.com/Significant-Gravitas/Auto-GPT.git
synced 2025-04-03 23:41:45 +08:00
refactor(backend): Reorganize & clean up execution update system (#9663)
- Prep work for #8782 - Prep work for #8779 ### Changes 🏗️ - refactor(platform): Differentiate graph/node execution events - fix(platform): Subscribe to execution updates by `graph_exec_id` instead of `graph_id`+`graph_version` - refactor(backend): Move all execution related models and functions from `.data.graph` to `.data.execution` - refactor(backend): Reorganize & refactor `.data.execution` - fix(libs): Remove `load_dotenv` in `.auth.config` to fix test config issues - dx: Bump version of `black` in pre-commit config to v24.10.0 to match poetry.lock - Other minor refactoring in both frontend and backend ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - Run an agent in the builder - [x] -> works normally, node I/O is updated in real time - Run an agent in the library - [x] -> works normally --------- Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
This commit is contained in:
parent
37f212e950
commit
1162ec1474
@ -140,7 +140,7 @@ repos:
|
||||
language: system
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.12.1
|
||||
rev: 24.10.0
|
||||
# Black has sensible defaults, doesn't need package context, and ignores
|
||||
# everything in .gitignore, so it works fine without any config or arguments.
|
||||
hooks:
|
||||
|
@ -1,11 +1,9 @@
|
||||
from .config import Settings
|
||||
from .depends import requires_admin_user, requires_user
|
||||
from .jwt_utils import parse_jwt_token
|
||||
from .middleware import APIKeyValidator, auth_middleware
|
||||
from .models import User
|
||||
|
||||
__all__ = [
|
||||
"Settings",
|
||||
"parse_jwt_token",
|
||||
"requires_user",
|
||||
"requires_admin_user",
|
||||
|
@ -1,14 +1,11 @@
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Settings:
|
||||
JWT_SECRET_KEY: str = os.getenv("SUPABASE_JWT_SECRET", "")
|
||||
ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
|
||||
JWT_ALGORITHM: str = "HS256"
|
||||
def __init__(self):
|
||||
self.JWT_SECRET_KEY: str = os.getenv("SUPABASE_JWT_SECRET", "")
|
||||
self.ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
|
||||
self.JWT_ALGORITHM: str = "HS256"
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
|
@ -1,6 +1,6 @@
|
||||
import fastapi
|
||||
|
||||
from .config import Settings
|
||||
from .config import settings
|
||||
from .middleware import auth_middleware
|
||||
from .models import DEFAULT_USER_ID, User
|
||||
|
||||
@ -17,7 +17,7 @@ def requires_admin_user(
|
||||
|
||||
def verify_user(payload: dict | None, admin_only: bool) -> User:
|
||||
if not payload:
|
||||
if Settings.ENABLE_AUTH:
|
||||
if settings.ENABLE_AUTH:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="Authorization header is missing"
|
||||
)
|
||||
|
2
autogpt_platform/autogpt_libs/poetry.lock
generated
2
autogpt_platform/autogpt_libs/poetry.lock
generated
@ -1929,4 +1929,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "931772287f71c539575d601e6398423bf68e09ca87ae1a144057c7f5707cf978"
|
||||
content-hash = "02023e8698c80648fec23a112ec2ec90d617bba83081d194fab90f682908f0f3"
|
||||
|
@ -16,7 +16,6 @@ pyjwt = "^2.10.1"
|
||||
pytest-asyncio = "^0.25.3"
|
||||
pytest-mock = "^3.14.0"
|
||||
python = ">=3.10,<4.0"
|
||||
python-dotenv = "^1.0.1"
|
||||
supabase = "^2.13.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
@ -75,6 +75,8 @@ class AgentExecutorBlock(Block):
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
from backend.data.execution import ExecutionEventType
|
||||
|
||||
executor_manager = get_executor_manager_client()
|
||||
event_bus = get_event_bus()
|
||||
|
||||
@ -90,11 +92,7 @@ class AgentExecutorBlock(Block):
|
||||
for event in event_bus.listen(
|
||||
graph_id=graph_exec.graph_id, graph_exec_id=graph_exec.graph_exec_id
|
||||
):
|
||||
logger.info(
|
||||
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
|
||||
)
|
||||
|
||||
if not event.node_id:
|
||||
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
|
||||
if event.status in [
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.TERMINATED,
|
||||
@ -105,6 +103,10 @@ class AgentExecutorBlock(Block):
|
||||
else:
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
|
||||
)
|
||||
|
||||
if not event.block_id:
|
||||
logger.warning(f"{log_id} received event without block_id {event}")
|
||||
continue
|
||||
|
@ -220,9 +220,8 @@ def event():
|
||||
|
||||
@test.command()
|
||||
@click.argument("server_address")
|
||||
@click.argument("graph_id")
|
||||
@click.argument("graph_version")
|
||||
def websocket(server_address: str, graph_id: str, graph_version: int):
|
||||
@click.argument("graph_exec_id")
|
||||
def websocket(server_address: str, graph_exec_id: str):
|
||||
"""
|
||||
Tests the websocket connection.
|
||||
"""
|
||||
@ -230,16 +229,20 @@ def websocket(server_address: str, graph_id: str, graph_version: int):
|
||||
|
||||
import websockets.asyncio.client
|
||||
|
||||
from backend.server.ws_api import ExecutionSubscription, Methods, WsMessage
|
||||
from backend.server.ws_api import (
|
||||
WSMessage,
|
||||
WSMethod,
|
||||
WSSubscribeGraphExecutionRequest,
|
||||
)
|
||||
|
||||
async def send_message(server_address: str):
|
||||
uri = f"ws://{server_address}"
|
||||
async with websockets.asyncio.client.connect(uri) as websocket:
|
||||
try:
|
||||
msg = WsMessage(
|
||||
method=Methods.SUBSCRIBE,
|
||||
data=ExecutionSubscription(
|
||||
graph_id=graph_id, graph_version=graph_version
|
||||
msg = WSMessage(
|
||||
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
|
||||
data=WSSubscribeGraphExecutionRequest(
|
||||
graph_exec_id=graph_exec_id,
|
||||
).model_dump(),
|
||||
).model_dump_json()
|
||||
await websocket.send(msg)
|
||||
|
@ -1,7 +1,18 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from multiprocessing import Manager
|
||||
from typing import Any, AsyncGenerator, Generator, Generic, Optional, Type, TypeVar
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Generator,
|
||||
Generic,
|
||||
Literal,
|
||||
Optional,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from prisma import Json
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
@ -10,62 +21,148 @@ from prisma.models import (
|
||||
AgentNodeExecution,
|
||||
AgentNodeExecutionInputOutput,
|
||||
)
|
||||
from prisma.types import AgentNodeExecutionUpdateInput, AgentNodeExecutionWhereInput
|
||||
from prisma.types import (
|
||||
AgentGraphExecutionWhereInput,
|
||||
AgentNodeExecutionUpdateInput,
|
||||
AgentNodeExecutionWhereInput,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
from backend.data.block import BlockData, BlockInput, CompletedBlockOutput
|
||||
from backend.data.includes import EXECUTION_RESULT_INCLUDE, GRAPH_EXECUTION_INCLUDE
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.queue import AsyncRedisEventBus, RedisEventBus
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util import mock, type
|
||||
from backend.util import mock
|
||||
from backend.util import type as type_utils
|
||||
from backend.util.settings import Config
|
||||
|
||||
from .block import BlockData, BlockInput, BlockType, CompletedBlockOutput, get_block
|
||||
from .db import BaseDbModel
|
||||
from .includes import EXECUTION_RESULT_INCLUDE, GRAPH_EXECUTION_INCLUDE
|
||||
from .model import GraphExecutionStats, NodeExecutionStats
|
||||
from .queue import AsyncRedisEventBus, RedisEventBus
|
||||
|
||||
class GraphExecutionEntry(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
start_node_execs: list["NodeExecutionEntry"]
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = Config()
|
||||
|
||||
|
||||
class NodeExecutionEntry(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
node_exec_id: str
|
||||
node_id: str
|
||||
block_id: str
|
||||
data: BlockInput
|
||||
# -------------------------- Models -------------------------- #
|
||||
|
||||
|
||||
ExecutionStatus = AgentExecutionStatus
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
class GraphExecutionMeta(BaseDbModel):
|
||||
user_id: str
|
||||
started_at: datetime
|
||||
ended_at: datetime
|
||||
cost: Optional[int] = Field(..., description="Execution cost in credits")
|
||||
duration: float = Field(..., description="Seconds from start to end of run")
|
||||
total_run_time: float = Field(..., description="Seconds of node runtime")
|
||||
status: ExecutionStatus
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
preset_id: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def from_db(_graph_exec: AgentGraphExecution):
|
||||
now = datetime.now(timezone.utc)
|
||||
start_time = _graph_exec.startedAt or _graph_exec.createdAt
|
||||
end_time = _graph_exec.updatedAt or now
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
total_run_time = duration
|
||||
|
||||
try:
|
||||
stats = GraphExecutionStats.model_validate(_graph_exec.stats)
|
||||
except ValueError as e:
|
||||
if _graph_exec.stats is not None:
|
||||
logger.warning(
|
||||
"Failed to parse invalid graph execution stats "
|
||||
f"{_graph_exec.stats}: {e}"
|
||||
)
|
||||
stats = None
|
||||
|
||||
duration = stats.walltime if stats else duration
|
||||
total_run_time = stats.nodes_walltime if stats else total_run_time
|
||||
|
||||
return GraphExecutionMeta(
|
||||
id=_graph_exec.id,
|
||||
user_id=_graph_exec.userId,
|
||||
started_at=start_time,
|
||||
ended_at=end_time,
|
||||
cost=stats.cost if stats else None,
|
||||
duration=duration,
|
||||
total_run_time=total_run_time,
|
||||
status=ExecutionStatus(_graph_exec.executionStatus),
|
||||
graph_id=_graph_exec.agentGraphId,
|
||||
graph_version=_graph_exec.agentGraphVersion,
|
||||
preset_id=_graph_exec.agentPresetId,
|
||||
)
|
||||
|
||||
|
||||
class ExecutionQueue(Generic[T]):
|
||||
"""
|
||||
Queue for managing the execution of agents.
|
||||
This will be shared between different processes
|
||||
"""
|
||||
class GraphExecution(GraphExecutionMeta):
|
||||
inputs: BlockInput
|
||||
outputs: CompletedBlockOutput
|
||||
node_executions: list["NodeExecutionResult"]
|
||||
|
||||
def __init__(self):
|
||||
self.queue = Manager().Queue()
|
||||
@staticmethod
|
||||
def from_db(_graph_exec: AgentGraphExecution):
|
||||
if _graph_exec.AgentNodeExecutions is None:
|
||||
raise ValueError("Node executions must be included in query")
|
||||
|
||||
def add(self, execution: T) -> T:
|
||||
self.queue.put(execution)
|
||||
return execution
|
||||
graph_exec = GraphExecutionMeta.from_db(_graph_exec)
|
||||
|
||||
def get(self) -> T:
|
||||
return self.queue.get()
|
||||
node_executions = sorted(
|
||||
[
|
||||
NodeExecutionResult.from_db(ne, _graph_exec.userId)
|
||||
for ne in _graph_exec.AgentNodeExecutions
|
||||
],
|
||||
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
|
||||
)
|
||||
|
||||
def empty(self) -> bool:
|
||||
return self.queue.empty()
|
||||
inputs = {
|
||||
**{
|
||||
# inputs from Agent Input Blocks
|
||||
exec.input_data["name"]: exec.input_data.get("value")
|
||||
for exec in node_executions
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
and block.block_type == BlockType.INPUT
|
||||
)
|
||||
},
|
||||
**{
|
||||
# input from webhook-triggered block
|
||||
"payload": exec.input_data["payload"]
|
||||
for exec in node_executions
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
and block.block_type
|
||||
in [BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL]
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
outputs: CompletedBlockOutput = defaultdict(list)
|
||||
for exec in node_executions:
|
||||
if (
|
||||
block := get_block(exec.block_id)
|
||||
) and block.block_type == BlockType.OUTPUT:
|
||||
outputs[exec.input_data["name"]].append(
|
||||
exec.input_data.get("value", None)
|
||||
)
|
||||
|
||||
return GraphExecution(
|
||||
**{
|
||||
field_name: getattr(graph_exec, field_name)
|
||||
for field_name in graph_exec.model_fields
|
||||
},
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
node_executions=node_executions,
|
||||
)
|
||||
|
||||
|
||||
class ExecutionResult(BaseModel):
|
||||
class NodeExecutionResult(BaseModel):
|
||||
user_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
@ -81,41 +178,20 @@ class ExecutionResult(BaseModel):
|
||||
start_time: datetime | None
|
||||
end_time: datetime | None
|
||||
|
||||
@staticmethod
|
||||
def from_graph(graph_exec: AgentGraphExecution):
|
||||
return ExecutionResult(
|
||||
user_id=graph_exec.userId,
|
||||
graph_id=graph_exec.agentGraphId,
|
||||
graph_version=graph_exec.agentGraphVersion,
|
||||
graph_exec_id=graph_exec.id,
|
||||
node_exec_id="",
|
||||
node_id="",
|
||||
block_id="",
|
||||
status=graph_exec.executionStatus,
|
||||
# TODO: Populate input_data & output_data from AgentNodeExecutions
|
||||
# Input & Output comes AgentInputBlock & AgentOutputBlock.
|
||||
input_data={},
|
||||
output_data={},
|
||||
add_time=graph_exec.createdAt,
|
||||
queue_time=graph_exec.createdAt,
|
||||
start_time=graph_exec.startedAt,
|
||||
end_time=graph_exec.updatedAt,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_db(execution: AgentNodeExecution, user_id: Optional[str] = None):
|
||||
if execution.executionData:
|
||||
# Execution that has been queued for execution will persist its data.
|
||||
input_data = type.convert(execution.executionData, dict[str, Any])
|
||||
input_data = type_utils.convert(execution.executionData, dict[str, Any])
|
||||
else:
|
||||
# For incomplete execution, executionData will not be yet available.
|
||||
input_data: BlockInput = defaultdict()
|
||||
for data in execution.Input or []:
|
||||
input_data[data.name] = type.convert(data.data, Type[Any])
|
||||
input_data[data.name] = type_utils.convert(data.data, type[Any])
|
||||
|
||||
output_data: CompletedBlockOutput = defaultdict(list)
|
||||
for data in execution.Output or []:
|
||||
output_data[data.name].append(type.convert(data.data, Type[Any]))
|
||||
output_data[data.name].append(type_utils.convert(data.data, type[Any]))
|
||||
|
||||
graph_execution: AgentGraphExecution | None = execution.AgentGraphExecution
|
||||
if graph_execution:
|
||||
@ -125,7 +201,7 @@ class ExecutionResult(BaseModel):
|
||||
"AgentGraphExecution must be included or user_id passed in"
|
||||
)
|
||||
|
||||
return ExecutionResult(
|
||||
return NodeExecutionResult(
|
||||
user_id=user_id,
|
||||
graph_id=graph_execution.agentGraphId if graph_execution else "",
|
||||
graph_version=graph_execution.agentGraphVersion if graph_execution else 0,
|
||||
@ -146,13 +222,49 @@ class ExecutionResult(BaseModel):
|
||||
# --------------------- Model functions --------------------- #
|
||||
|
||||
|
||||
async def get_graph_executions(
|
||||
graph_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> list[GraphExecutionMeta]:
|
||||
where_filter: AgentGraphExecutionWhereInput = {
|
||||
"isDeleted": False,
|
||||
}
|
||||
if user_id:
|
||||
where_filter["userId"] = user_id
|
||||
if graph_id:
|
||||
where_filter["agentGraphId"] = graph_id
|
||||
|
||||
executions = await AgentGraphExecution.prisma().find_many(
|
||||
where=where_filter,
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
return [GraphExecutionMeta.from_db(execution) for execution in executions]
|
||||
|
||||
|
||||
async def get_graph_execution_meta(
|
||||
user_id: str, execution_id: str
|
||||
) -> GraphExecutionMeta | None:
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={"id": execution_id, "isDeleted": False, "userId": user_id}
|
||||
)
|
||||
return GraphExecutionMeta.from_db(execution) if execution else None
|
||||
|
||||
|
||||
async def get_graph_execution(user_id: str, execution_id: str) -> GraphExecution | None:
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={"id": execution_id, "isDeleted": False, "userId": user_id},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
return GraphExecution.from_db(execution) if execution else None
|
||||
|
||||
|
||||
async def create_graph_execution(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
nodes_input: list[tuple[str, BlockInput]],
|
||||
user_id: str,
|
||||
preset_id: str | None = None,
|
||||
) -> tuple[str, list[ExecutionResult]]:
|
||||
) -> tuple[str, list[NodeExecutionResult]]:
|
||||
"""
|
||||
Create a new AgentGraphExecution record.
|
||||
Returns:
|
||||
@ -186,7 +298,7 @@ async def create_graph_execution(
|
||||
)
|
||||
|
||||
return result.id, [
|
||||
ExecutionResult.from_db(execution, result.userId)
|
||||
NodeExecutionResult.from_db(execution, result.userId)
|
||||
for execution in result.AgentNodeExecutions or []
|
||||
]
|
||||
|
||||
@ -236,7 +348,7 @@ async def upsert_execution_input(
|
||||
)
|
||||
return existing_execution.id, {
|
||||
**{
|
||||
input_data.name: type.convert(input_data.data, Type[Any])
|
||||
input_data.name: type_utils.convert(input_data.data, type[Any])
|
||||
for input_data in existing_execution.Input or []
|
||||
},
|
||||
input_name: input_data,
|
||||
@ -276,7 +388,7 @@ async def upsert_execution_output(
|
||||
)
|
||||
|
||||
|
||||
async def update_graph_execution_start_time(graph_exec_id: str) -> ExecutionResult:
|
||||
async def update_graph_execution_start_time(graph_exec_id: str) -> GraphExecutionMeta:
|
||||
res = await AgentGraphExecution.prisma().update(
|
||||
where={"id": graph_exec_id},
|
||||
data={
|
||||
@ -285,16 +397,16 @@ async def update_graph_execution_start_time(graph_exec_id: str) -> ExecutionResu
|
||||
},
|
||||
)
|
||||
if not res:
|
||||
raise ValueError(f"Execution {graph_exec_id} not found.")
|
||||
raise ValueError(f"Graph execution #{graph_exec_id} not found")
|
||||
|
||||
return ExecutionResult.from_graph(res)
|
||||
return GraphExecutionMeta.from_db(res)
|
||||
|
||||
|
||||
async def update_graph_execution_stats(
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> ExecutionResult:
|
||||
) -> GraphExecutionMeta:
|
||||
data = stats.model_dump() if stats else {}
|
||||
if isinstance(data.get("error"), Exception):
|
||||
data["error"] = str(data["error"])
|
||||
@ -312,9 +424,9 @@ async def update_graph_execution_stats(
|
||||
},
|
||||
)
|
||||
if not res:
|
||||
raise ValueError(f"Execution {graph_exec_id} not found.")
|
||||
raise ValueError(f"Graph execution #{graph_exec_id} not found")
|
||||
|
||||
return ExecutionResult.from_graph(res)
|
||||
return GraphExecutionMeta.from_db(res)
|
||||
|
||||
|
||||
async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionStats):
|
||||
@ -327,7 +439,7 @@ async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionSta
|
||||
)
|
||||
|
||||
|
||||
async def update_execution_status_batch(
|
||||
async def update_node_execution_status_batch(
|
||||
node_exec_ids: list[str],
|
||||
status: ExecutionStatus,
|
||||
stats: dict[str, Any] | None = None,
|
||||
@ -338,12 +450,12 @@ async def update_execution_status_batch(
|
||||
)
|
||||
|
||||
|
||||
async def update_execution_status(
|
||||
async def update_node_execution_status(
|
||||
node_exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: BlockInput | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
) -> ExecutionResult:
|
||||
) -> NodeExecutionResult:
|
||||
if status == ExecutionStatus.QUEUED and execution_data is None:
|
||||
raise ValueError("Execution data must be provided when queuing an execution.")
|
||||
|
||||
@ -355,7 +467,7 @@ async def update_execution_status(
|
||||
if not res:
|
||||
raise ValueError(f"Execution {node_exec_id} not found.")
|
||||
|
||||
return ExecutionResult.from_db(res)
|
||||
return NodeExecutionResult.from_db(res)
|
||||
|
||||
|
||||
def _get_update_status_data(
|
||||
@ -381,7 +493,7 @@ def _get_update_status_data(
|
||||
return update_data
|
||||
|
||||
|
||||
async def delete_execution(
|
||||
async def delete_graph_execution(
|
||||
graph_exec_id: str, user_id: str, soft_delete: bool = True
|
||||
) -> None:
|
||||
if soft_delete:
|
||||
@ -398,12 +510,12 @@ async def delete_execution(
|
||||
)
|
||||
|
||||
|
||||
async def get_execution_results(
|
||||
async def get_node_execution_results(
|
||||
graph_exec_id: str,
|
||||
block_ids: list[str] | None = None,
|
||||
statuses: list[ExecutionStatus] | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[ExecutionResult]:
|
||||
) -> list[NodeExecutionResult]:
|
||||
where_clause: AgentNodeExecutionWhereInput = {
|
||||
"agentGraphExecutionId": graph_exec_id,
|
||||
}
|
||||
@ -417,10 +529,110 @@ async def get_execution_results(
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
take=limit,
|
||||
)
|
||||
res = [ExecutionResult.from_db(execution) for execution in executions]
|
||||
res = [NodeExecutionResult.from_db(execution) for execution in executions]
|
||||
return res
|
||||
|
||||
|
||||
async def get_graph_executions_in_timerange(
|
||||
user_id: str, start_time: str, end_time: str
|
||||
) -> list[GraphExecution]:
|
||||
try:
|
||||
executions = await AgentGraphExecution.prisma().find_many(
|
||||
where={
|
||||
"startedAt": {
|
||||
"gte": datetime.fromisoformat(start_time),
|
||||
"lte": datetime.fromisoformat(end_time),
|
||||
},
|
||||
"userId": user_id,
|
||||
"isDeleted": False,
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
return [GraphExecution.from_db(execution) for execution in executions]
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to get executions in timerange {start_time} to {end_time} for user {user_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_latest_node_execution(
|
||||
node_id: str, graph_eid: str
|
||||
) -> NodeExecutionResult | None:
|
||||
execution = await AgentNodeExecution.prisma().find_first(
|
||||
where={
|
||||
"agentNodeId": node_id,
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
|
||||
},
|
||||
order=[
|
||||
{"queuedTime": "desc"},
|
||||
{"addedTime": "desc"},
|
||||
],
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
if not execution:
|
||||
return None
|
||||
return NodeExecutionResult.from_db(execution)
|
||||
|
||||
|
||||
async def get_incomplete_node_executions(
|
||||
node_id: str, graph_eid: str
|
||||
) -> list[NodeExecutionResult]:
|
||||
executions = await AgentNodeExecution.prisma().find_many(
|
||||
where={
|
||||
"agentNodeId": node_id,
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
},
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
return [NodeExecutionResult.from_db(execution) for execution in executions]
|
||||
|
||||
|
||||
# ----------------- Execution Infrastructure ----------------- #
|
||||
|
||||
|
||||
class GraphExecutionEntry(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
start_node_execs: list["NodeExecutionEntry"]
|
||||
|
||||
|
||||
class NodeExecutionEntry(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
node_exec_id: str
|
||||
node_id: str
|
||||
block_id: str
|
||||
data: BlockInput
|
||||
|
||||
|
||||
class ExecutionQueue(Generic[T]):
|
||||
"""
|
||||
Queue for managing the execution of agents.
|
||||
This will be shared between different processes
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.queue = Manager().Queue()
|
||||
|
||||
def add(self, execution: T) -> T:
|
||||
self.queue.put(execution)
|
||||
return execution
|
||||
|
||||
def get(self) -> T:
|
||||
return self.queue.get()
|
||||
|
||||
def empty(self) -> bool:
|
||||
return self.queue.empty()
|
||||
|
||||
|
||||
# ------------------- Execution Utilities -------------------- #
|
||||
|
||||
|
||||
LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
@ -556,72 +768,82 @@ def merge_execution_input(data: BlockInput) -> BlockInput:
|
||||
return data
|
||||
|
||||
|
||||
async def get_latest_execution(node_id: str, graph_eid: str) -> ExecutionResult | None:
|
||||
execution = await AgentNodeExecution.prisma().find_first(
|
||||
where={
|
||||
"agentNodeId": node_id,
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
|
||||
},
|
||||
order=[
|
||||
{"queuedTime": "desc"},
|
||||
{"addedTime": "desc"},
|
||||
],
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
if not execution:
|
||||
return None
|
||||
return ExecutionResult.from_db(execution)
|
||||
|
||||
|
||||
async def get_incomplete_executions(
|
||||
node_id: str, graph_eid: str
|
||||
) -> list[ExecutionResult]:
|
||||
executions = await AgentNodeExecution.prisma().find_many(
|
||||
where={
|
||||
"agentNodeId": node_id,
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
},
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
return [ExecutionResult.from_db(execution) for execution in executions]
|
||||
|
||||
|
||||
# --------------------- Event Bus --------------------- #
|
||||
|
||||
config = Config()
|
||||
|
||||
class ExecutionEventType(str, Enum):
|
||||
GRAPH_EXEC_UPDATE = "graph_execution_update"
|
||||
NODE_EXEC_UPDATE = "node_execution_update"
|
||||
|
||||
|
||||
class RedisExecutionEventBus(RedisEventBus[ExecutionResult]):
|
||||
Model = ExecutionResult
|
||||
class GraphExecutionEvent(GraphExecutionMeta):
|
||||
event_type: Literal[ExecutionEventType.GRAPH_EXEC_UPDATE] = (
|
||||
ExecutionEventType.GRAPH_EXEC_UPDATE
|
||||
)
|
||||
|
||||
|
||||
class NodeExecutionEvent(NodeExecutionResult):
|
||||
event_type: Literal[ExecutionEventType.NODE_EXEC_UPDATE] = (
|
||||
ExecutionEventType.NODE_EXEC_UPDATE
|
||||
)
|
||||
|
||||
|
||||
ExecutionEvent = Annotated[
|
||||
GraphExecutionEvent | NodeExecutionEvent, Field(discriminator="event_type")
|
||||
]
|
||||
|
||||
|
||||
class RedisExecutionEventBus(RedisEventBus[ExecutionEvent]):
|
||||
Model = ExecutionEvent # type: ignore
|
||||
|
||||
@property
|
||||
def event_bus_name(self) -> str:
|
||||
return config.execution_event_bus_name
|
||||
|
||||
def publish(self, res: ExecutionResult):
|
||||
self.publish_event(res, f"{res.graph_id}/{res.graph_exec_id}")
|
||||
def publish(self, res: GraphExecutionMeta | NodeExecutionResult):
|
||||
if isinstance(res, GraphExecutionMeta):
|
||||
self.publish_graph_exec_update(res)
|
||||
else:
|
||||
self.publish_node_exec_update(res)
|
||||
|
||||
def publish_node_exec_update(self, res: NodeExecutionResult):
|
||||
event = NodeExecutionEvent.model_validate(res.model_dump())
|
||||
self.publish_event(event, f"{res.graph_id}/{res.graph_exec_id}")
|
||||
|
||||
def publish_graph_exec_update(self, res: GraphExecutionMeta):
|
||||
event = GraphExecutionEvent.model_validate(res.model_dump())
|
||||
self.publish_event(event, f"{res.graph_id}/{res.id}")
|
||||
|
||||
def listen(
|
||||
self, graph_id: str = "*", graph_exec_id: str = "*"
|
||||
) -> Generator[ExecutionResult, None, None]:
|
||||
for execution_result in self.listen_events(f"{graph_id}/{graph_exec_id}"):
|
||||
yield execution_result
|
||||
) -> Generator[ExecutionEvent, None, None]:
|
||||
for event in self.listen_events(f"{graph_id}/{graph_exec_id}"):
|
||||
yield event
|
||||
|
||||
|
||||
class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionResult]):
|
||||
Model = ExecutionResult
|
||||
class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
|
||||
Model = ExecutionEvent # type: ignore
|
||||
|
||||
@property
|
||||
def event_bus_name(self) -> str:
|
||||
return config.execution_event_bus_name
|
||||
|
||||
async def publish(self, res: ExecutionResult):
|
||||
await self.publish_event(res, f"{res.graph_id}/{res.graph_exec_id}")
|
||||
async def publish(self, res: GraphExecutionMeta | NodeExecutionResult):
|
||||
if isinstance(res, GraphExecutionMeta):
|
||||
await self.publish_graph_exec_update(res)
|
||||
else:
|
||||
await self.publish_node_exec_update(res)
|
||||
|
||||
async def publish_node_exec_update(self, res: NodeExecutionResult):
|
||||
event = NodeExecutionEvent.model_validate(res.model_dump())
|
||||
await self.publish_event(event, f"{res.graph_id}/{res.graph_exec_id}")
|
||||
|
||||
async def publish_graph_exec_update(self, res: GraphExecutionMeta):
|
||||
event = GraphExecutionEvent.model_validate(res.model_dump())
|
||||
await self.publish_event(event, f"{res.graph_id}/{res.id}")
|
||||
|
||||
async def listen(
|
||||
self, graph_id: str = "*", graph_exec_id: str = "*"
|
||||
) -> AsyncGenerator[ExecutionResult, None]:
|
||||
async for execution_result in self.listen_events(f"{graph_id}/{graph_exec_id}"):
|
||||
yield execution_result
|
||||
) -> AsyncGenerator[ExecutionEvent, None]:
|
||||
async for event in self.listen_events(f"{graph_id}/{graph_exec_id}"):
|
||||
yield event
|
||||
|
@ -1,21 +1,14 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal, Optional, Type
|
||||
|
||||
import prisma
|
||||
from prisma import Json
|
||||
from prisma.enums import SubmissionStatus
|
||||
from prisma.models import (
|
||||
AgentGraph,
|
||||
AgentGraphExecution,
|
||||
AgentNode,
|
||||
AgentNodeLink,
|
||||
StoreListingVersion,
|
||||
)
|
||||
from prisma.types import AgentGraphExecutionWhereInput, AgentGraphWhereInput
|
||||
from pydantic.fields import Field, computed_field
|
||||
from prisma.models import AgentGraph, AgentNode, AgentNodeLink, StoreListingVersion
|
||||
from prisma.types import AgentGraphWhereInput
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
@ -25,15 +18,11 @@ from backend.util import type as type_utils
|
||||
|
||||
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
|
||||
from .db import BaseDbModel, transaction
|
||||
from .execution import ExecutionResult, ExecutionStatus
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE, GRAPH_EXECUTION_INCLUDE
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
|
||||
from .integrations import Webhook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_INPUT_BLOCK_ID = AgentInputBlock().id
|
||||
_OUTPUT_BLOCK_ID = AgentOutputBlock().id
|
||||
|
||||
|
||||
class Link(BaseDbModel):
|
||||
source_id: str
|
||||
@ -164,111 +153,6 @@ class NodeModel(Node):
|
||||
Webhook.model_rebuild()
|
||||
|
||||
|
||||
class GraphExecutionMeta(BaseDbModel):
|
||||
execution_id: str
|
||||
started_at: datetime
|
||||
ended_at: datetime
|
||||
cost: Optional[int] = Field(..., description="Execution cost in credits")
|
||||
duration: float
|
||||
total_run_time: float
|
||||
status: ExecutionStatus
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
preset_id: Optional[str]
|
||||
|
||||
@staticmethod
|
||||
def from_db(_graph_exec: AgentGraphExecution):
|
||||
now = datetime.now(timezone.utc)
|
||||
start_time = _graph_exec.startedAt or _graph_exec.createdAt
|
||||
end_time = _graph_exec.updatedAt or now
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
total_run_time = duration
|
||||
|
||||
try:
|
||||
stats = type_utils.convert(_graph_exec.stats or {}, dict[str, Any])
|
||||
except ValueError:
|
||||
stats = {}
|
||||
|
||||
duration = stats.get("walltime", duration)
|
||||
total_run_time = stats.get("nodes_walltime", total_run_time)
|
||||
|
||||
return GraphExecutionMeta(
|
||||
id=_graph_exec.id,
|
||||
execution_id=_graph_exec.id,
|
||||
started_at=start_time,
|
||||
ended_at=end_time,
|
||||
cost=stats.get("cost", None),
|
||||
duration=duration,
|
||||
total_run_time=total_run_time,
|
||||
status=ExecutionStatus(_graph_exec.executionStatus),
|
||||
graph_id=_graph_exec.agentGraphId,
|
||||
graph_version=_graph_exec.agentGraphVersion,
|
||||
preset_id=_graph_exec.agentPresetId,
|
||||
)
|
||||
|
||||
|
||||
class GraphExecution(GraphExecutionMeta):
|
||||
inputs: dict[str, Any]
|
||||
outputs: dict[str, list[Any]]
|
||||
node_executions: list[ExecutionResult]
|
||||
|
||||
@staticmethod
|
||||
def from_db(_graph_exec: AgentGraphExecution):
|
||||
if _graph_exec.AgentNodeExecutions is None:
|
||||
raise ValueError("Node executions must be included in query")
|
||||
|
||||
graph_exec = GraphExecutionMeta.from_db(_graph_exec)
|
||||
|
||||
node_executions = sorted(
|
||||
[
|
||||
ExecutionResult.from_db(ne, _graph_exec.userId)
|
||||
for ne in _graph_exec.AgentNodeExecutions
|
||||
],
|
||||
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
|
||||
)
|
||||
|
||||
inputs = {
|
||||
**{
|
||||
# inputs from Agent Input Blocks
|
||||
exec.input_data["name"]: exec.input_data.get("value")
|
||||
for exec in node_executions
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
and block.block_type == BlockType.INPUT
|
||||
)
|
||||
},
|
||||
**{
|
||||
# input from webhook-triggered block
|
||||
"payload": exec.input_data["payload"]
|
||||
for exec in node_executions
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
and block.block_type
|
||||
in [BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL]
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
outputs: dict[str, list] = defaultdict(list)
|
||||
for exec in node_executions:
|
||||
if (
|
||||
block := get_block(exec.block_id)
|
||||
) and block.block_type == BlockType.OUTPUT:
|
||||
outputs[exec.input_data["name"]].append(
|
||||
exec.input_data.get("value", None)
|
||||
)
|
||||
|
||||
return GraphExecution(
|
||||
**{
|
||||
field_name: getattr(graph_exec, field_name)
|
||||
for field_name in graph_exec.model_fields
|
||||
},
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
node_executions=node_executions,
|
||||
)
|
||||
|
||||
|
||||
class BaseGraph(BaseDbModel):
|
||||
version: int = 1
|
||||
is_active: bool = True
|
||||
@ -644,45 +528,6 @@ async def get_graphs(
|
||||
return graph_models
|
||||
|
||||
|
||||
async def get_graph_executions(
|
||||
graph_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> list[GraphExecutionMeta]:
|
||||
where_filter: AgentGraphExecutionWhereInput = {
|
||||
"isDeleted": False,
|
||||
}
|
||||
if user_id:
|
||||
where_filter["userId"] = user_id
|
||||
if graph_id:
|
||||
where_filter["agentGraphId"] = graph_id
|
||||
|
||||
executions = await AgentGraphExecution.prisma().find_many(
|
||||
where=where_filter,
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
return [GraphExecutionMeta.from_db(execution) for execution in executions]
|
||||
|
||||
|
||||
async def get_execution_meta(
|
||||
user_id: str, execution_id: str
|
||||
) -> GraphExecutionMeta | None:
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={"id": execution_id, "isDeleted": False, "userId": user_id}
|
||||
)
|
||||
return GraphExecutionMeta.from_db(execution) if execution else None
|
||||
|
||||
|
||||
async def get_execution(
|
||||
user_id: str,
|
||||
execution_id: str,
|
||||
) -> GraphExecution | None:
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={"id": execution_id, "isDeleted": False, "userId": user_id},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
return GraphExecution.from_db(execution) if execution else None
|
||||
|
||||
|
||||
async def get_graph_metadata(graph_id: str, version: int | None = None) -> Graph | None:
|
||||
where_clause: AgentGraphWhereInput = {
|
||||
"id": graph_id,
|
||||
|
@ -411,6 +411,7 @@ class NodeExecutionStats(BaseModel):
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = "allow"
|
||||
|
||||
error: Optional[Exception | str] = None
|
||||
walltime: float = 0
|
||||
@ -428,12 +429,19 @@ class GraphExecutionStats(BaseModel):
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = "allow"
|
||||
|
||||
error: Optional[Exception | str] = None
|
||||
walltime: float = 0
|
||||
walltime: float = Field(
|
||||
default=0, description="Time between start and end of run (seconds)"
|
||||
)
|
||||
cputime: float = 0
|
||||
nodes_walltime: float = 0
|
||||
nodes_walltime: float = Field(
|
||||
default=0, description="Total node execution time (seconds)"
|
||||
)
|
||||
nodes_cputime: float = 0
|
||||
node_count: int = 0
|
||||
node_error_count: int = 0
|
||||
cost: float = 0
|
||||
node_count: int = Field(default=0, description="Total number of node executions")
|
||||
node_error_count: int = Field(
|
||||
default=0, description="Total number of errors generated"
|
||||
)
|
||||
cost: int = Field(default=0, description="Total execution cost (cents)")
|
||||
|
@ -1,8 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Generator, Generic, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -14,13 +12,6 @@ from backend.data import redis
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DateTimeEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if isinstance(o, datetime):
|
||||
return o.isoformat()
|
||||
return super().default(o)
|
||||
|
||||
|
||||
M = TypeVar("M", bound=BaseModel)
|
||||
|
||||
|
||||
@ -32,8 +23,12 @@ class BaseRedisEventBus(Generic[M], ABC):
|
||||
def event_bus_name(self) -> str:
|
||||
pass
|
||||
|
||||
@property
|
||||
def Message(self) -> type["_EventPayloadWrapper[M]"]:
|
||||
return _EventPayloadWrapper[self.Model]
|
||||
|
||||
def _serialize_message(self, item: M, channel_key: str) -> tuple[str, str]:
|
||||
message = json.dumps(item.model_dump(), cls=DateTimeEncoder)
|
||||
message = self.Message(payload=item).model_dump_json()
|
||||
channel_name = f"{self.event_bus_name}/{channel_key}"
|
||||
logger.debug(f"[{channel_name}] Publishing an event to Redis {message}")
|
||||
return message, channel_name
|
||||
@ -43,9 +38,8 @@ class BaseRedisEventBus(Generic[M], ABC):
|
||||
if msg["type"] != message_type:
|
||||
return None
|
||||
try:
|
||||
data = json.loads(msg["data"])
|
||||
logger.debug(f"Consuming an event from Redis {data}")
|
||||
return self.Model(**data)
|
||||
logger.debug(f"[{channel_key}] Consuming an event from Redis {msg['data']}")
|
||||
return self.Message.model_validate_json(msg["data"]).payload
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse event result from Redis {msg} {e}")
|
||||
|
||||
@ -57,9 +51,16 @@ class BaseRedisEventBus(Generic[M], ABC):
|
||||
return pubsub, full_channel_name
|
||||
|
||||
|
||||
class RedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
Model: type[M]
|
||||
class _EventPayloadWrapper(BaseModel, Generic[M]):
|
||||
"""
|
||||
Wrapper model to allow `RedisEventBus.Model` to be a discriminated union
|
||||
of multiple event types.
|
||||
"""
|
||||
|
||||
payload: M
|
||||
|
||||
|
||||
class RedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
@property
|
||||
def connection(self) -> redis.Redis:
|
||||
return redis.get_redis()
|
||||
@ -85,8 +86,6 @@ class RedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
|
||||
|
||||
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
Model: type[M]
|
||||
|
||||
@property
|
||||
async def connection(self) -> redis.AsyncRedis:
|
||||
return await redis.get_redis_async()
|
||||
|
@ -1,16 +1,17 @@
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
ExecutionResult,
|
||||
GraphExecutionMeta,
|
||||
NodeExecutionResult,
|
||||
RedisExecutionEventBus,
|
||||
create_graph_execution,
|
||||
get_execution_results,
|
||||
get_incomplete_executions,
|
||||
get_latest_execution,
|
||||
update_execution_status,
|
||||
update_execution_status_batch,
|
||||
get_incomplete_node_executions,
|
||||
get_latest_node_execution,
|
||||
get_node_execution_results,
|
||||
update_graph_execution_start_time,
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_stats,
|
||||
update_node_execution_status,
|
||||
update_node_execution_status_batch,
|
||||
upsert_execution_input,
|
||||
upsert_execution_output,
|
||||
)
|
||||
@ -55,23 +56,29 @@ class DatabaseManager(AppService):
|
||||
super().__init__()
|
||||
self.use_db = True
|
||||
self.use_redis = True
|
||||
self.event_queue = RedisExecutionEventBus()
|
||||
self.execution_event_bus = RedisExecutionEventBus()
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return config.database_api_port
|
||||
|
||||
@expose
|
||||
def send_execution_update(self, execution_result: ExecutionResult):
|
||||
self.event_queue.publish(execution_result)
|
||||
def send_execution_update(
|
||||
self, execution_result: GraphExecutionMeta | NodeExecutionResult
|
||||
):
|
||||
self.execution_event_bus.publish(execution_result)
|
||||
|
||||
# Executions
|
||||
create_graph_execution = exposed_run_and_wait(create_graph_execution)
|
||||
get_execution_results = exposed_run_and_wait(get_execution_results)
|
||||
get_incomplete_executions = exposed_run_and_wait(get_incomplete_executions)
|
||||
get_latest_execution = exposed_run_and_wait(get_latest_execution)
|
||||
update_execution_status = exposed_run_and_wait(update_execution_status)
|
||||
update_execution_status_batch = exposed_run_and_wait(update_execution_status_batch)
|
||||
get_node_execution_results = exposed_run_and_wait(get_node_execution_results)
|
||||
get_incomplete_node_executions = exposed_run_and_wait(
|
||||
get_incomplete_node_executions
|
||||
)
|
||||
get_latest_node_execution = exposed_run_and_wait(get_latest_node_execution)
|
||||
update_node_execution_status = exposed_run_and_wait(update_node_execution_status)
|
||||
update_node_execution_status_batch = exposed_run_and_wait(
|
||||
update_node_execution_status_batch
|
||||
)
|
||||
update_graph_execution_start_time = exposed_run_and_wait(
|
||||
update_graph_execution_start_time
|
||||
)
|
||||
|
@ -40,10 +40,10 @@ from backend.data.block import (
|
||||
)
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionResult,
|
||||
ExecutionStatus,
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
NodeExecutionResult,
|
||||
merge_execution_input,
|
||||
parse_execution_output,
|
||||
)
|
||||
@ -149,8 +149,9 @@ def execute_node(
|
||||
node_exec_id = data.node_exec_id
|
||||
node_id = data.node_id
|
||||
|
||||
def update_execution(status: ExecutionStatus) -> ExecutionResult:
|
||||
exec_update = db_client.update_execution_status(node_exec_id, status)
|
||||
def update_execution(status: ExecutionStatus) -> NodeExecutionResult:
|
||||
"""Sets status and fetches+broadcasts the latest state of the node execution"""
|
||||
exec_update = db_client.update_node_execution_status(node_exec_id, status)
|
||||
db_client.send_execution_update(exec_update)
|
||||
return exec_update
|
||||
|
||||
@ -281,7 +282,7 @@ def _enqueue_next_nodes(
|
||||
def add_enqueued_execution(
|
||||
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
|
||||
) -> NodeExecutionEntry:
|
||||
exec_update = db_client.update_execution_status(
|
||||
exec_update = db_client.update_node_execution_status(
|
||||
node_exec_id, ExecutionStatus.QUEUED, data
|
||||
)
|
||||
db_client.send_execution_update(exec_update)
|
||||
@ -326,7 +327,7 @@ def _enqueue_next_nodes(
|
||||
if link.is_static and link.sink_name not in next_node_input
|
||||
}
|
||||
if static_link_names and (
|
||||
latest_execution := db_client.get_latest_execution(
|
||||
latest_execution := db_client.get_latest_node_execution(
|
||||
next_node_id, graph_exec_id
|
||||
)
|
||||
):
|
||||
@ -359,7 +360,7 @@ def _enqueue_next_nodes(
|
||||
|
||||
# If link is static, there could be some incomplete executions waiting for it.
|
||||
# Load and complete the input missing input data, and try to re-enqueue them.
|
||||
for iexec in db_client.get_incomplete_executions(
|
||||
for iexec in db_client.get_incomplete_node_executions(
|
||||
next_node_id, graph_exec_id
|
||||
):
|
||||
idata = iexec.input_data
|
||||
@ -732,7 +733,6 @@ class Executor:
|
||||
running_executions: dict[str, AsyncResult] = {}
|
||||
|
||||
def make_exec_callback(exec_data: NodeExecutionEntry):
|
||||
|
||||
def callback(result: object):
|
||||
running_executions.pop(exec_data.node_id)
|
||||
|
||||
@ -778,7 +778,7 @@ class Executor:
|
||||
exec_id = exec_data.node_exec_id
|
||||
cls.db_client.upsert_execution_output(exec_id, "error", str(error))
|
||||
|
||||
exec_update = cls.db_client.update_execution_status(
|
||||
exec_update = cls.db_client.update_node_execution_status(
|
||||
exec_id, ExecutionStatus.FAILED
|
||||
)
|
||||
cls.db_client.send_execution_update(exec_update)
|
||||
@ -842,7 +842,7 @@ class Executor:
|
||||
metadata = cls.db_client.get_graph_metadata(
|
||||
graph_exec.graph_id, graph_exec.graph_version
|
||||
)
|
||||
outputs = cls.db_client.get_execution_results(
|
||||
outputs = cls.db_client.get_node_execution_results(
|
||||
graph_exec.graph_exec_id,
|
||||
block_ids=[AgentOutputBlock().id],
|
||||
)
|
||||
@ -1061,7 +1061,7 @@ class ExecutionManager(AppService):
|
||||
graph_exec_id,
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
node_execs = self.db_client.get_execution_results(
|
||||
node_execs = self.db_client.get_node_execution_results(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[
|
||||
ExecutionStatus.QUEUED,
|
||||
@ -1069,7 +1069,7 @@ class ExecutionManager(AppService):
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
],
|
||||
)
|
||||
self.db_client.update_execution_status_batch(
|
||||
self.db_client.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in node_execs],
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
|
@ -2,8 +2,17 @@ from typing import Dict, Set
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from backend.data import execution
|
||||
from backend.server.model import Methods, WsMessage
|
||||
from backend.data.execution import (
|
||||
ExecutionEventType,
|
||||
GraphExecutionEvent,
|
||||
NodeExecutionEvent,
|
||||
)
|
||||
from backend.server.model import WSMessage, WSMethod
|
||||
|
||||
_EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
|
||||
ExecutionEventType.GRAPH_EXEC_UPDATE: WSMethod.GRAPH_EXECUTION_EVENT,
|
||||
ExecutionEventType.NODE_EXEC_UPDATE: WSMethod.NODE_EXECUTION_EVENT,
|
||||
}
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
@ -11,39 +20,58 @@ class ConnectionManager:
|
||||
self.active_connections: Set[WebSocket] = set()
|
||||
self.subscriptions: Dict[str, Set[WebSocket]] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
async def connect_socket(self, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
self.active_connections.add(websocket)
|
||||
|
||||
def disconnect(self, websocket: WebSocket):
|
||||
def disconnect_socket(self, websocket: WebSocket):
|
||||
self.active_connections.remove(websocket)
|
||||
for subscribers in self.subscriptions.values():
|
||||
subscribers.discard(websocket)
|
||||
|
||||
async def subscribe(
|
||||
self, *, user_id: str, graph_id: str, graph_version: int, websocket: WebSocket
|
||||
):
|
||||
key = f"{user_id}_{graph_id}_{graph_version}"
|
||||
async def subscribe_graph_exec(
|
||||
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
|
||||
) -> str:
|
||||
key = _graph_exec_channel_key(user_id, graph_exec_id)
|
||||
if key not in self.subscriptions:
|
||||
self.subscriptions[key] = set()
|
||||
self.subscriptions[key].add(websocket)
|
||||
return key
|
||||
|
||||
async def unsubscribe(
|
||||
self, *, user_id: str, graph_id: str, graph_version: int, websocket: WebSocket
|
||||
):
|
||||
key = f"{user_id}_{graph_id}_{graph_version}"
|
||||
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
|
||||
) -> str | None:
|
||||
key = _graph_exec_channel_key(user_id, graph_exec_id)
|
||||
if key in self.subscriptions:
|
||||
self.subscriptions[key].discard(websocket)
|
||||
if not self.subscriptions[key]:
|
||||
del self.subscriptions[key]
|
||||
return key
|
||||
return None
|
||||
|
||||
async def send_execution_result(self, result: execution.ExecutionResult):
|
||||
key = f"{result.user_id}_{result.graph_id}_{result.graph_version}"
|
||||
async def send_execution_update(
|
||||
self, exec_event: GraphExecutionEvent | NodeExecutionEvent
|
||||
) -> int:
|
||||
graph_exec_id = (
|
||||
exec_event.id
|
||||
if isinstance(exec_event, GraphExecutionEvent)
|
||||
else exec_event.graph_exec_id
|
||||
)
|
||||
key = _graph_exec_channel_key(exec_event.user_id, graph_exec_id)
|
||||
|
||||
n_sent = 0
|
||||
if key in self.subscriptions:
|
||||
message = WsMessage(
|
||||
method=Methods.EXECUTION_EVENT,
|
||||
message = WSMessage(
|
||||
method=_EVENT_TYPE_TO_METHOD_MAP[exec_event.event_type],
|
||||
channel=key,
|
||||
data=result.model_dump(),
|
||||
data=exec_event.model_dump(),
|
||||
).model_dump_json()
|
||||
for connection in self.subscriptions[key]:
|
||||
await connection.send_text(message)
|
||||
n_sent += 1
|
||||
|
||||
return n_sent
|
||||
|
||||
|
||||
def _graph_exec_channel_key(user_id: str, graph_exec_id: str) -> str:
|
||||
return f"{user_id}|graph_exec#{graph_exec_id}"
|
||||
|
@ -12,7 +12,7 @@ from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.api_key import APIKey
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.execution import ExecutionResult
|
||||
from backend.data.execution import NodeExecutionResult
|
||||
from backend.executor import ExecutionManager
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.util.service import get_service_client
|
||||
@ -53,7 +53,7 @@ class GraphExecutionResult(TypedDict):
|
||||
output: Optional[List[Dict[str, str]]]
|
||||
|
||||
|
||||
def get_outputs_with_names(results: List[ExecutionResult]) -> List[Dict[str, str]]:
|
||||
def get_outputs_with_names(results: list[NodeExecutionResult]) -> list[dict[str, str]]:
|
||||
outputs = []
|
||||
for result in results:
|
||||
if "output" in result.output_data:
|
||||
@ -130,7 +130,7 @@ async def get_graph_execution_results(
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
results = await execution_db.get_execution_results(graph_exec_id)
|
||||
results = await execution_db.get_node_execution_results(graph_exec_id)
|
||||
last_result = results[-1] if results else None
|
||||
execution_status = (
|
||||
last_result.status if last_result else AgentExecutionStatus.INCOMPLETE
|
||||
|
@ -1,31 +1,31 @@
|
||||
import enum
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import pydantic
|
||||
|
||||
import backend.data.graph
|
||||
from backend.data.api_key import APIKeyPermission, APIKeyWithoutHash
|
||||
from backend.data.graph import Graph
|
||||
|
||||
|
||||
class Methods(enum.Enum):
|
||||
SUBSCRIBE = "subscribe"
|
||||
class WSMethod(enum.Enum):
|
||||
SUBSCRIBE_GRAPH_EXEC = "subscribe_graph_execution"
|
||||
UNSUBSCRIBE = "unsubscribe"
|
||||
EXECUTION_EVENT = "execution_event"
|
||||
GRAPH_EXECUTION_EVENT = "graph_execution_event"
|
||||
NODE_EXECUTION_EVENT = "node_execution_event"
|
||||
ERROR = "error"
|
||||
HEARTBEAT = "heartbeat"
|
||||
|
||||
|
||||
class WsMessage(pydantic.BaseModel):
|
||||
method: Methods
|
||||
data: Optional[Union[dict[str, Any], list[Any], str]] = None
|
||||
class WSMessage(pydantic.BaseModel):
|
||||
method: WSMethod
|
||||
data: Optional[dict[str, Any] | list[Any] | str] = None
|
||||
success: bool | None = None
|
||||
channel: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ExecutionSubscription(pydantic.BaseModel):
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
class WSSubscribeGraphExecutionRequest(pydantic.BaseModel):
|
||||
graph_exec_id: str
|
||||
|
||||
|
||||
class ExecuteGraphResponse(pydantic.BaseModel):
|
||||
@ -33,12 +33,12 @@ class ExecuteGraphResponse(pydantic.BaseModel):
|
||||
|
||||
|
||||
class CreateGraph(pydantic.BaseModel):
|
||||
graph: backend.data.graph.Graph
|
||||
graph: Graph
|
||||
|
||||
|
||||
class CreateAPIKeyRequest(pydantic.BaseModel):
|
||||
name: str
|
||||
permissions: List[APIKeyPermission]
|
||||
permissions: list[APIKeyPermission]
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ class SetGraphActiveVersion(pydantic.BaseModel):
|
||||
|
||||
|
||||
class UpdatePermissionsRequest(pydantic.BaseModel):
|
||||
permissions: List[APIKeyPermission]
|
||||
permissions: list[APIKeyPermission]
|
||||
|
||||
|
||||
class Pagination(pydantic.BaseModel):
|
||||
|
@ -177,7 +177,9 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
|
||||
@staticmethod
|
||||
async def test_get_graph_run_status(graph_exec_id: str, user_id: str):
|
||||
execution = await backend.data.graph.get_execution_meta(
|
||||
from backend.data.execution import get_graph_execution_meta
|
||||
|
||||
execution = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not execution:
|
||||
|
@ -599,8 +599,8 @@ def execute_graph(
|
||||
)
|
||||
async def stop_graph_run(
|
||||
graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> graph_db.GraphExecution:
|
||||
if not await graph_db.get_execution_meta(
|
||||
) -> execution_db.GraphExecution:
|
||||
if not await execution_db.get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
):
|
||||
raise HTTPException(404, detail=f"Agent execution #{graph_exec_id} not found")
|
||||
@ -610,7 +610,9 @@ async def stop_graph_run(
|
||||
)
|
||||
|
||||
# Retrieve & return canceled graph execution in its final state
|
||||
result = await graph_db.get_execution(execution_id=graph_exec_id, user_id=user_id)
|
||||
result = await execution_db.get_graph_execution(
|
||||
execution_id=graph_exec_id, user_id=user_id
|
||||
)
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
500,
|
||||
@ -626,8 +628,8 @@ async def stop_graph_run(
|
||||
)
|
||||
async def get_graphs_executions(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[graph_db.GraphExecutionMeta]:
|
||||
return await graph_db.get_graph_executions(user_id=user_id)
|
||||
) -> list[execution_db.GraphExecutionMeta]:
|
||||
return await execution_db.get_graph_executions(user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@ -638,8 +640,8 @@ async def get_graphs_executions(
|
||||
async def get_graph_executions(
|
||||
graph_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[graph_db.GraphExecutionMeta]:
|
||||
return await graph_db.get_graph_executions(graph_id=graph_id, user_id=user_id)
|
||||
) -> list[execution_db.GraphExecutionMeta]:
|
||||
return await execution_db.get_graph_executions(graph_id=graph_id, user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@ -651,8 +653,10 @@ async def get_graph_execution(
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> graph_db.GraphExecution:
|
||||
result = await graph_db.get_execution(execution_id=graph_exec_id, user_id=user_id)
|
||||
) -> execution_db.GraphExecution:
|
||||
result = await execution_db.get_graph_execution(
|
||||
execution_id=graph_exec_id, user_id=user_id
|
||||
)
|
||||
if not result or result.graph_id != graph_id:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Graph execution #{graph_exec_id} not found."
|
||||
@ -671,7 +675,9 @@ async def delete_graph_execution(
|
||||
graph_exec_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> None:
|
||||
await execution_db.delete_execution(graph_exec_id=graph_exec_id, user_id=user_id)
|
||||
await execution_db.delete_graph_execution(
|
||||
graph_exec_id=graph_exec_id, user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
########################################################
|
||||
|
@ -12,7 +12,7 @@ from backend.data import redis
|
||||
from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.model import ExecutionSubscription, Methods, WsMessage
|
||||
from backend.server.model import WSMessage, WSMethod, WSSubscribeGraphExecutionRequest
|
||||
from backend.util.service import AppProcess, get_service_client
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
|
||||
@ -52,7 +52,7 @@ async def event_broadcaster(manager: ConnectionManager):
|
||||
redis.connect()
|
||||
event_queue = AsyncRedisExecutionEventBus()
|
||||
async for event in event_queue.listen():
|
||||
await manager.send_execution_result(event)
|
||||
await manager.send_execution_update(event)
|
||||
except Exception as e:
|
||||
logger.exception(f"Event broadcaster error: {e}")
|
||||
raise
|
||||
@ -85,18 +85,18 @@ async def handle_subscribe(
|
||||
connection_manager: ConnectionManager,
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
message: WsMessage,
|
||||
message: WSMessage,
|
||||
):
|
||||
if not message.data:
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.ERROR,
|
||||
WSMessage(
|
||||
method=WSMethod.ERROR,
|
||||
success=False,
|
||||
error="Subscription data missing",
|
||||
).model_dump_json()
|
||||
)
|
||||
else:
|
||||
sub_req = ExecutionSubscription.model_validate(message.data)
|
||||
sub_req = WSSubscribeGraphExecutionRequest.model_validate(message.data)
|
||||
|
||||
# Verify that user has read access to graph
|
||||
# if not get_db_client().get_graph(
|
||||
@ -113,21 +113,20 @@ async def handle_subscribe(
|
||||
# )
|
||||
# return
|
||||
|
||||
await connection_manager.subscribe(
|
||||
channel_key = await connection_manager.subscribe_graph_exec(
|
||||
user_id=user_id,
|
||||
graph_id=sub_req.graph_id,
|
||||
graph_version=sub_req.graph_version,
|
||||
graph_exec_id=sub_req.graph_exec_id,
|
||||
websocket=websocket,
|
||||
)
|
||||
logger.debug(
|
||||
f"New execution subscription for user #{user_id} "
|
||||
f"graph #{sub_req.graph_id}v{sub_req.graph_version}"
|
||||
f"New subscription for user #{user_id}, "
|
||||
f"graph execution #{sub_req.graph_exec_id}"
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.SUBSCRIBE,
|
||||
WSMessage(
|
||||
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
|
||||
success=True,
|
||||
channel=f"{user_id}_{sub_req.graph_id}_{sub_req.graph_version}",
|
||||
channel=channel_key,
|
||||
).model_dump_json()
|
||||
)
|
||||
|
||||
@ -136,33 +135,32 @@ async def handle_unsubscribe(
|
||||
connection_manager: ConnectionManager,
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
message: WsMessage,
|
||||
message: WSMessage,
|
||||
):
|
||||
if not message.data:
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.ERROR,
|
||||
WSMessage(
|
||||
method=WSMethod.ERROR,
|
||||
success=False,
|
||||
error="Subscription data missing",
|
||||
).model_dump_json()
|
||||
)
|
||||
else:
|
||||
unsub_req = ExecutionSubscription.model_validate(message.data)
|
||||
await connection_manager.unsubscribe(
|
||||
unsub_req = WSSubscribeGraphExecutionRequest.model_validate(message.data)
|
||||
channel_key = await connection_manager.unsubscribe(
|
||||
user_id=user_id,
|
||||
graph_id=unsub_req.graph_id,
|
||||
graph_version=unsub_req.graph_version,
|
||||
graph_exec_id=unsub_req.graph_exec_id,
|
||||
websocket=websocket,
|
||||
)
|
||||
logger.debug(
|
||||
f"Removed execution subscription for user #{user_id} "
|
||||
f"graph #{unsub_req.graph_id}v{unsub_req.graph_version}"
|
||||
f"Removed subscription for user #{user_id}, "
|
||||
f"graph execution #{unsub_req.graph_exec_id}"
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.UNSUBSCRIBE,
|
||||
WSMessage(
|
||||
method=WSMethod.UNSUBSCRIBE,
|
||||
success=True,
|
||||
channel=f"{unsub_req.graph_id}_{unsub_req.graph_version}",
|
||||
channel=channel_key,
|
||||
).model_dump_json()
|
||||
)
|
||||
|
||||
@ -179,20 +177,24 @@ async def websocket_router(
|
||||
user_id = await authenticate_websocket(websocket)
|
||||
if not user_id:
|
||||
return
|
||||
await manager.connect(websocket)
|
||||
await manager.connect_socket(websocket)
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
message = WsMessage.model_validate_json(data)
|
||||
message = WSMessage.model_validate_json(data)
|
||||
|
||||
if message.method == Methods.HEARTBEAT:
|
||||
if message.method == WSMethod.HEARTBEAT:
|
||||
await websocket.send_json(
|
||||
{"method": Methods.HEARTBEAT.value, "data": "pong", "success": True}
|
||||
{
|
||||
"method": WSMethod.HEARTBEAT.value,
|
||||
"data": "pong",
|
||||
"success": True,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
if message.method == Methods.SUBSCRIBE:
|
||||
if message.method == WSMethod.SUBSCRIBE_GRAPH_EXEC:
|
||||
await handle_subscribe(
|
||||
connection_manager=manager,
|
||||
websocket=websocket,
|
||||
@ -201,7 +203,7 @@ async def websocket_router(
|
||||
)
|
||||
continue
|
||||
|
||||
elif message.method == Methods.UNSUBSCRIBE:
|
||||
elif message.method == WSMethod.UNSUBSCRIBE:
|
||||
await handle_unsubscribe(
|
||||
connection_manager=manager,
|
||||
websocket=websocket,
|
||||
@ -216,7 +218,7 @@ async def websocket_router(
|
||||
)
|
||||
continue
|
||||
|
||||
if message.method == Methods.ERROR:
|
||||
if message.method == WSMethod.ERROR:
|
||||
logger.error(f"WebSocket Error message received: {message.data}")
|
||||
|
||||
else:
|
||||
@ -225,15 +227,15 @@ async def websocket_router(
|
||||
f"{message.data}"
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.ERROR,
|
||||
WSMessage(
|
||||
method=WSMethod.ERROR,
|
||||
success=False,
|
||||
error="Message type is not processed by the server",
|
||||
).model_dump_json()
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket)
|
||||
manager.disconnect_socket(websocket)
|
||||
logger.debug("WebSocket client disconnected")
|
||||
|
||||
|
||||
|
@ -5,7 +5,7 @@ from typing import Sequence, cast
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block import Block, BlockSchema, initialize_blocks
|
||||
from backend.data.execution import ExecutionResult, ExecutionStatus
|
||||
from backend.data.execution import ExecutionStatus, NodeExecutionResult
|
||||
from backend.data.model import _BaseCredentials
|
||||
from backend.data.user import create_default_user
|
||||
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
||||
@ -63,7 +63,7 @@ async def wait_execution(
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
timeout: int = 30,
|
||||
) -> Sequence[ExecutionResult]:
|
||||
) -> Sequence[NodeExecutionResult]:
|
||||
async def is_execution_completed():
|
||||
status = await AgentServer().test_get_graph_run_status(graph_exec_id, user_id)
|
||||
log.info(f"Execution status: {status}")
|
||||
|
1
autogpt_platform/backend/poetry.lock
generated
1
autogpt_platform/backend/poetry.lock
generated
@ -299,7 +299,6 @@ pydantic-settings = "^2.7.1"
|
||||
pyjwt = "^2.10.1"
|
||||
pytest-asyncio = "^0.25.3"
|
||||
pytest-mock = "^3.14.0"
|
||||
python-dotenv = "^1.0.1"
|
||||
supabase = "^2.13.0"
|
||||
|
||||
[package.source]
|
||||
|
@ -2,9 +2,12 @@ import logging
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from backend.util.logging import configure_logging
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# NOTE: You can run tests like with the --log-cli-level=INFO to see the logs
|
||||
# Set up logging
|
||||
configure_logging()
|
||||
|
@ -4,9 +4,13 @@ from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
from fastapi import WebSocket
|
||||
|
||||
from backend.data.execution import ExecutionResult, ExecutionStatus
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionEvent,
|
||||
NodeExecutionEvent,
|
||||
)
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.model import Methods, WsMessage
|
||||
from backend.server.model import WSMessage, WSMethod
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -25,7 +29,7 @@ def mock_websocket() -> AsyncMock:
|
||||
async def test_connect(
|
||||
connection_manager: ConnectionManager, mock_websocket: AsyncMock
|
||||
) -> None:
|
||||
await connection_manager.connect(mock_websocket)
|
||||
await connection_manager.connect_socket(mock_websocket)
|
||||
assert mock_websocket in connection_manager.active_connections
|
||||
mock_websocket.accept.assert_called_once()
|
||||
|
||||
@ -34,37 +38,39 @@ def test_disconnect(
|
||||
connection_manager: ConnectionManager, mock_websocket: AsyncMock
|
||||
) -> None:
|
||||
connection_manager.active_connections.add(mock_websocket)
|
||||
connection_manager.subscriptions["test_graph_1"] = {mock_websocket}
|
||||
connection_manager.subscriptions["test_channel_42"] = {mock_websocket}
|
||||
|
||||
connection_manager.disconnect(mock_websocket)
|
||||
connection_manager.disconnect_socket(mock_websocket)
|
||||
|
||||
assert mock_websocket not in connection_manager.active_connections
|
||||
assert mock_websocket not in connection_manager.subscriptions["test_graph_1"]
|
||||
assert mock_websocket not in connection_manager.subscriptions["test_channel_42"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe(
|
||||
connection_manager: ConnectionManager, mock_websocket: AsyncMock
|
||||
) -> None:
|
||||
await connection_manager.subscribe(
|
||||
await connection_manager.subscribe_graph_exec(
|
||||
user_id="user-1",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="graph-exec-1",
|
||||
websocket=mock_websocket,
|
||||
)
|
||||
assert mock_websocket in connection_manager.subscriptions["user-1_test_graph_1"]
|
||||
assert (
|
||||
mock_websocket
|
||||
in connection_manager.subscriptions["user-1|graph_exec#graph-exec-1"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe(
|
||||
connection_manager: ConnectionManager, mock_websocket: AsyncMock
|
||||
) -> None:
|
||||
connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
|
||||
channel_key = "user-1|graph_exec#graph-exec-1"
|
||||
connection_manager.subscriptions[channel_key] = {mock_websocket}
|
||||
|
||||
await connection_manager.unsubscribe(
|
||||
user_id="user-1",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="graph-exec-1",
|
||||
websocket=mock_websocket,
|
||||
)
|
||||
|
||||
@ -72,15 +78,46 @@ async def test_unsubscribe(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_execution_result(
|
||||
async def test_send_graph_execution_result(
|
||||
connection_manager: ConnectionManager, mock_websocket: AsyncMock
|
||||
) -> None:
|
||||
connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
|
||||
result: ExecutionResult = ExecutionResult(
|
||||
channel_key = "user-1|graph_exec#graph-exec-1"
|
||||
connection_manager.subscriptions[channel_key] = {mock_websocket}
|
||||
result = GraphExecutionEvent(
|
||||
id="graph-exec-1",
|
||||
user_id="user-1",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="test_exec_id",
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
cost=0,
|
||||
duration=1.2,
|
||||
total_run_time=0.5,
|
||||
started_at=datetime.now(tz=timezone.utc),
|
||||
ended_at=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
|
||||
await connection_manager.send_execution_update(result)
|
||||
|
||||
mock_websocket.send_text.assert_called_once_with(
|
||||
WSMessage(
|
||||
method=WSMethod.GRAPH_EXECUTION_EVENT,
|
||||
channel="user-1|graph_exec#graph-exec-1",
|
||||
data=result.model_dump(),
|
||||
).model_dump_json()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_node_execution_result(
|
||||
connection_manager: ConnectionManager, mock_websocket: AsyncMock
|
||||
) -> None:
|
||||
channel_key = "user-1|graph_exec#graph-exec-1"
|
||||
connection_manager.subscriptions[channel_key] = {mock_websocket}
|
||||
result = NodeExecutionEvent(
|
||||
user_id="user-1",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="graph-exec-1",
|
||||
node_exec_id="test_node_exec_id",
|
||||
node_id="test_node_id",
|
||||
block_id="test_block_id",
|
||||
@ -93,12 +130,12 @@ async def test_send_execution_result(
|
||||
end_time=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
|
||||
await connection_manager.send_execution_result(result)
|
||||
await connection_manager.send_execution_update(result)
|
||||
|
||||
mock_websocket.send_text.assert_called_once_with(
|
||||
WsMessage(
|
||||
method=Methods.EXECUTION_EVENT,
|
||||
channel="user-1_test_graph_1",
|
||||
WSMessage(
|
||||
method=WSMethod.NODE_EXECUTION_EVENT,
|
||||
channel="user-1|graph_exec#graph-exec-1",
|
||||
data=result.model_dump(),
|
||||
).model_dump_json()
|
||||
)
|
||||
@ -108,12 +145,13 @@ async def test_send_execution_result(
|
||||
async def test_send_execution_result_user_mismatch(
|
||||
connection_manager: ConnectionManager, mock_websocket: AsyncMock
|
||||
) -> None:
|
||||
connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
|
||||
result: ExecutionResult = ExecutionResult(
|
||||
channel_key = "user-1|graph_exec#graph-exec-1"
|
||||
connection_manager.subscriptions[channel_key] = {mock_websocket}
|
||||
result = NodeExecutionEvent(
|
||||
user_id="user-2",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="test_exec_id",
|
||||
graph_exec_id="graph-exec-1",
|
||||
node_exec_id="test_node_exec_id",
|
||||
node_id="test_node_id",
|
||||
block_id="test_block_id",
|
||||
@ -126,7 +164,7 @@ async def test_send_execution_result_user_mismatch(
|
||||
end_time=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
|
||||
await connection_manager.send_execution_result(result)
|
||||
await connection_manager.send_execution_update(result)
|
||||
|
||||
mock_websocket.send_text.assert_not_called()
|
||||
|
||||
@ -135,7 +173,7 @@ async def test_send_execution_result_user_mismatch(
|
||||
async def test_send_execution_result_no_subscribers(
|
||||
connection_manager: ConnectionManager, mock_websocket: AsyncMock
|
||||
) -> None:
|
||||
result: ExecutionResult = ExecutionResult(
|
||||
result = NodeExecutionEvent(
|
||||
user_id="user-1",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
@ -152,6 +190,6 @@ async def test_send_execution_result_no_subscribers(
|
||||
end_time=datetime.now(),
|
||||
)
|
||||
|
||||
await connection_manager.send_execution_result(result)
|
||||
await connection_manager.send_execution_update(result)
|
||||
|
||||
mock_websocket.send_text.assert_not_called()
|
||||
|
@ -7,8 +7,8 @@ from fastapi import WebSocket, WebSocketDisconnect
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.ws_api import (
|
||||
Methods,
|
||||
WsMessage,
|
||||
WSMessage,
|
||||
WSMethod,
|
||||
handle_subscribe,
|
||||
handle_unsubscribe,
|
||||
websocket_router,
|
||||
@ -30,28 +30,33 @@ async def test_websocket_router_subscribe(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
||||
) -> None:
|
||||
mock_websocket.receive_text.side_effect = [
|
||||
WsMessage(
|
||||
method=Methods.SUBSCRIBE,
|
||||
data={"graph_id": "test_graph", "graph_version": 1},
|
||||
WSMessage(
|
||||
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
|
||||
data={"graph_exec_id": "test-graph-exec-1"},
|
||||
).model_dump_json(),
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
mock_manager.subscribe_graph_exec.return_value = (
|
||||
f"{DEFAULT_USER_ID}|graph_exec#test-graph-exec-1"
|
||||
)
|
||||
|
||||
await websocket_router(
|
||||
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
|
||||
)
|
||||
|
||||
mock_manager.connect.assert_called_once_with(mock_websocket)
|
||||
mock_manager.subscribe.assert_called_once_with(
|
||||
mock_manager.connect_socket.assert_called_once_with(mock_websocket)
|
||||
mock_manager.subscribe_graph_exec.assert_called_once_with(
|
||||
user_id=DEFAULT_USER_ID,
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="test-graph-exec-1",
|
||||
websocket=mock_websocket,
|
||||
)
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
assert '"method":"subscribe"' in mock_websocket.send_text.call_args[0][0]
|
||||
assert (
|
||||
'"method":"subscribe_graph_execution"'
|
||||
in mock_websocket.send_text.call_args[0][0]
|
||||
)
|
||||
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
|
||||
mock_manager.disconnect.assert_called_once_with(mock_websocket)
|
||||
mock_manager.disconnect_socket.assert_called_once_with(mock_websocket)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -59,28 +64,30 @@ async def test_websocket_router_unsubscribe(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
||||
) -> None:
|
||||
mock_websocket.receive_text.side_effect = [
|
||||
WsMessage(
|
||||
method=Methods.UNSUBSCRIBE,
|
||||
data={"graph_id": "test_graph", "graph_version": 1},
|
||||
WSMessage(
|
||||
method=WSMethod.UNSUBSCRIBE,
|
||||
data={"graph_exec_id": "test-graph-exec-1"},
|
||||
).model_dump_json(),
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
mock_manager.unsubscribe.return_value = (
|
||||
f"{DEFAULT_USER_ID}|graph_exec#test-graph-exec-1"
|
||||
)
|
||||
|
||||
await websocket_router(
|
||||
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
|
||||
)
|
||||
|
||||
mock_manager.connect.assert_called_once_with(mock_websocket)
|
||||
mock_manager.connect_socket.assert_called_once_with(mock_websocket)
|
||||
mock_manager.unsubscribe.assert_called_once_with(
|
||||
user_id=DEFAULT_USER_ID,
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="test-graph-exec-1",
|
||||
websocket=mock_websocket,
|
||||
)
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
assert '"method":"unsubscribe"' in mock_websocket.send_text.call_args[0][0]
|
||||
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
|
||||
mock_manager.disconnect.assert_called_once_with(mock_websocket)
|
||||
mock_manager.disconnect_socket.assert_called_once_with(mock_websocket)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -88,7 +95,7 @@ async def test_websocket_router_invalid_method(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
||||
) -> None:
|
||||
mock_websocket.receive_text.side_effect = [
|
||||
WsMessage(method=Methods.EXECUTION_EVENT).model_dump_json(),
|
||||
WSMessage(method=WSMethod.GRAPH_EXECUTION_EVENT).model_dump_json(),
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
|
||||
@ -96,19 +103,23 @@ async def test_websocket_router_invalid_method(
|
||||
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
|
||||
)
|
||||
|
||||
mock_manager.connect.assert_called_once_with(mock_websocket)
|
||||
mock_manager.connect_socket.assert_called_once_with(mock_websocket)
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
|
||||
assert '"success":false' in mock_websocket.send_text.call_args[0][0]
|
||||
mock_manager.disconnect.assert_called_once_with(mock_websocket)
|
||||
mock_manager.disconnect_socket.assert_called_once_with(mock_websocket)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_subscribe_success(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
||||
) -> None:
|
||||
message = WsMessage(
|
||||
method=Methods.SUBSCRIBE, data={"graph_id": "test_graph", "graph_version": 1}
|
||||
message = WSMessage(
|
||||
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
|
||||
data={"graph_exec_id": "test-graph-exec-id"},
|
||||
)
|
||||
mock_manager.subscribe_graph_exec.return_value = (
|
||||
"user-1|graph_exec#test-graph-exec-id"
|
||||
)
|
||||
|
||||
await handle_subscribe(
|
||||
@ -118,14 +129,16 @@ async def test_handle_subscribe_success(
|
||||
message=message,
|
||||
)
|
||||
|
||||
mock_manager.subscribe.assert_called_once_with(
|
||||
mock_manager.subscribe_graph_exec.assert_called_once_with(
|
||||
user_id="user-1",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="test-graph-exec-id",
|
||||
websocket=mock_websocket,
|
||||
)
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
assert '"method":"subscribe"' in mock_websocket.send_text.call_args[0][0]
|
||||
assert (
|
||||
'"method":"subscribe_graph_execution"'
|
||||
in mock_websocket.send_text.call_args[0][0]
|
||||
)
|
||||
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
|
||||
|
||||
|
||||
@ -133,7 +146,7 @@ async def test_handle_subscribe_success(
|
||||
async def test_handle_subscribe_missing_data(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
||||
) -> None:
|
||||
message = WsMessage(method=Methods.SUBSCRIBE)
|
||||
message = WSMessage(method=WSMethod.SUBSCRIBE_GRAPH_EXEC)
|
||||
|
||||
await handle_subscribe(
|
||||
connection_manager=cast(ConnectionManager, mock_manager),
|
||||
@ -142,7 +155,7 @@ async def test_handle_subscribe_missing_data(
|
||||
message=message,
|
||||
)
|
||||
|
||||
mock_manager.subscribe.assert_not_called()
|
||||
mock_manager.subscribe_graph_exec.assert_not_called()
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
|
||||
assert '"success":false' in mock_websocket.send_text.call_args[0][0]
|
||||
@ -152,9 +165,10 @@ async def test_handle_subscribe_missing_data(
|
||||
async def test_handle_unsubscribe_success(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
||||
) -> None:
|
||||
message = WsMessage(
|
||||
method=Methods.UNSUBSCRIBE, data={"graph_id": "test_graph", "graph_version": 1}
|
||||
message = WSMessage(
|
||||
method=WSMethod.UNSUBSCRIBE, data={"graph_exec_id": "test-graph-exec-id"}
|
||||
)
|
||||
mock_manager.unsubscribe.return_value = "user-1|graph_exec#test-graph-exec-id"
|
||||
|
||||
await handle_unsubscribe(
|
||||
connection_manager=cast(ConnectionManager, mock_manager),
|
||||
@ -165,8 +179,7 @@ async def test_handle_unsubscribe_success(
|
||||
|
||||
mock_manager.unsubscribe.assert_called_once_with(
|
||||
user_id="user-1",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="test-graph-exec-id",
|
||||
websocket=mock_websocket,
|
||||
)
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
@ -178,7 +191,7 @@ async def test_handle_unsubscribe_success(
|
||||
async def test_handle_unsubscribe_missing_data(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
||||
) -> None:
|
||||
message = WsMessage(method=Methods.UNSUBSCRIBE)
|
||||
message = WSMessage(method=WSMethod.UNSUBSCRIBE)
|
||||
|
||||
await handle_unsubscribe(
|
||||
connection_manager=cast(ConnectionManager, mock_manager),
|
||||
|
@ -0,0 +1,21 @@
|
||||
import AgentFlowListSkeleton from "@/components/monitor/skeletons/AgentFlowListSkeleton";
|
||||
import React from "react";
|
||||
import FlowRunsListSkeleton from "@/components/monitor/skeletons/FlowRunsListSkeleton";
|
||||
import FlowRunsStatusSkeleton from "@/components/monitor/skeletons/FlowRunsStatusSkeleton";
|
||||
|
||||
export default function MonitorLoadingSkeleton() {
|
||||
return (
|
||||
<div className="space-y-4 p-4">
|
||||
<div className="grid grid-cols-1 gap-4 md:grid-cols-3">
|
||||
{/* Agents Section */}
|
||||
<AgentFlowListSkeleton />
|
||||
|
||||
{/* Runs Section */}
|
||||
<FlowRunsListSkeleton />
|
||||
|
||||
{/* Stats Section */}
|
||||
<FlowRunsStatusSkeleton />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
@ -90,7 +90,7 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
);
|
||||
api.getGraphExecutions(agent.agent_id).then((agentRuns) => {
|
||||
const sortedRuns = agentRuns.toSorted(
|
||||
(a, b) => b.started_at - a.started_at,
|
||||
(a, b) => Number(b.started_at) - Number(a.started_at),
|
||||
);
|
||||
setAgentRuns(sortedRuns);
|
||||
|
||||
@ -102,7 +102,7 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
if (!selectedView.id && isFirstLoad && sortedRuns.length > 0) {
|
||||
// only for first load or first execution
|
||||
setIsFirstLoad(false);
|
||||
selectView({ type: "run", id: sortedRuns[0].execution_id });
|
||||
selectView({ type: "run", id: sortedRuns[0].id });
|
||||
}
|
||||
});
|
||||
});
|
||||
@ -121,10 +121,8 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
useEffect(() => {
|
||||
if (selectedView.type != "run" || !selectedView.id || !agent) return;
|
||||
|
||||
const newSelectedRun = agentRuns.find(
|
||||
(run) => run.execution_id == selectedView.id,
|
||||
);
|
||||
if (selectedView.id !== selectedRun?.execution_id) {
|
||||
const newSelectedRun = agentRuns.find((run) => run.id == selectedView.id);
|
||||
if (selectedView.id !== selectedRun?.id) {
|
||||
// Pull partial data from "cache" while waiting for the rest to load
|
||||
setSelectedRun(newSelectedRun ?? null);
|
||||
|
||||
@ -136,14 +134,7 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
setSelectedRun(run);
|
||||
});
|
||||
}
|
||||
}, [
|
||||
api,
|
||||
selectedView,
|
||||
agent,
|
||||
agentRuns,
|
||||
selectedRun?.execution_id,
|
||||
getGraphVersion,
|
||||
]);
|
||||
}, [api, selectedView, agent, agentRuns, selectedRun?.id, getGraphVersion]);
|
||||
|
||||
const fetchSchedules = useCallback(async () => {
|
||||
if (!agent) return;
|
||||
@ -169,17 +160,15 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
const deleteRun = useCallback(
|
||||
async (run: GraphExecutionMeta) => {
|
||||
if (run.status == "RUNNING" || run.status == "QUEUED") {
|
||||
await api.stopGraphExecution(run.graph_id, run.execution_id);
|
||||
await api.stopGraphExecution(run.graph_id, run.id);
|
||||
}
|
||||
await api.deleteGraphExecution(run.execution_id);
|
||||
await api.deleteGraphExecution(run.id);
|
||||
|
||||
setConfirmingDeleteAgentRun(null);
|
||||
if (selectedView.type == "run" && selectedView.id == run.execution_id) {
|
||||
if (selectedView.type == "run" && selectedView.id == run.id) {
|
||||
openRunDraftView();
|
||||
}
|
||||
setAgentRuns(
|
||||
agentRuns.filter((r) => r.execution_id !== run.execution_id),
|
||||
);
|
||||
setAgentRuns(agentRuns.filter((r) => r.id !== run.id));
|
||||
},
|
||||
[agentRuns, api, selectedView, openRunDraftView],
|
||||
);
|
||||
|
@ -102,9 +102,7 @@ const Monitor = () => {
|
||||
: executions),
|
||||
].sort((a, b) => Number(b.started_at) - Number(a.started_at))}
|
||||
selectedRun={selectedRun}
|
||||
onSelectRun={(r) =>
|
||||
setSelectedRun(r.execution_id == selectedRun?.execution_id ? null : r)
|
||||
}
|
||||
onSelectRun={(r) => setSelectedRun(r.id == selectedRun?.id ? null : r)}
|
||||
/>
|
||||
{(selectedRun && (
|
||||
<FlowRunInfo
|
||||
|
@ -90,8 +90,8 @@ export default function AgentRunDetailsView({
|
||||
);
|
||||
|
||||
const stopRun = useCallback(
|
||||
() => api.stopGraphExecution(graph.id, run.execution_id),
|
||||
[api, graph.id, run.execution_id],
|
||||
() => api.stopGraphExecution(graph.id, run.id),
|
||||
[api, graph.id, run.id],
|
||||
);
|
||||
|
||||
const agentRunOutputs:
|
||||
|
@ -111,8 +111,8 @@ export default function AgentRunsSelectorList({
|
||||
status={agentRunStatusMap[run.status]}
|
||||
title={agent.name}
|
||||
timestamp={run.started_at}
|
||||
selected={selectedView.id === run.execution_id}
|
||||
onClick={() => onSelectRun(run.execution_id)}
|
||||
selected={selectedView.id === run.id}
|
||||
onClick={() => onSelectRun(run.id)}
|
||||
onDelete={() => onDeleteRun(run)}
|
||||
/>
|
||||
))
|
||||
|
@ -125,7 +125,9 @@ export const AgentFlowList = ({
|
||||
if (!a.lastRun && !b.lastRun) return 0;
|
||||
if (!a.lastRun) return 1;
|
||||
if (!b.lastRun) return -1;
|
||||
return b.lastRun.started_at - a.lastRun.started_at;
|
||||
return (
|
||||
Number(b.lastRun.started_at) - Number(a.lastRun.started_at)
|
||||
);
|
||||
})
|
||||
.map(({ flow, runCount, lastRun }) => (
|
||||
<TableRow
|
||||
|
@ -27,7 +27,7 @@ export const FlowRunInfo: React.FC<
|
||||
|
||||
const fetchBlockResults = useCallback(async () => {
|
||||
const executionResults = (
|
||||
await api.getGraphExecutionInfo(flow.agent_id, execution.execution_id)
|
||||
await api.getGraphExecutionInfo(flow.agent_id, execution.id)
|
||||
).node_executions;
|
||||
|
||||
// Create a map of the latest COMPLETED execution results of output nodes by node_id
|
||||
@ -69,7 +69,7 @@ export const FlowRunInfo: React.FC<
|
||||
result: result.output_data?.output || undefined,
|
||||
})),
|
||||
);
|
||||
}, [api, flow.agent_id, execution.execution_id]);
|
||||
}, [api, flow.agent_id, execution.id]);
|
||||
|
||||
// Fetch graph and execution data
|
||||
useEffect(() => {
|
||||
@ -84,8 +84,8 @@ export const FlowRunInfo: React.FC<
|
||||
}
|
||||
|
||||
const handleStopRun = useCallback(() => {
|
||||
api.stopGraphExecution(flow.agent_id, execution.execution_id);
|
||||
}, [api, flow.agent_id, execution.execution_id]);
|
||||
api.stopGraphExecution(flow.agent_id, execution.id);
|
||||
}, [api, flow.agent_id, execution.id]);
|
||||
|
||||
return (
|
||||
<>
|
||||
@ -109,7 +109,7 @@ export const FlowRunInfo: React.FC<
|
||||
{flow.can_access_graph && (
|
||||
<Link
|
||||
className={buttonVariants({ variant: "default" })}
|
||||
href={`/build?flowID=${execution.graph_id}&flowVersion=${execution.graph_version}&flowExecutionID=${execution.execution_id}`}
|
||||
href={`/build?flowID=${execution.graph_id}&flowVersion=${execution.graph_version}&flowExecutionID=${execution.id}`}
|
||||
>
|
||||
<Pencil2Icon className="mr-2" /> Open in Builder
|
||||
</Link>
|
||||
@ -121,7 +121,7 @@ export const FlowRunInfo: React.FC<
|
||||
<strong>Agent ID:</strong> <code>{flow.agent_id}</code>
|
||||
</p>
|
||||
<p className="hidden">
|
||||
<strong>Run ID:</strong> <code>{execution.execution_id}</code>
|
||||
<strong>Run ID:</strong> <code>{execution.id}</code>
|
||||
</p>
|
||||
<div>
|
||||
<strong>Status:</strong>{" "}
|
||||
|
@ -37,17 +37,13 @@ export const FlowRunsList: React.FC<{
|
||||
<TableBody data-testid="flow-runs-list-body">
|
||||
{executions.map((execution) => (
|
||||
<TableRow
|
||||
key={execution.execution_id}
|
||||
data-testid={`flow-run-${execution.execution_id}-graph-${execution.graph_id}`}
|
||||
data-runid={execution.execution_id}
|
||||
key={execution.id}
|
||||
data-testid={`flow-run-${execution.id}-graph-${execution.graph_id}`}
|
||||
data-runid={execution.id}
|
||||
data-graphid={execution.graph_id}
|
||||
className="cursor-pointer"
|
||||
onClick={() => onSelectRun(execution)}
|
||||
data-state={
|
||||
selectedRun?.execution_id == execution.execution_id
|
||||
? "selected"
|
||||
: null
|
||||
}
|
||||
data-state={selectedRun?.id == execution.id ? "selected" : null}
|
||||
>
|
||||
<TableCell>
|
||||
<TextRenderer
|
||||
|
@ -29,7 +29,7 @@ export const FlowRunsStatus: React.FC<{
|
||||
: statsSince;
|
||||
const filteredFlowRuns =
|
||||
statsSinceTimestamp != null
|
||||
? executions.filter((fr) => fr.started_at > statsSinceTimestamp)
|
||||
? executions.filter((fr) => Number(fr.started_at) > statsSinceTimestamp)
|
||||
: executions;
|
||||
|
||||
return (
|
||||
|
@ -99,7 +99,7 @@ export const FlowRunsTimeline = ({
|
||||
.filter((e) => e.graph_id == flow.agent_id)
|
||||
.map((e) => ({
|
||||
...e,
|
||||
time: e.started_at + e.total_run_time * 1000,
|
||||
time: Number(e.started_at) + e.total_run_time * 1000,
|
||||
_duration: e.total_run_time,
|
||||
}))}
|
||||
name={flow.name}
|
||||
@ -108,14 +108,14 @@ export const FlowRunsTimeline = ({
|
||||
))}
|
||||
{executions.map((execution) => (
|
||||
<Line
|
||||
key={execution.execution_id}
|
||||
key={execution.id}
|
||||
type="linear"
|
||||
dataKey="_duration"
|
||||
data={[
|
||||
{ ...execution, time: execution.started_at, _duration: 0 },
|
||||
{ ...execution, time: Number(execution.started_at), _duration: 0 },
|
||||
{
|
||||
...execution,
|
||||
time: execution.ended_at,
|
||||
time: Number(execution.ended_at),
|
||||
_duration: execution.total_run_time,
|
||||
},
|
||||
]}
|
||||
|
@ -106,24 +106,24 @@ export default function useAgentGraph(
|
||||
|
||||
// Subscribe to execution events
|
||||
useEffect(() => {
|
||||
api.onWebSocketMessage("execution_event", (data) => {
|
||||
api.onWebSocketMessage("node_execution_event", (data) => {
|
||||
if (data.graph_exec_id != flowExecutionID) {
|
||||
return;
|
||||
}
|
||||
setUpdateQueue((prev) => [...prev, data]);
|
||||
});
|
||||
|
||||
if (flowID && flowVersion) {
|
||||
if (flowExecutionID) {
|
||||
api
|
||||
.subscribeToExecution(flowID, flowVersion)
|
||||
.subscribeToGraphExecution(flowExecutionID)
|
||||
.then(() =>
|
||||
console.debug(
|
||||
`Subscribed to execution events for ${flowID} v.${flowVersion}`,
|
||||
`Subscribed to updates for execution #${flowExecutionID}`,
|
||||
),
|
||||
)
|
||||
.catch((error) =>
|
||||
console.error(
|
||||
`Failed to subscribe to execution events for ${flowID} v.${flowVersion}:`,
|
||||
`Failed to subscribe to updates for execution #${flowExecutionID}:`,
|
||||
error,
|
||||
),
|
||||
);
|
||||
@ -235,7 +235,7 @@ export default function useAgentGraph(
|
||||
return newNodes;
|
||||
});
|
||||
},
|
||||
[availableNodes, availableFlows, formatEdgeID, getOutputType],
|
||||
[availableNodes, availableFlows, getOutputType],
|
||||
);
|
||||
|
||||
const getFrontendId = useCallback(
|
||||
@ -636,7 +636,7 @@ export default function useAgentGraph(
|
||||
// Track execution until completed
|
||||
const pendingNodeExecutions: Set<string> = new Set();
|
||||
const cancelExecListener = api.onWebSocketMessage(
|
||||
"execution_event",
|
||||
"node_execution_event",
|
||||
(nodeResult) => {
|
||||
// We are racing the server here, since we need the ID to filter events
|
||||
if (nodeResult.graph_exec_id != flowExecutionID) {
|
||||
|
@ -106,8 +106,9 @@ export default class BackendAPI {
|
||||
}
|
||||
|
||||
////////////////////////////////////////
|
||||
///////////// CREDITS //////////////////
|
||||
/////////////// CREDITS ////////////////
|
||||
////////////////////////////////////////
|
||||
|
||||
getUserCredit(): Promise<{ credits: number }> {
|
||||
try {
|
||||
return this._get("/credits");
|
||||
@ -172,8 +173,9 @@ export default class BackendAPI {
|
||||
}
|
||||
|
||||
////////////////////////////////////////
|
||||
/////////// ONBOARDING /////////////////
|
||||
////////////// ONBOARDING //////////////
|
||||
////////////////////////////////////////
|
||||
|
||||
getUserOnboarding(): Promise<UserOnboarding> {
|
||||
return this._get("/onboarding");
|
||||
}
|
||||
@ -192,8 +194,9 @@ export default class BackendAPI {
|
||||
}
|
||||
|
||||
////////////////////////////////////////
|
||||
/////////// GRAPHS /////////////////////
|
||||
//////////////// GRAPHS ////////////////
|
||||
////////////////////////////////////////
|
||||
|
||||
getBlocks(): Promise<Block[]> {
|
||||
return this._get("/blocks");
|
||||
}
|
||||
@ -398,9 +401,9 @@ export default class BackendAPI {
|
||||
return this._request("POST", "/analytics/log_raw_analytics", analytic);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////
|
||||
/////////// V2 STORE API /////////////////
|
||||
/////////////////////////////////////////
|
||||
////////////////////////////////////////
|
||||
///////////// V2 STORE API /////////////
|
||||
////////////////////////////////////////
|
||||
|
||||
getStoreProfile(): Promise<ProfileDetails | null> {
|
||||
try {
|
||||
@ -529,9 +532,9 @@ export default class BackendAPI {
|
||||
return this._get(url);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////
|
||||
/////////// Admin API ///////////////////
|
||||
/////////////////////////////////////////
|
||||
////////////////////////////////////////
|
||||
////////////// Admin API ///////////////
|
||||
////////////////////////////////////////
|
||||
|
||||
getAdminListingsWithVersions(params?: {
|
||||
status?: SubmissionStatus;
|
||||
@ -553,9 +556,9 @@ export default class BackendAPI {
|
||||
);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////
|
||||
/////////// V2 LIBRARY API //////////////
|
||||
/////////////////////////////////////////
|
||||
////////////////////////////////////////
|
||||
//////////// V2 LIBRARY API ////////////
|
||||
////////////////////////////////////////
|
||||
|
||||
listLibraryAgents(params?: {
|
||||
search_term?: string;
|
||||
@ -651,9 +654,9 @@ export default class BackendAPI {
|
||||
);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////
|
||||
/////////// INTERNAL FUNCTIONS ////////////
|
||||
//////////////////////////////??///////////
|
||||
////////////////////////////////////////
|
||||
////////// INTERNAL FUNCTIONS //////////
|
||||
////////////////////////////////////////
|
||||
|
||||
private _get(path: string, query?: Record<string, any>) {
|
||||
return this._request("GET", path, query);
|
||||
@ -815,104 +818,16 @@ export default class BackendAPI {
|
||||
}
|
||||
}
|
||||
|
||||
startHeartbeat() {
|
||||
this.stopHeartbeat();
|
||||
this.heartbeatInterval = window.setInterval(() => {
|
||||
if (this.webSocket?.readyState === WebSocket.OPEN) {
|
||||
this.webSocket.send(
|
||||
JSON.stringify({
|
||||
method: "heartbeat",
|
||||
data: "ping",
|
||||
success: true,
|
||||
}),
|
||||
);
|
||||
////////////////////////////////////////
|
||||
////////////// WEBSOCKETS //////////////
|
||||
////////////////////////////////////////
|
||||
|
||||
this.heartbeatTimeoutId = window.setTimeout(() => {
|
||||
console.warn("Heartbeat timeout - reconnecting");
|
||||
this.webSocket?.close();
|
||||
this.connectWebSocket();
|
||||
}, this.HEARTBEAT_TIMEOUT);
|
||||
}
|
||||
}, this.HEARTBEAT_INTERVAL);
|
||||
}
|
||||
|
||||
stopHeartbeat() {
|
||||
if (this.heartbeatInterval) {
|
||||
clearInterval(this.heartbeatInterval);
|
||||
this.heartbeatInterval = null;
|
||||
}
|
||||
if (this.heartbeatTimeoutId) {
|
||||
clearTimeout(this.heartbeatTimeoutId);
|
||||
this.heartbeatTimeoutId = null;
|
||||
}
|
||||
}
|
||||
|
||||
handleHeartbeatResponse() {
|
||||
if (this.heartbeatTimeoutId) {
|
||||
clearTimeout(this.heartbeatTimeoutId);
|
||||
this.heartbeatTimeoutId = null;
|
||||
}
|
||||
}
|
||||
|
||||
async connectWebSocket(): Promise<void> {
|
||||
this.wsConnecting ??= new Promise(async (resolve, reject) => {
|
||||
try {
|
||||
const token =
|
||||
(await this.supabaseClient?.auth.getSession())?.data.session
|
||||
?.access_token || "";
|
||||
const wsUrlWithToken = `${this.wsUrl}?token=${token}`;
|
||||
this.webSocket = new WebSocket(wsUrlWithToken);
|
||||
|
||||
this.webSocket.onopen = () => {
|
||||
this.startHeartbeat(); // Start heartbeat when connection opens
|
||||
resolve();
|
||||
};
|
||||
|
||||
this.webSocket.onclose = (event) => {
|
||||
console.warn("WebSocket connection closed", event);
|
||||
this.stopHeartbeat(); // Stop heartbeat when connection closes
|
||||
this.wsConnecting = null;
|
||||
// Attempt to reconnect after a delay
|
||||
setTimeout(() => this.connectWebSocket(), 1000);
|
||||
};
|
||||
|
||||
this.webSocket.onerror = (error) => {
|
||||
console.error("WebSocket error:", error);
|
||||
this.stopHeartbeat(); // Stop heartbeat on error
|
||||
this.wsConnecting = null;
|
||||
reject(error);
|
||||
};
|
||||
|
||||
this.webSocket.onmessage = (event) => {
|
||||
const message: WebsocketMessage = JSON.parse(event.data);
|
||||
|
||||
// Handle heartbeat response
|
||||
if (message.method === "heartbeat" && message.data === "pong") {
|
||||
this.handleHeartbeatResponse();
|
||||
return;
|
||||
}
|
||||
|
||||
if (message.method === "execution_event") {
|
||||
message.data = parseNodeExecutionResultTimestamps(message.data);
|
||||
}
|
||||
this.wsMessageHandlers[message.method]?.forEach((handler) =>
|
||||
handler(message.data),
|
||||
);
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("Error connecting to WebSocket:", error);
|
||||
reject(error);
|
||||
}
|
||||
subscribeToGraphExecution(graphExecID: GraphExecutionID): Promise<void> {
|
||||
return this.sendWebSocketMessage("subscribe_graph_execution", {
|
||||
graph_exec_id: graphExecID,
|
||||
});
|
||||
return this.wsConnecting;
|
||||
}
|
||||
|
||||
disconnectWebSocket() {
|
||||
this.stopHeartbeat(); // Stop heartbeat when disconnecting
|
||||
if (this.webSocket && this.webSocket.readyState === WebSocket.OPEN) {
|
||||
this.webSocket.close();
|
||||
}
|
||||
}
|
||||
async sendWebSocketMessage<M extends keyof WebsocketMessageTypeMap>(
|
||||
method: M,
|
||||
data: WebsocketMessageTypeMap[M],
|
||||
@ -920,7 +835,7 @@ export default class BackendAPI {
|
||||
callCountLimit = 4,
|
||||
): Promise<void> {
|
||||
if (this.webSocket && this.webSocket.readyState === WebSocket.OPEN) {
|
||||
const result = this.webSocket.send(JSON.stringify({ method, data }));
|
||||
this.webSocket.send(JSON.stringify({ method, data }));
|
||||
return;
|
||||
}
|
||||
if (callCount >= callCountLimit) {
|
||||
@ -948,11 +863,105 @@ export default class BackendAPI {
|
||||
return () => this.wsMessageHandlers[method].delete(handler);
|
||||
}
|
||||
|
||||
async subscribeToExecution(graphId: string, graphVersion: number) {
|
||||
await this.sendWebSocketMessage("subscribe", {
|
||||
graph_id: graphId,
|
||||
graph_version: graphVersion,
|
||||
async connectWebSocket(): Promise<void> {
|
||||
this.wsConnecting ??= new Promise(async (resolve, reject) => {
|
||||
try {
|
||||
const token =
|
||||
(await this.supabaseClient?.auth.getSession())?.data.session
|
||||
?.access_token || "";
|
||||
const wsUrlWithToken = `${this.wsUrl}?token=${token}`;
|
||||
this.webSocket = new WebSocket(wsUrlWithToken);
|
||||
|
||||
this.webSocket.onopen = () => {
|
||||
this._startWSHeartbeat(); // Start heartbeat when connection opens
|
||||
resolve();
|
||||
};
|
||||
|
||||
this.webSocket.onclose = (event) => {
|
||||
console.warn("WebSocket connection closed", event);
|
||||
this._stopWSHeartbeat(); // Stop heartbeat when connection closes
|
||||
this.wsConnecting = null;
|
||||
// Attempt to reconnect after a delay
|
||||
setTimeout(() => this.connectWebSocket(), 1000);
|
||||
};
|
||||
|
||||
this.webSocket.onerror = (error) => {
|
||||
console.error("WebSocket error:", error);
|
||||
this._stopWSHeartbeat(); // Stop heartbeat on error
|
||||
this.wsConnecting = null;
|
||||
reject(error);
|
||||
};
|
||||
|
||||
this.webSocket.onmessage = (event) => {
|
||||
const message: WebsocketMessage = JSON.parse(event.data);
|
||||
|
||||
// Handle heartbeat response
|
||||
if (message.method === "heartbeat" && message.data === "pong") {
|
||||
this._handleWSHeartbeatResponse();
|
||||
return;
|
||||
}
|
||||
|
||||
if (message.method === "node_execution_event") {
|
||||
message.data = parseNodeExecutionResultTimestamps(message.data);
|
||||
} else if (message.method == "graph_execution_event") {
|
||||
message.data = parseGraphExecutionMetaTimestamps(message.data);
|
||||
}
|
||||
this.wsMessageHandlers[message.method]?.forEach((handler) =>
|
||||
handler(message.data),
|
||||
);
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("Error connecting to WebSocket:", error);
|
||||
reject(error);
|
||||
}
|
||||
});
|
||||
return this.wsConnecting;
|
||||
}
|
||||
|
||||
disconnectWebSocket() {
|
||||
this._stopWSHeartbeat(); // Stop heartbeat when disconnecting
|
||||
if (this.webSocket && this.webSocket.readyState === WebSocket.OPEN) {
|
||||
this.webSocket.close();
|
||||
}
|
||||
}
|
||||
|
||||
_startWSHeartbeat() {
|
||||
this._stopWSHeartbeat();
|
||||
this.heartbeatInterval = window.setInterval(() => {
|
||||
if (this.webSocket?.readyState === WebSocket.OPEN) {
|
||||
this.webSocket.send(
|
||||
JSON.stringify({
|
||||
method: "heartbeat",
|
||||
data: "ping",
|
||||
success: true,
|
||||
}),
|
||||
);
|
||||
|
||||
this.heartbeatTimeoutId = window.setTimeout(() => {
|
||||
console.warn("Heartbeat timeout - reconnecting");
|
||||
this.webSocket?.close();
|
||||
this.connectWebSocket();
|
||||
}, this.HEARTBEAT_TIMEOUT);
|
||||
}
|
||||
}, this.HEARTBEAT_INTERVAL);
|
||||
}
|
||||
|
||||
_stopWSHeartbeat() {
|
||||
if (this.heartbeatInterval) {
|
||||
clearInterval(this.heartbeatInterval);
|
||||
this.heartbeatInterval = null;
|
||||
}
|
||||
if (this.heartbeatTimeoutId) {
|
||||
clearTimeout(this.heartbeatTimeoutId);
|
||||
this.heartbeatTimeoutId = null;
|
||||
}
|
||||
}
|
||||
|
||||
_handleWSHeartbeatResponse() {
|
||||
if (this.heartbeatTimeoutId) {
|
||||
clearTimeout(this.heartbeatTimeoutId);
|
||||
this.heartbeatTimeoutId = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -963,8 +972,9 @@ type GraphCreateRequestBody = {
|
||||
};
|
||||
|
||||
type WebsocketMessageTypeMap = {
|
||||
subscribe: { graph_id: string; graph_version: number };
|
||||
execution_event: NodeExecutionResult;
|
||||
subscribe_graph_execution: { graph_exec_id: GraphExecutionID };
|
||||
graph_execution_event: GraphExecutionMeta;
|
||||
node_execution_event: NodeExecutionResult;
|
||||
heartbeat: "ping" | "pong";
|
||||
};
|
||||
|
||||
@ -984,19 +994,30 @@ type _PydanticValidationError = {
|
||||
|
||||
/* *** HELPER FUNCTIONS *** */
|
||||
|
||||
function parseGraphExecutionMetaTimestamps(result: any): GraphExecutionMeta {
|
||||
return _parseObjectTimestamps<GraphExecutionMeta>(result, [
|
||||
"started_at",
|
||||
"ended_at",
|
||||
]);
|
||||
}
|
||||
|
||||
function parseNodeExecutionResultTimestamps(result: any): NodeExecutionResult {
|
||||
return {
|
||||
...result,
|
||||
add_time: new Date(result.add_time),
|
||||
queue_time: result.queue_time ? new Date(result.queue_time) : undefined,
|
||||
start_time: result.start_time ? new Date(result.start_time) : undefined,
|
||||
end_time: result.end_time ? new Date(result.end_time) : undefined,
|
||||
};
|
||||
return _parseObjectTimestamps<NodeExecutionResult>(result, [
|
||||
"add_time",
|
||||
"queue_time",
|
||||
"start_time",
|
||||
"end_time",
|
||||
]);
|
||||
}
|
||||
|
||||
function parseScheduleTimestamp(result: any): Schedule {
|
||||
return {
|
||||
...result,
|
||||
next_run_time: new Date(result.next_run_time),
|
||||
};
|
||||
return _parseObjectTimestamps<Schedule>(result, ["next_run_time"]);
|
||||
}
|
||||
|
||||
function _parseObjectTimestamps<T>(obj: any, keys: (keyof T)[]): T {
|
||||
const result = { ...obj };
|
||||
keys.forEach(
|
||||
(key) => (result[key] = result[key] ? new Date(result[key]) : undefined),
|
||||
);
|
||||
return result;
|
||||
}
|
||||
|
@ -230,9 +230,9 @@ export type LinkCreatable = Omit<Link, "id" | "is_static"> & {
|
||||
|
||||
/* Mirror of backend/data/graph.py:GraphExecutionMeta */
|
||||
export type GraphExecutionMeta = {
|
||||
execution_id: GraphExecutionID;
|
||||
started_at: number;
|
||||
ended_at: number;
|
||||
id: GraphExecutionID;
|
||||
started_at: Date;
|
||||
ended_at: Date;
|
||||
cost?: number;
|
||||
duration: number;
|
||||
total_run_time: number;
|
||||
|
Loading…
x
Reference in New Issue
Block a user