mirror of
https://github.com/Significant-Gravitas/Auto-GPT.git
synced 2025-01-09 04:19:02 +08:00
Allow execution of store listings
This commit is contained in:
parent
745aae4aec
commit
44659948e5
@ -6,8 +6,14 @@ from datetime import datetime, timezone
|
||||
from typing import Any, Literal, Optional, Type
|
||||
|
||||
import prisma
|
||||
from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink
|
||||
from prisma.types import AgentGraphWhereInput
|
||||
from prisma.models import (
|
||||
AgentGraph,
|
||||
AgentGraphExecution,
|
||||
AgentNode,
|
||||
AgentNodeLink,
|
||||
StoreListing,
|
||||
)
|
||||
from prisma.types import AgentGraphWhereInput, StoreListingWhereInput
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
@ -509,7 +515,6 @@ async def get_execution(user_id: str, execution_id: str) -> GraphExecution | Non
|
||||
async def get_graph(
|
||||
graph_id: str,
|
||||
version: int | None = None,
|
||||
template: bool = False,
|
||||
user_id: str | None = None,
|
||||
for_export: bool = False,
|
||||
) -> GraphModel | None:
|
||||
@ -523,20 +528,38 @@ async def get_graph(
|
||||
where_clause: AgentGraphWhereInput = {
|
||||
"id": graph_id,
|
||||
}
|
||||
|
||||
if version is not None:
|
||||
where_clause["version"] = version
|
||||
elif not template:
|
||||
where_clause["isActive"] = True
|
||||
|
||||
# TODO: Fix hack workaround to get adding store agents to work
|
||||
if user_id is not None and not template:
|
||||
where_clause["userId"] = user_id
|
||||
|
||||
graph = await AgentGraph.prisma().find_first(
|
||||
where=where_clause,
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
order={"version": "desc"},
|
||||
)
|
||||
|
||||
if not graph:
|
||||
return None
|
||||
|
||||
if graph.userId == user_id:
|
||||
return GraphModel.from_db(graph, for_export) if graph else None
|
||||
|
||||
# If the graph is not owned by the user, we need to check if it's a store listing.
|
||||
if not version:
|
||||
version = graph.version
|
||||
|
||||
store_listing_where: StoreListingWhereInput = {
|
||||
"agentId": graph_id,
|
||||
"agentVersion": version,
|
||||
}
|
||||
|
||||
store_listing = await StoreListing.prisma().find_first(where=store_listing_where)
|
||||
|
||||
# If it is not a store listing, return None
|
||||
if not store_listing:
|
||||
return None
|
||||
|
||||
# If it is a store listing, return the graph model
|
||||
return GraphModel.from_db(graph, for_export) if graph else None
|
||||
|
||||
|
||||
@ -591,9 +614,7 @@ async def create_graph(graph: Graph, user_id: str) -> GraphModel:
|
||||
async with transaction() as tx:
|
||||
await __create_graph(tx, graph, user_id)
|
||||
|
||||
if created_graph := await get_graph(
|
||||
graph.id, graph.version, graph.is_template, user_id=user_id
|
||||
):
|
||||
if created_graph := await get_graph(graph.id, graph.version, user_id=user_id):
|
||||
return created_graph
|
||||
|
||||
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
|
||||
|
@ -769,7 +769,7 @@ class ExecutionManager(AppService):
|
||||
graph_id: str,
|
||||
data: BlockInput,
|
||||
user_id: str,
|
||||
graph_version: int | None = None,
|
||||
graph_version: int,
|
||||
) -> GraphExecutionEntry:
|
||||
graph: GraphModel | None = self.db_client.get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
|
@ -63,7 +63,10 @@ def execute_graph(**kwargs):
|
||||
try:
|
||||
log(f"Executing recurring job for graph #{args.graph_id}")
|
||||
get_execution_client().add_execution(
|
||||
args.graph_id, args.input_data, args.user_id
|
||||
graph_id=args.graph_id,
|
||||
data=args.input_data,
|
||||
user_id=args.user_id,
|
||||
graph_version=args.graph_version,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error executing graph {args.graph_id}: {e}")
|
||||
|
@ -317,7 +317,8 @@ async def webhook_ingress_generic(
|
||||
continue
|
||||
logger.debug(f"Executing graph #{node.graph_id} node #{node.id}")
|
||||
executor.add_execution(
|
||||
node.graph_id,
|
||||
graph_id=node.graph_id,
|
||||
graph_version=node.graph_version,
|
||||
data={f"webhook_{webhook_id}_payload": payload},
|
||||
user_id=webhook.user_id,
|
||||
)
|
||||
|
@ -117,9 +117,14 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
|
||||
@staticmethod
|
||||
async def test_execute_graph(
|
||||
graph_id: str, node_input: dict[typing.Any, typing.Any], user_id: str
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
node_input: dict[typing.Any, typing.Any],
|
||||
user_id: str,
|
||||
):
|
||||
return backend.server.routers.v1.execute_graph(graph_id, node_input, user_id)
|
||||
return backend.server.routers.v1.execute_graph(
|
||||
graph_id, graph_version, node_input, user_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_create_graph(
|
||||
|
@ -200,12 +200,11 @@ async def get_graph_all_versions(
|
||||
async def create_new_graph(
|
||||
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> graph_db.GraphModel:
|
||||
return await do_create_graph(create_graph, is_template=False, user_id=user_id)
|
||||
return await do_create_graph(create_graph, user_id=user_id)
|
||||
|
||||
|
||||
async def do_create_graph(
|
||||
create_graph: CreateGraph,
|
||||
is_template: bool,
|
||||
# user_id doesn't have to be annotated like on other endpoints,
|
||||
# because create_graph isn't used directly as an endpoint
|
||||
user_id: str,
|
||||
@ -217,7 +216,6 @@ async def do_create_graph(
|
||||
graph = await graph_db.get_graph(
|
||||
create_graph.template_id,
|
||||
create_graph.template_version,
|
||||
template=True,
|
||||
user_id=user_id,
|
||||
)
|
||||
if not graph:
|
||||
@ -230,8 +228,6 @@ async def do_create_graph(
|
||||
status_code=400, detail="Either graph or template_id must be provided."
|
||||
)
|
||||
|
||||
graph.is_template = is_template
|
||||
graph.is_active = not is_template
|
||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
||||
|
||||
graph = await graph_db.create_graph(graph, user_id=user_id)
|
||||
@ -368,12 +364,13 @@ async def set_graph_active_version(
|
||||
)
|
||||
def execute_graph(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
node_input: dict[Any, Any],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[str, Any]: # FIXME: add proper return type
|
||||
try:
|
||||
graph_exec = execution_manager_client().add_execution(
|
||||
graph_id, node_input, user_id=user_id
|
||||
graph_id, node_input, user_id=user_id, graph_version=graph_version
|
||||
)
|
||||
return {"id": graph_exec.graph_exec_id}
|
||||
except Exception as e:
|
||||
@ -452,7 +449,7 @@ async def get_templates(
|
||||
async def get_template(
|
||||
graph_id: str, version: int | None = None
|
||||
) -> graph_db.GraphModel:
|
||||
graph = await graph_db.get_graph(graph_id, version, template=True)
|
||||
graph = await graph_db.get_graph(graph_id, version)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Template #{graph_id} not found.")
|
||||
return graph
|
||||
@ -466,7 +463,7 @@ async def get_template(
|
||||
async def create_new_template(
|
||||
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> graph_db.GraphModel:
|
||||
return await do_create_graph(create_graph, is_template=True, user_id=user_id)
|
||||
return await do_create_graph(create_graph, user_id=user_id)
|
||||
|
||||
|
||||
########################################################
|
||||
|
@ -91,7 +91,7 @@ async def add_agent_to_library(
|
||||
|
||||
# Create a new graph from the template
|
||||
graph = await backend.data.graph.get_graph(
|
||||
agent.id, agent.version, template=True, user_id=user_id
|
||||
agent.id, agent.version, user_id=user_id
|
||||
)
|
||||
|
||||
if not graph:
|
||||
|
@ -253,7 +253,7 @@ async def block_autogen_agent():
|
||||
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
||||
input_data = {"input": "Write me a block that writes a string into a file."}
|
||||
response = await server.agent_server.test_execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
test_graph.id, test_graph.version, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
result = await wait_execution(
|
||||
|
@ -157,7 +157,7 @@ async def reddit_marketing_agent():
|
||||
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
||||
input_data = {"subreddit": "AutoGPT"}
|
||||
response = await server.agent_server.test_execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
test_graph.id, test_graph.version, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
result = await wait_execution(test_user.id, test_graph.id, response["id"], 120)
|
||||
|
@ -79,7 +79,7 @@ async def sample_agent():
|
||||
test_graph = await create_graph(create_test_graph(), test_user.id)
|
||||
input_data = {"input_1": "Hello", "input_2": "World"}
|
||||
response = await server.agent_server.test_execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
test_graph.id, test_graph.version, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
result = await wait_execution(test_user.id, test_graph.id, response["id"], 10)
|
||||
|
@ -31,7 +31,7 @@ async def execute_graph(
|
||||
|
||||
# --- Test adding new executions --- #
|
||||
response = await agent_server.test_execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
test_graph.id, test_graph.version, input_data, test_user.id
|
||||
)
|
||||
graph_exec_id = response["id"]
|
||||
logger.info(f"Created execution with ID: {graph_exec_id}")
|
||||
|
Loading…
Reference in New Issue
Block a user