test(agent): Fix VCRpy request header filter for cross-platform cassette reuse (#7040)

- Move filtering logic from tests/vcr/__init__.py to tests/vcr/vcr_filter.py
- Ignore all `X-Stainless-*` headers for cassette matching, e.g. `X-Stainless-OS` and `X-Stainless-Runtime-Version`
- Remove deprecated OpenAI proxy logic
- Reorder methods in vcr_filter.py for readability
This commit is contained in:
Reinier van der Leer 2024-03-22 13:08:15 +01:00 committed by GitHub
parent 20041d65bf
commit 6dd76afad5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 64 additions and 88 deletions

View File

@ -10,7 +10,6 @@ from openai._utils import is_given
from pytest_mock import MockerFixture
from .vcr_filter import (
PROXY,
before_record_request,
before_record_response,
freeze_request_body,
@ -20,15 +19,6 @@ DEFAULT_RECORD_MODE = "new_episodes"
BASE_VCR_CONFIG = {
"before_record_request": before_record_request,
"before_record_response": before_record_response,
"filter_headers": [
"Authorization",
"AGENT-MODE",
"AGENT-TYPE",
"Cookie",
"OpenAI-Organization",
"X-OpenAI-Client-User-Agent",
"User-Agent",
],
"match_on": ["method", "headers"],
}
@ -69,10 +59,6 @@ def cached_openai_client(mocker: MockerFixture) -> OpenAI:
options.headers = headers
data: dict = options.json_data
if PROXY:
headers["AGENT-MODE"] = os.environ.get("AGENT_MODE", Omit())
headers["AGENT-TYPE"] = os.environ.get("AGENT_TYPE", Omit())
logging.getLogger("cached_openai_client").debug(
f"Outgoing API request: {headers}\n{data if data else None}"
)
@ -82,8 +68,6 @@ def cached_openai_client(mocker: MockerFixture) -> OpenAI:
freeze_request_body(data), usedforsecurity=False
).hexdigest()
if PROXY:
client.base_url = f"{PROXY}/v1"
mocker.patch.object(
client,
"_prepare_options",

View File

@ -1,16 +1,27 @@
import contextlib
import json
import os
import re
from io import BytesIO
from typing import Any, Dict, List
from urllib.parse import urlparse, urlunparse
from typing import Any
from vcr.request import Request
PROXY = os.environ.get("PROXY")
HOSTNAMES_TO_CACHE: list[str] = [
"api.openai.com",
"localhost:50337",
"duckduckgo.com",
]
REPLACEMENTS: List[Dict[str, str]] = [
IGNORE_REQUEST_HEADERS: set[str | re.Pattern] = {
"Authorization",
"Cookie",
"OpenAI-Organization",
"X-OpenAI-Client-User-Agent",
"User-Agent",
re.compile(r"X-Stainless-[\w\-]+", re.IGNORECASE),
}
LLM_MESSAGE_REPLACEMENTS: list[dict[str, str]] = [
{
"regex": r"\w{3} \w{3} {1,2}\d{1,2} \d{2}:\d{2}:\d{2} \d{4}",
"replacement": "Tue Jan 1 00:00:00 2000",
@ -21,46 +32,33 @@ REPLACEMENTS: List[Dict[str, str]] = [
},
]
ALLOWED_HOSTNAMES: List[str] = [
"api.openai.com",
"localhost:50337",
"duckduckgo.com",
]
if PROXY:
ALLOWED_HOSTNAMES.append(PROXY)
ORIGINAL_URL = PROXY
else:
ORIGINAL_URL = "no_ci"
NEW_URL = "api.openai.com"
OPENAI_URL = "api.openai.com"
def replace_message_content(content: str, replacements: List[Dict[str, str]]) -> str:
for replacement in replacements:
pattern = re.compile(replacement["regex"])
content = pattern.sub(replacement["replacement"], content)
def before_record_request(request: Request) -> Request | None:
if not should_cache_request(request):
return None
return content
request = filter_request_headers(request)
request = freeze_request(request)
return request
def freeze_request_body(body: dict) -> bytes:
"""Remove any dynamic items from the request body"""
def should_cache_request(request: Request) -> bool:
return any(hostname in request.url for hostname in HOSTNAMES_TO_CACHE)
if "messages" not in body:
return json.dumps(body, sort_keys=True).encode()
if "max_tokens" in body:
del body["max_tokens"]
for message in body["messages"]:
if "content" in message and "role" in message:
if message["role"] == "system":
message["content"] = replace_message_content(
message["content"], REPLACEMENTS
)
return json.dumps(body, sort_keys=True).encode()
def filter_request_headers(request: Request) -> Request:
for header_name in list(request.headers):
if any(
(
(type(ignore) is str and ignore.lower() == header_name.lower())
or (isinstance(ignore, re.Pattern) and ignore.match(header_name))
)
for ignore in IGNORE_REQUEST_HEADERS
):
del request.headers[header_name]
return request
def freeze_request(request: Request) -> Request:
@ -79,40 +77,34 @@ def freeze_request(request: Request) -> Request:
return request
def before_record_response(response: Dict[str, Any]) -> Dict[str, Any]:
def freeze_request_body(body: dict) -> bytes:
"""Remove any dynamic items from the request body"""
if "messages" not in body:
return json.dumps(body, sort_keys=True).encode()
if "max_tokens" in body:
del body["max_tokens"]
for message in body["messages"]:
if "content" in message and "role" in message:
if message["role"] == "system":
message["content"] = replace_message_content(
message["content"], LLM_MESSAGE_REPLACEMENTS
)
return json.dumps(body, sort_keys=True).encode()
def replace_message_content(content: str, replacements: list[dict[str, str]]) -> str:
for replacement in replacements:
pattern = re.compile(replacement["regex"])
content = pattern.sub(replacement["replacement"], content)
return content
def before_record_response(response: dict[str, Any]) -> dict[str, Any]:
if "Transfer-Encoding" in response["headers"]:
del response["headers"]["Transfer-Encoding"]
return response
def before_record_request(request: Request) -> Request | None:
request = replace_request_hostname(request, ORIGINAL_URL, NEW_URL)
filtered_request = filter_hostnames(request)
if not filtered_request:
return None
filtered_request_without_dynamic_data = freeze_request(filtered_request)
return filtered_request_without_dynamic_data
def replace_request_hostname(
request: Request, original_url: str, new_hostname: str
) -> Request:
parsed_url = urlparse(request.uri)
if parsed_url.hostname in original_url:
new_path = parsed_url.path.replace("/proxy_function", "")
request.uri = urlunparse(
parsed_url._replace(netloc=new_hostname, path=new_path, scheme="https")
)
return request
def filter_hostnames(request: Request) -> Request | None:
# Add your implementation here for filtering hostnames
if any(hostname in request.url for hostname in ALLOWED_HOSTNAMES):
return request
else:
return None