mirror of
https://github.com/Significant-Gravitas/Auto-GPT.git
synced 2025-01-09 04:19:02 +08:00
test(agent): Fix VCRpy request header filter for cross-platform cassette reuse (#7040)
- Move filtering logic from tests/vcr/__init__.py to tests/vcr/vcr_filter.py - Ignore all `X-Stainless-*` headers for cassette matching, e.g. `X-Stainless-OS` and `X-Stainless-Runtime-Version` - Remove deprecated OpenAI proxy logic - Reorder methods in vcr_filter.py for readability
This commit is contained in:
parent
20041d65bf
commit
6dd76afad5
@ -10,7 +10,6 @@ from openai._utils import is_given
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from .vcr_filter import (
|
||||
PROXY,
|
||||
before_record_request,
|
||||
before_record_response,
|
||||
freeze_request_body,
|
||||
@ -20,15 +19,6 @@ DEFAULT_RECORD_MODE = "new_episodes"
|
||||
BASE_VCR_CONFIG = {
|
||||
"before_record_request": before_record_request,
|
||||
"before_record_response": before_record_response,
|
||||
"filter_headers": [
|
||||
"Authorization",
|
||||
"AGENT-MODE",
|
||||
"AGENT-TYPE",
|
||||
"Cookie",
|
||||
"OpenAI-Organization",
|
||||
"X-OpenAI-Client-User-Agent",
|
||||
"User-Agent",
|
||||
],
|
||||
"match_on": ["method", "headers"],
|
||||
}
|
||||
|
||||
@ -69,10 +59,6 @@ def cached_openai_client(mocker: MockerFixture) -> OpenAI:
|
||||
options.headers = headers
|
||||
data: dict = options.json_data
|
||||
|
||||
if PROXY:
|
||||
headers["AGENT-MODE"] = os.environ.get("AGENT_MODE", Omit())
|
||||
headers["AGENT-TYPE"] = os.environ.get("AGENT_TYPE", Omit())
|
||||
|
||||
logging.getLogger("cached_openai_client").debug(
|
||||
f"Outgoing API request: {headers}\n{data if data else None}"
|
||||
)
|
||||
@ -82,8 +68,6 @@ def cached_openai_client(mocker: MockerFixture) -> OpenAI:
|
||||
freeze_request_body(data), usedforsecurity=False
|
||||
).hexdigest()
|
||||
|
||||
if PROXY:
|
||||
client.base_url = f"{PROXY}/v1"
|
||||
mocker.patch.object(
|
||||
client,
|
||||
"_prepare_options",
|
||||
|
@ -1,16 +1,27 @@
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
from typing import Any
|
||||
|
||||
from vcr.request import Request
|
||||
|
||||
PROXY = os.environ.get("PROXY")
|
||||
HOSTNAMES_TO_CACHE: list[str] = [
|
||||
"api.openai.com",
|
||||
"localhost:50337",
|
||||
"duckduckgo.com",
|
||||
]
|
||||
|
||||
REPLACEMENTS: List[Dict[str, str]] = [
|
||||
IGNORE_REQUEST_HEADERS: set[str | re.Pattern] = {
|
||||
"Authorization",
|
||||
"Cookie",
|
||||
"OpenAI-Organization",
|
||||
"X-OpenAI-Client-User-Agent",
|
||||
"User-Agent",
|
||||
re.compile(r"X-Stainless-[\w\-]+", re.IGNORECASE),
|
||||
}
|
||||
|
||||
LLM_MESSAGE_REPLACEMENTS: list[dict[str, str]] = [
|
||||
{
|
||||
"regex": r"\w{3} \w{3} {1,2}\d{1,2} \d{2}:\d{2}:\d{2} \d{4}",
|
||||
"replacement": "Tue Jan 1 00:00:00 2000",
|
||||
@ -21,46 +32,33 @@ REPLACEMENTS: List[Dict[str, str]] = [
|
||||
},
|
||||
]
|
||||
|
||||
ALLOWED_HOSTNAMES: List[str] = [
|
||||
"api.openai.com",
|
||||
"localhost:50337",
|
||||
"duckduckgo.com",
|
||||
]
|
||||
|
||||
if PROXY:
|
||||
ALLOWED_HOSTNAMES.append(PROXY)
|
||||
ORIGINAL_URL = PROXY
|
||||
else:
|
||||
ORIGINAL_URL = "no_ci"
|
||||
|
||||
NEW_URL = "api.openai.com"
|
||||
OPENAI_URL = "api.openai.com"
|
||||
|
||||
|
||||
def replace_message_content(content: str, replacements: List[Dict[str, str]]) -> str:
|
||||
for replacement in replacements:
|
||||
pattern = re.compile(replacement["regex"])
|
||||
content = pattern.sub(replacement["replacement"], content)
|
||||
def before_record_request(request: Request) -> Request | None:
|
||||
if not should_cache_request(request):
|
||||
return None
|
||||
|
||||
return content
|
||||
request = filter_request_headers(request)
|
||||
request = freeze_request(request)
|
||||
return request
|
||||
|
||||
|
||||
def freeze_request_body(body: dict) -> bytes:
|
||||
"""Remove any dynamic items from the request body"""
|
||||
def should_cache_request(request: Request) -> bool:
|
||||
return any(hostname in request.url for hostname in HOSTNAMES_TO_CACHE)
|
||||
|
||||
if "messages" not in body:
|
||||
return json.dumps(body, sort_keys=True).encode()
|
||||
|
||||
if "max_tokens" in body:
|
||||
del body["max_tokens"]
|
||||
|
||||
for message in body["messages"]:
|
||||
if "content" in message and "role" in message:
|
||||
if message["role"] == "system":
|
||||
message["content"] = replace_message_content(
|
||||
message["content"], REPLACEMENTS
|
||||
)
|
||||
|
||||
return json.dumps(body, sort_keys=True).encode()
|
||||
def filter_request_headers(request: Request) -> Request:
|
||||
for header_name in list(request.headers):
|
||||
if any(
|
||||
(
|
||||
(type(ignore) is str and ignore.lower() == header_name.lower())
|
||||
or (isinstance(ignore, re.Pattern) and ignore.match(header_name))
|
||||
)
|
||||
for ignore in IGNORE_REQUEST_HEADERS
|
||||
):
|
||||
del request.headers[header_name]
|
||||
return request
|
||||
|
||||
|
||||
def freeze_request(request: Request) -> Request:
|
||||
@ -79,40 +77,34 @@ def freeze_request(request: Request) -> Request:
|
||||
return request
|
||||
|
||||
|
||||
def before_record_response(response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def freeze_request_body(body: dict) -> bytes:
|
||||
"""Remove any dynamic items from the request body"""
|
||||
|
||||
if "messages" not in body:
|
||||
return json.dumps(body, sort_keys=True).encode()
|
||||
|
||||
if "max_tokens" in body:
|
||||
del body["max_tokens"]
|
||||
|
||||
for message in body["messages"]:
|
||||
if "content" in message and "role" in message:
|
||||
if message["role"] == "system":
|
||||
message["content"] = replace_message_content(
|
||||
message["content"], LLM_MESSAGE_REPLACEMENTS
|
||||
)
|
||||
|
||||
return json.dumps(body, sort_keys=True).encode()
|
||||
|
||||
|
||||
def replace_message_content(content: str, replacements: list[dict[str, str]]) -> str:
|
||||
for replacement in replacements:
|
||||
pattern = re.compile(replacement["regex"])
|
||||
content = pattern.sub(replacement["replacement"], content)
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def before_record_response(response: dict[str, Any]) -> dict[str, Any]:
|
||||
if "Transfer-Encoding" in response["headers"]:
|
||||
del response["headers"]["Transfer-Encoding"]
|
||||
return response
|
||||
|
||||
|
||||
def before_record_request(request: Request) -> Request | None:
|
||||
request = replace_request_hostname(request, ORIGINAL_URL, NEW_URL)
|
||||
|
||||
filtered_request = filter_hostnames(request)
|
||||
if not filtered_request:
|
||||
return None
|
||||
|
||||
filtered_request_without_dynamic_data = freeze_request(filtered_request)
|
||||
return filtered_request_without_dynamic_data
|
||||
|
||||
|
||||
def replace_request_hostname(
|
||||
request: Request, original_url: str, new_hostname: str
|
||||
) -> Request:
|
||||
parsed_url = urlparse(request.uri)
|
||||
|
||||
if parsed_url.hostname in original_url:
|
||||
new_path = parsed_url.path.replace("/proxy_function", "")
|
||||
request.uri = urlunparse(
|
||||
parsed_url._replace(netloc=new_hostname, path=new_path, scheme="https")
|
||||
)
|
||||
|
||||
return request
|
||||
|
||||
|
||||
def filter_hostnames(request: Request) -> Request | None:
|
||||
# Add your implementation here for filtering hostnames
|
||||
if any(hostname in request.url for hostname in ALLOWED_HOSTNAMES):
|
||||
return request
|
||||
else:
|
||||
return None
|
||||
|
Loading…
Reference in New Issue
Block a user