update execution manager with redis

This commit is contained in:
Aarushi 2024-10-21 14:02:43 +01:00
parent e4a9c8216f
commit 9334eee41d
3 changed files with 28 additions and 30 deletions

View File

@ -38,28 +38,6 @@ class NodeExecution(BaseModel):
ExecutionStatus = AgentExecutionStatus
T = TypeVar("T")
class ExecutionQueue(Generic[T]):
"""
Queue for managing the execution of agents.
This will be shared between different processes
"""
def __init__(self):
self.queue = Manager().Queue()
def add(self, execution: T) -> T:
self.queue.put(execution)
return execution
def get(self) -> T:
return self.queue.get()
def empty(self) -> bool:
return self.queue.empty()
class ExecutionResult(BaseModel):
graph_id: str

View File

@ -2,12 +2,14 @@ import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Generic, TypeVar
from backend.data import redis
from backend.data.execution import ExecutionResult
logger = logging.getLogger(__name__)
T = TypeVar("T")
class DateTimeEncoder(json.JSONEncoder):
def default(self, o):
@ -48,3 +50,21 @@ class RedisEventQueue(AbstractEventQueue):
elif message is not None:
logger.error(f"Failed to get execution result from Redis {message}")
return None
class ExecutionQueue(Generic[T]):
def __init__(self, queue_name: str):
self.redis = redis.get_redis()
self.queue_name = queue_name
def add(self, item: T):
message = json.dumps(item.model_dump(), default=str)
self.redis.lpush(self.queue_name, message)
def get(self) -> T:
while True:
_, message = self.redis.brpop(self.queue_name)
return T.model_validate(json.loads(message))
def empty(self) -> bool:
return self.redis.llen(self.queue_name) == 0

View File

@ -13,13 +13,14 @@ from typing import TYPE_CHECKING, Any, Generator, TypeVar, cast
from pydantic import BaseModel
from redis.lock import Lock as RedisLock
from backend.data.queue import ExecutionQueue
if TYPE_CHECKING:
from backend.executor import DatabaseManager
from backend.data import redis
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
from backend.data.execution import (
ExecutionQueue,
ExecutionResult,
ExecutionStatus,
GraphExecution,
@ -415,6 +416,7 @@ class Executor:
configure_logging()
set_service_name("NodeExecutor")
redis.connect()
cls.node_queue = ExecutionQueue[NodeExecution]("node_execution_queue")
cls.pid = os.getpid()
cls.db_client = get_db_client()
cls.creds_manager = IntegrationCredentialsManager()
@ -454,7 +456,6 @@ class Executor:
@error_logged
def on_node_execution(
cls,
q: ExecutionQueue[NodeExecution],
node_exec: NodeExecution,
):
log_metadata = LogMetadata(
@ -465,7 +466,7 @@ class Executor:
node_id=node_exec.node_id,
block_name="-",
)
q = cls.node_queue
execution_stats = {}
timing_info, _ = cls._on_node_execution(
q, node_exec, log_metadata, execution_stats
@ -481,7 +482,6 @@ class Executor:
@time_measured
def _on_node_execution(
cls,
q: ExecutionQueue[NodeExecution],
node_exec: NodeExecution,
log_metadata: LogMetadata,
stats: dict[str, Any] | None = None,
@ -491,7 +491,7 @@ class Executor:
for execution in execute_node(
cls.db_client, cls.creds_manager, node_exec, stats
):
q.add(execution)
cls.node_queue.add(execution)
log_metadata.info(f"Finished node execution {node_exec.node_exec_id}")
except Exception as e:
log_metadata.exception(
@ -582,7 +582,7 @@ class Executor:
cancel_thread.start()
try:
queue = ExecutionQueue[NodeExecution]()
queue = ExecutionQueue[NodeExecution]("node_execution_queue")
for node_exec in graph_exec.start_node_execs:
queue.add(node_exec)
@ -620,7 +620,7 @@ class Executor:
)
running_executions[exec_data.node_id] = cls.executor.apply_async(
cls.on_node_execution,
(queue, exec_data),
(exec_data,),
callback=make_exec_callback(exec_data),
)
@ -661,7 +661,7 @@ class ExecutionManager(AppService):
self.use_redis = True
self.use_supabase = True
self.pool_size = settings.config.num_graph_workers
self.queue = ExecutionQueue[GraphExecution]()
self.queue = ExecutionQueue[GraphExecution]("graph_execution_queue")
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
def run_service(self):