mirror of
https://github.com/Significant-Gravitas/Auto-GPT.git
synced 2025-01-08 11:57:32 +08:00
update execution manager with redis
This commit is contained in:
parent
e4a9c8216f
commit
9334eee41d
@ -38,28 +38,6 @@ class NodeExecution(BaseModel):
|
||||
|
||||
ExecutionStatus = AgentExecutionStatus
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ExecutionQueue(Generic[T]):
|
||||
"""
|
||||
Queue for managing the execution of agents.
|
||||
This will be shared between different processes
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.queue = Manager().Queue()
|
||||
|
||||
def add(self, execution: T) -> T:
|
||||
self.queue.put(execution)
|
||||
return execution
|
||||
|
||||
def get(self) -> T:
|
||||
return self.queue.get()
|
||||
|
||||
def empty(self) -> bool:
|
||||
return self.queue.empty()
|
||||
|
||||
|
||||
class ExecutionResult(BaseModel):
|
||||
graph_id: str
|
||||
|
@ -2,12 +2,14 @@ import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from backend.data import redis
|
||||
from backend.data.execution import ExecutionResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
class DateTimeEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
@ -48,3 +50,21 @@ class RedisEventQueue(AbstractEventQueue):
|
||||
elif message is not None:
|
||||
logger.error(f"Failed to get execution result from Redis {message}")
|
||||
return None
|
||||
|
||||
|
||||
class ExecutionQueue(Generic[T]):
|
||||
def __init__(self, queue_name: str):
|
||||
self.redis = redis.get_redis()
|
||||
self.queue_name = queue_name
|
||||
|
||||
def add(self, item: T):
|
||||
message = json.dumps(item.model_dump(), default=str)
|
||||
self.redis.lpush(self.queue_name, message)
|
||||
|
||||
def get(self) -> T:
|
||||
while True:
|
||||
_, message = self.redis.brpop(self.queue_name)
|
||||
return T.model_validate(json.loads(message))
|
||||
|
||||
def empty(self) -> bool:
|
||||
return self.redis.llen(self.queue_name) == 0
|
||||
|
@ -13,13 +13,14 @@ from typing import TYPE_CHECKING, Any, Generator, TypeVar, cast
|
||||
from pydantic import BaseModel
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from backend.data.queue import ExecutionQueue
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManager
|
||||
|
||||
from backend.data import redis
|
||||
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionResult,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
@ -415,6 +416,7 @@ class Executor:
|
||||
configure_logging()
|
||||
set_service_name("NodeExecutor")
|
||||
redis.connect()
|
||||
cls.node_queue = ExecutionQueue[NodeExecution]("node_execution_queue")
|
||||
cls.pid = os.getpid()
|
||||
cls.db_client = get_db_client()
|
||||
cls.creds_manager = IntegrationCredentialsManager()
|
||||
@ -454,7 +456,6 @@ class Executor:
|
||||
@error_logged
|
||||
def on_node_execution(
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecution],
|
||||
node_exec: NodeExecution,
|
||||
):
|
||||
log_metadata = LogMetadata(
|
||||
@ -465,7 +466,7 @@ class Executor:
|
||||
node_id=node_exec.node_id,
|
||||
block_name="-",
|
||||
)
|
||||
|
||||
q = cls.node_queue
|
||||
execution_stats = {}
|
||||
timing_info, _ = cls._on_node_execution(
|
||||
q, node_exec, log_metadata, execution_stats
|
||||
@ -481,7 +482,6 @@ class Executor:
|
||||
@time_measured
|
||||
def _on_node_execution(
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecution],
|
||||
node_exec: NodeExecution,
|
||||
log_metadata: LogMetadata,
|
||||
stats: dict[str, Any] | None = None,
|
||||
@ -491,7 +491,7 @@ class Executor:
|
||||
for execution in execute_node(
|
||||
cls.db_client, cls.creds_manager, node_exec, stats
|
||||
):
|
||||
q.add(execution)
|
||||
cls.node_queue.add(execution)
|
||||
log_metadata.info(f"Finished node execution {node_exec.node_exec_id}")
|
||||
except Exception as e:
|
||||
log_metadata.exception(
|
||||
@ -582,7 +582,7 @@ class Executor:
|
||||
cancel_thread.start()
|
||||
|
||||
try:
|
||||
queue = ExecutionQueue[NodeExecution]()
|
||||
queue = ExecutionQueue[NodeExecution]("node_execution_queue")
|
||||
for node_exec in graph_exec.start_node_execs:
|
||||
queue.add(node_exec)
|
||||
|
||||
@ -620,7 +620,7 @@ class Executor:
|
||||
)
|
||||
running_executions[exec_data.node_id] = cls.executor.apply_async(
|
||||
cls.on_node_execution,
|
||||
(queue, exec_data),
|
||||
(exec_data,),
|
||||
callback=make_exec_callback(exec_data),
|
||||
)
|
||||
|
||||
@ -661,7 +661,7 @@ class ExecutionManager(AppService):
|
||||
self.use_redis = True
|
||||
self.use_supabase = True
|
||||
self.pool_size = settings.config.num_graph_workers
|
||||
self.queue = ExecutionQueue[GraphExecution]()
|
||||
self.queue = ExecutionQueue[GraphExecution]("graph_execution_queue")
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
|
||||
def run_service(self):
|
||||
|
Loading…
Reference in New Issue
Block a user