fix(backend): Add migrations to fix credentials inputs with invalid provider "llm"(#8674)

In #8524, the "llm" credentials provider was replaced. There are still entries with 	"provider": "llm"	 in the system though, and those break if not migrated.

- SQL migration to fix the obvious ones where we know the provider from `credentials.id`
- Non-SQL migration to fix the rest
This commit is contained in:
Reinier van der Leer 2024-11-15 20:18:02 +01:00 committed by GitHub
parent 0551bec096
commit 4db8e746d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 97 additions and 0 deletions

View File

@ -5,6 +5,7 @@ from collections import defaultdict
from datetime import datetime, timezone
from typing import Any, Literal, Type
import prisma
from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink
from prisma.types import AgentGraphWhereInput
from pydantic.fields import computed_field
@ -528,3 +529,84 @@ async def __create_graph(tx, graph: Graph, user_id: str):
for link in graph.links
]
)
# ------------------------ UTILITIES ------------------------ #
async def fix_llm_provider_credentials():
"""Fix node credentials with provider `llm`"""
from autogpt_libs.supabase_integration_credentials_store import (
SupabaseIntegrationCredentialsStore,
)
from .redis import get_redis
from .user import get_user_integrations
store = SupabaseIntegrationCredentialsStore(get_redis())
broken_nodes = await prisma.get_client().query_raw(
"""
SELECT user.id user_id,
node.id node_id,
node."constantInput" node_preset_input
FROM platform."AgentGraph" graph
LEFT JOIN platform."AgentNode" node
ON node."agentGraphId" = graph.id
LEFT JOIN platform."User" user
ON graph."userId" = user.id
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
ORDER BY user_id;
"""
)
logger.info(f"Fixing LLM credential inputs on {len(broken_nodes)} nodes")
user_id: str = ""
user_integrations = None
for node in broken_nodes:
if node["user_id"] != user_id:
# Save queries by only fetching once per user
user_id = node["user_id"]
user_integrations = await get_user_integrations(user_id)
elif not user_integrations:
raise RuntimeError(f"Impossible state while processing node {node}")
node_id: str = node["node_id"]
node_preset_input: dict = json.loads(node["node_preset_input"])
credentials_meta: dict = node_preset_input["credentials"]
credentials = next(
(
c
for c in user_integrations.credentials
if c.id == credentials_meta["id"]
),
None,
)
if not credentials:
continue
if credentials.type != "api_key":
logger.warning(
f"User {user_id} credentials {credentials.id} with provider 'llm' "
f"has invalid type '{credentials.type}'"
)
continue
api_key = credentials.api_key.get_secret_value()
if api_key.startswith("sk-ant-api03-"):
credentials.provider = credentials_meta["provider"] = "anthropic"
elif api_key.startswith("sk-"):
credentials.provider = credentials_meta["provider"] = "openai"
elif api_key.startswith("gsk_"):
credentials.provider = credentials_meta["provider"] = "groq"
else:
logger.warning(
f"Could not identify provider from key prefix {api_key[:13]}*****"
)
continue
store.update_creds(user_id, credentials)
await AgentNode.prisma().update(
where={"id": node_id},
data={"constantInput": json.dumps(node_preset_input)},
)

View File

@ -9,6 +9,7 @@ import uvicorn
import backend.data.block
import backend.data.db
import backend.data.graph
import backend.data.user
import backend.server.routers.v1
import backend.util.service
@ -23,6 +24,7 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.db.connect()
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()
yield
await backend.data.db.disconnect()

View File

@ -0,0 +1,13 @@
-- Correct credentials.provider field on all nodes with 'llm' provider credentials
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{credentials,provider}',
CASE
WHEN "constantInput"::jsonb->'credentials'->>'id' = '53c25cb8-e3ee-465c-a4d1-e75a4c899c2a' THEN '"openai"'::jsonb
WHEN "constantInput"::jsonb->'credentials'->>'id' = '24e5d942-d9e3-4798-8151-90143ee55629' THEN '"anthropic"'::jsonb
WHEN "constantInput"::jsonb->'credentials'->>'id' = '4ec22295-8f97-4dd1-b42b-2c6957a02545' THEN '"groq"'::jsonb
ELSE ("constantInput"::jsonb->'credentials'->>'provider')::jsonb
END
)::text
WHERE "constantInput"::jsonb->'credentials'->>'provider' = 'llm';