Allow execution of store listings

This commit is contained in:
SwiftyOS 2025-01-03 11:26:28 +01:00
parent 745aae4aec
commit 44659948e5
11 changed files with 57 additions and 30 deletions

View File

@ -6,8 +6,14 @@ from datetime import datetime, timezone
from typing import Any, Literal, Optional, Type
import prisma
from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink
from prisma.types import AgentGraphWhereInput
from prisma.models import (
AgentGraph,
AgentGraphExecution,
AgentNode,
AgentNodeLink,
StoreListing,
)
from prisma.types import AgentGraphWhereInput, StoreListingWhereInput
from pydantic.fields import computed_field
from backend.blocks.agent import AgentExecutorBlock
@ -509,7 +515,6 @@ async def get_execution(user_id: str, execution_id: str) -> GraphExecution | Non
async def get_graph(
graph_id: str,
version: int | None = None,
template: bool = False,
user_id: str | None = None,
for_export: bool = False,
) -> GraphModel | None:
@ -523,20 +528,38 @@ async def get_graph(
where_clause: AgentGraphWhereInput = {
"id": graph_id,
}
if version is not None:
where_clause["version"] = version
elif not template:
where_clause["isActive"] = True
# TODO: Fix hack workaround to get adding store agents to work
if user_id is not None and not template:
where_clause["userId"] = user_id
graph = await AgentGraph.prisma().find_first(
where=where_clause,
include=AGENT_GRAPH_INCLUDE,
order={"version": "desc"},
)
if not graph:
return None
if graph.userId == user_id:
return GraphModel.from_db(graph, for_export) if graph else None
# If the graph is not owned by the user, we need to check if it's a store listing.
if not version:
version = graph.version
store_listing_where: StoreListingWhereInput = {
"agentId": graph_id,
"agentVersion": version,
}
store_listing = await StoreListing.prisma().find_first(where=store_listing_where)
# If it is not a store listing, return None
if not store_listing:
return None
# If it is a store listing, return the graph model
return GraphModel.from_db(graph, for_export) if graph else None
@ -591,9 +614,7 @@ async def create_graph(graph: Graph, user_id: str) -> GraphModel:
async with transaction() as tx:
await __create_graph(tx, graph, user_id)
if created_graph := await get_graph(
graph.id, graph.version, graph.is_template, user_id=user_id
):
if created_graph := await get_graph(graph.id, graph.version, user_id=user_id):
return created_graph
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")

View File

@ -769,7 +769,7 @@ class ExecutionManager(AppService):
graph_id: str,
data: BlockInput,
user_id: str,
graph_version: int | None = None,
graph_version: int,
) -> GraphExecutionEntry:
graph: GraphModel | None = self.db_client.get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version

View File

@ -63,7 +63,10 @@ def execute_graph(**kwargs):
try:
log(f"Executing recurring job for graph #{args.graph_id}")
get_execution_client().add_execution(
args.graph_id, args.input_data, args.user_id
graph_id=args.graph_id,
data=args.input_data,
user_id=args.user_id,
graph_version=args.graph_version,
)
except Exception as e:
logger.exception(f"Error executing graph {args.graph_id}: {e}")

View File

@ -317,7 +317,8 @@ async def webhook_ingress_generic(
continue
logger.debug(f"Executing graph #{node.graph_id} node #{node.id}")
executor.add_execution(
node.graph_id,
graph_id=node.graph_id,
graph_version=node.graph_version,
data={f"webhook_{webhook_id}_payload": payload},
user_id=webhook.user_id,
)

View File

@ -117,9 +117,14 @@ class AgentServer(backend.util.service.AppProcess):
@staticmethod
async def test_execute_graph(
graph_id: str, node_input: dict[typing.Any, typing.Any], user_id: str
graph_id: str,
graph_version: int,
node_input: dict[typing.Any, typing.Any],
user_id: str,
):
return backend.server.routers.v1.execute_graph(graph_id, node_input, user_id)
return backend.server.routers.v1.execute_graph(
graph_id, graph_version, node_input, user_id
)
@staticmethod
async def test_create_graph(

View File

@ -200,12 +200,11 @@ async def get_graph_all_versions(
async def create_new_graph(
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.GraphModel:
return await do_create_graph(create_graph, is_template=False, user_id=user_id)
return await do_create_graph(create_graph, user_id=user_id)
async def do_create_graph(
create_graph: CreateGraph,
is_template: bool,
# user_id doesn't have to be annotated like on other endpoints,
# because create_graph isn't used directly as an endpoint
user_id: str,
@ -217,7 +216,6 @@ async def do_create_graph(
graph = await graph_db.get_graph(
create_graph.template_id,
create_graph.template_version,
template=True,
user_id=user_id,
)
if not graph:
@ -230,8 +228,6 @@ async def do_create_graph(
status_code=400, detail="Either graph or template_id must be provided."
)
graph.is_template = is_template
graph.is_active = not is_template
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
graph = await graph_db.create_graph(graph, user_id=user_id)
@ -368,12 +364,13 @@ async def set_graph_active_version(
)
def execute_graph(
graph_id: str,
graph_version: int,
node_input: dict[Any, Any],
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[str, Any]: # FIXME: add proper return type
try:
graph_exec = execution_manager_client().add_execution(
graph_id, node_input, user_id=user_id
graph_id, node_input, user_id=user_id, graph_version=graph_version
)
return {"id": graph_exec.graph_exec_id}
except Exception as e:
@ -452,7 +449,7 @@ async def get_templates(
async def get_template(
graph_id: str, version: int | None = None
) -> graph_db.GraphModel:
graph = await graph_db.get_graph(graph_id, version, template=True)
graph = await graph_db.get_graph(graph_id, version)
if not graph:
raise HTTPException(status_code=404, detail=f"Template #{graph_id} not found.")
return graph
@ -466,7 +463,7 @@ async def get_template(
async def create_new_template(
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.GraphModel:
return await do_create_graph(create_graph, is_template=True, user_id=user_id)
return await do_create_graph(create_graph, user_id=user_id)
########################################################

View File

@ -91,7 +91,7 @@ async def add_agent_to_library(
# Create a new graph from the template
graph = await backend.data.graph.get_graph(
agent.id, agent.version, template=True, user_id=user_id
agent.id, agent.version, user_id=user_id
)
if not graph:

View File

@ -253,7 +253,7 @@ async def block_autogen_agent():
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"input": "Write me a block that writes a string into a file."}
response = await server.agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id
test_graph.id, test_graph.version, input_data, test_user.id
)
print(response)
result = await wait_execution(

View File

@ -157,7 +157,7 @@ async def reddit_marketing_agent():
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"subreddit": "AutoGPT"}
response = await server.agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id
test_graph.id, test_graph.version, input_data, test_user.id
)
print(response)
result = await wait_execution(test_user.id, test_graph.id, response["id"], 120)

View File

@ -79,7 +79,7 @@ async def sample_agent():
test_graph = await create_graph(create_test_graph(), test_user.id)
input_data = {"input_1": "Hello", "input_2": "World"}
response = await server.agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id
test_graph.id, test_graph.version, input_data, test_user.id
)
print(response)
result = await wait_execution(test_user.id, test_graph.id, response["id"], 10)

View File

@ -31,7 +31,7 @@ async def execute_graph(
# --- Test adding new executions --- #
response = await agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id
test_graph.id, test_graph.version, input_data, test_user.id
)
graph_exec_id = response["id"]
logger.info(f"Created execution with ID: {graph_exec_id}")