mirror of
https://github.com/Significant-Gravitas/Auto-GPT.git
synced 2025-01-08 11:57:32 +08:00
Set up unified pre-commit + CI w/ linting + type checking & FIX EVERYTHING (#7171)
- **FIX ALL LINT/TYPE ERRORS IN AUTOGPT, FORGE, AND BENCHMARK** ### Linting - Clean up linter configs for `autogpt`, `forge`, and `benchmark` - Add type checking with Pyright - Create unified pre-commit config - Create unified linting and type checking CI workflow ### Testing - Synchronize CI test setups for `autogpt`, `forge`, and `benchmark` - Add missing pytest-cov to benchmark dependencies - Mark GCS tests as slow to speed up pre-commit test runs - Repair `forge` test suite - Add `AgentDB.close()` method for test DB teardown in db_test.py - Use actual temporary dir instead of forge/test_workspace/ - Move left-behind dependencies for moved `forge`-code to from autogpt to forge ### Notable type changes - Replace uses of `ChatModelProvider` by `MultiProvider` - Removed unnecessary exports from various __init__.py - Simplify `FileStorage.open_file` signature by removing `IOBase` from return type union - Implement `S3BinaryIOWrapper(BinaryIO)` type interposer for `S3FileStorage` - Expand overloads of `GCSFileStorage.open_file` for improved typing of read and write modes Had to silence type checking for the extra overloads, because (I think) Pyright is reporting a false-positive: https://github.com/microsoft/pyright/issues/8007 - Change `count_tokens`, `get_tokenizer`, `count_message_tokens` methods on `ModelProvider`s from class methods to instance methods - Move `CompletionModelFunction.schema` method -> helper function `format_function_def_for_openai` in `forge.llm.providers.openai` - Rename `ModelProvider` -> `BaseModelProvider` - Rename `ChatModelProvider` -> `BaseChatModelProvider` - Add type `ChatModelProvider` which is a union of all subclasses of `BaseChatModelProvider` ### Removed rather than fixed - Remove deprecated and broken autogpt/agbenchmark_config/benchmarks.py - Various base classes and properties on base classes in `forge.llm.providers.schema` and `forge.models.providers` ### Fixes for other issues that came to light - Clean up `forge.agent_protocol.api_router`, `forge.agent_protocol.database`, and `forge.agent.agent` - Add fallback behavior to `ImageGeneratorComponent` - Remove test for deprecated failure behavior - Fix `agbenchmark.challenges.builtin` challenge exclusion mechanism on Windows - Fix `_tool_calls_compat_extract_calls` in `forge.llm.providers.openai` - Add support for `any` (= no type specified) in `JSONSchema.typescript_type`
This commit is contained in:
parent
2c13a2706c
commit
f107ff8cf0
53
.github/workflows/autogpt-ci.yml
vendored
53
.github/workflows/autogpt-ci.yml
vendored
@ -1,4 +1,4 @@
|
||||
name: AutoGPT Python CI
|
||||
name: AutoGPT CI
|
||||
|
||||
on:
|
||||
push:
|
||||
@ -24,57 +24,6 @@ defaults:
|
||||
working-directory: autogpt
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
min-python-version: "3.10"
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python ${{ env.min-python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- id: get_date
|
||||
name: Get date
|
||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles('autogpt/pyproject.toml') }}-${{ steps.get_date.outputs.date }}
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
poetry install
|
||||
|
||||
- name: Lint with flake8
|
||||
run: poetry run flake8
|
||||
|
||||
- name: Check black formatting
|
||||
run: poetry run black . --check
|
||||
if: success() || failure()
|
||||
|
||||
- name: Check isort formatting
|
||||
run: poetry run isort . --check
|
||||
if: success() || failure()
|
||||
|
||||
# - name: Check mypy formatting
|
||||
# run: poetry run mypy
|
||||
# if: success() || failure()
|
||||
|
||||
# - name: Check for unused imports and pass statements
|
||||
# run: |
|
||||
# cmd="autoflake --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring autogpt tests"
|
||||
# poetry run $cmd --check || (echo "You have unused imports or pass statements, please run '${cmd} --in-place'" && exit 1)
|
||||
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
|
4
.github/workflows/autogpts-ci.yml
vendored
4
.github/workflows/autogpts-ci.yml
vendored
@ -1,4 +1,4 @@
|
||||
name: AutoGPTs smoke test CI
|
||||
name: Agent smoke tests
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
@ -28,7 +28,7 @@ on:
|
||||
- '!**/*.md'
|
||||
|
||||
jobs:
|
||||
run-tests:
|
||||
serve-agent-protocol:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
|
102
.github/workflows/benchmark-ci.yml
vendored
102
.github/workflows/benchmark-ci.yml
vendored
@ -1,4 +1,4 @@
|
||||
name: Benchmark CI
|
||||
name: AGBenchmark CI
|
||||
|
||||
on:
|
||||
push:
|
||||
@ -14,62 +14,91 @@ on:
|
||||
- '!benchmark/reports/**'
|
||||
- .github/workflows/benchmark-ci.yml
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('benchmark-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
cancel-in-progress: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
env:
|
||||
min-python-version: '3.10'
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: benchmark
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Set up Python ${{ env.min-python-version }}
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- id: get_date
|
||||
name: Get date
|
||||
working-directory: ./benchmark/
|
||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('benchmark/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
working-directory: ./benchmark/
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: ./benchmark/
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
export POETRY_VIRTUALENVS_IN_PROJECT=true
|
||||
poetry install -vvv
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
- name: Lint with flake8
|
||||
working-directory: ./benchmark/
|
||||
run: poetry run flake8
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
|
||||
- name: Check black formatting
|
||||
working-directory: ./benchmark/
|
||||
run: poetry run black . --exclude test.py --check
|
||||
if: success() || failure()
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Check isort formatting
|
||||
working-directory: ./benchmark/
|
||||
run: poetry run isort . --check
|
||||
if: success() || failure()
|
||||
|
||||
- name: Check for unused imports and pass statements
|
||||
working-directory: ./benchmark/
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
cmd="poetry run autoflake --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring agbenchmark"
|
||||
$cmd --check || (echo "You have unused imports or pass statements, please run '${cmd} --in-place'" && exit 1)
|
||||
if: success() || failure()
|
||||
poetry run pytest -vv \
|
||||
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--durations=10 \
|
||||
tests
|
||||
env:
|
||||
CI: true
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
tests-agbenchmark:
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: agbenchmark,${{ runner.os }}
|
||||
|
||||
self-test-with-agent:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
@ -89,11 +118,11 @@ jobs:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- name: Install Poetry
|
||||
working-directory: ./${{ matrix.agent-name }}/
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
|
||||
- name: Run regression tests
|
||||
working-directory: .
|
||||
run: |
|
||||
./run agent start ${{ matrix.agent-name }}
|
||||
cd ${{ matrix.agent-name }}
|
||||
@ -125,7 +154,6 @@ jobs:
|
||||
export BUILD_SKILL_TREE=true
|
||||
|
||||
poetry run agbenchmark --mock
|
||||
poetry run pytest -vv -s tests
|
||||
|
||||
CHANGED=$(git diff --name-only | grep -E '(agbenchmark/challenges)|(../frontend/assets)') || echo "No diffs"
|
||||
if [ ! -z "$CHANGED" ]; then
|
||||
|
129
.github/workflows/forge-ci.yml
vendored
Normal file
129
.github/workflows/forge-ci.yml
vendored
Normal file
@ -0,0 +1,129 @@
|
||||
name: Forge CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master, development, ci-test* ]
|
||||
paths:
|
||||
- '.github/workflows/forge-ci.yml'
|
||||
- 'forge/**'
|
||||
pull_request:
|
||||
branches: [ master, development, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/forge-ci.yml'
|
||||
- 'forge/**'
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('forge-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
cancel-in-progress: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: forge
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
|
||||
steps:
|
||||
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||
# - name: Set up Docker (macOS)
|
||||
# if: runner.os == 'macOS'
|
||||
# uses: crazy-max/ghaction-setup-docker@v3
|
||||
|
||||
- name: Start MinIO service (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
working-directory: '.'
|
||||
run: |
|
||||
docker pull minio/minio:edge-cicd
|
||||
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||
|
||||
- name: Start MinIO service (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
working-directory: ${{ runner.temp }}
|
||||
run: |
|
||||
brew install minio/stable/minio
|
||||
mkdir data
|
||||
minio server ./data &
|
||||
|
||||
# No MinIO on Windows:
|
||||
# - Windows doesn't support running Linux Docker containers
|
||||
# - It doesn't seem possible to start background processes on Windows. They are
|
||||
# killed after the step returns.
|
||||
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('forge/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
poetry run pytest -vv \
|
||||
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--durations=10 \
|
||||
forge
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||
AWS_ACCESS_KEY_ID: minioadmin
|
||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: forge,${{ runner.os }}
|
||||
|
||||
- name: Upload logs to artifact
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-logs
|
||||
path: forge/logs/
|
151
.github/workflows/python-checks.yml
vendored
Normal file
151
.github/workflows/python-checks.yml
vendored
Normal file
@ -0,0 +1,151 @@
|
||||
name: Python checks
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master, development, ci-test* ]
|
||||
paths:
|
||||
- '.github/workflows/lint-ci.yml'
|
||||
- 'autogpt/**'
|
||||
- 'forge/**'
|
||||
- 'benchmark/**'
|
||||
- '**.py'
|
||||
- '!autogpt/tests/vcr_cassettes'
|
||||
pull_request:
|
||||
branches: [ master, development, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/lint-ci.yml'
|
||||
- 'autogpt/**'
|
||||
- 'forge/**'
|
||||
- 'benchmark/**'
|
||||
- '**.py'
|
||||
- '!autogpt/tests/vcr_cassettes'
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('lint-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
cancel-in-progress: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
jobs:
|
||||
get-changed-parts:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- id: changes-in
|
||||
name: Determine affected subprojects
|
||||
uses: dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
autogpt:
|
||||
- autogpt/autogpt/**
|
||||
- autogpt/tests/**
|
||||
- autogpt/poetry.lock
|
||||
forge:
|
||||
- forge/forge/**
|
||||
- forge/tests/**
|
||||
- forge/poetry.lock
|
||||
benchmark:
|
||||
- benchmark/agbenchmark/**
|
||||
- benchmark/tests/**
|
||||
- benchmark/poetry.lock
|
||||
outputs:
|
||||
changed-parts: ${{ steps.changes-in.outputs.changes }}
|
||||
|
||||
lint:
|
||||
needs: get-changed-parts
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
min-python-version: "3.10"
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python ${{ env.min-python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
# Install dependencies
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry -C ${{ matrix.sub-package }} install
|
||||
|
||||
# Lint
|
||||
|
||||
- name: Lint (isort)
|
||||
run: poetry run isort --check .
|
||||
working-directory: ${{ matrix.sub-package }}
|
||||
|
||||
- name: Lint (Black)
|
||||
if: success() || failure()
|
||||
run: poetry run black --check .
|
||||
working-directory: ${{ matrix.sub-package }}
|
||||
|
||||
- name: Lint (Flake8)
|
||||
if: success() || failure()
|
||||
run: poetry run flake8 .
|
||||
working-directory: ${{ matrix.sub-package }}
|
||||
|
||||
types:
|
||||
needs: get-changed-parts
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
min-python-version: "3.10"
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python ${{ env.min-python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
# Install dependencies
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry -C ${{ matrix.sub-package }} install
|
||||
|
||||
# Typecheck
|
||||
|
||||
- name: Typecheck
|
||||
if: success() || failure()
|
||||
run: poetry run pyright
|
||||
working-directory: ${{ matrix.sub-package }}
|
127
.pre-commit-config.yaml
Normal file
127
.pre-commit-config.yaml
Normal file
@ -0,0 +1,127 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
args: ["--maxkb=500"]
|
||||
- id: fix-byte-order-marker
|
||||
- id: check-case-conflict
|
||||
- id: check-merge-conflict
|
||||
- id: check-symlinks
|
||||
- id: debug-statements
|
||||
|
||||
- repo: local
|
||||
# isort needs the context of which packages are installed to function, so we
|
||||
# can't use a vendored isort pre-commit hook (which runs in its own isolated venv).
|
||||
hooks:
|
||||
- id: isort-autogpt
|
||||
name: Lint (isort) - AutoGPT
|
||||
entry: poetry -C autogpt run isort
|
||||
files: ^autogpt/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
- id: isort-forge
|
||||
name: Lint (isort) - Forge
|
||||
entry: poetry -C forge run isort
|
||||
files: ^forge/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
- id: isort-benchmark
|
||||
name: Lint (isort) - Benchmark
|
||||
entry: poetry -C benchmark run isort
|
||||
files: ^benchmark/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.12.1
|
||||
# Black has sensible defaults, doesn't need package context, and ignores
|
||||
# everything in .gitignore, so it works fine without any config or arguments.
|
||||
hooks:
|
||||
- id: black
|
||||
name: Lint (Black)
|
||||
language_version: python3.10
|
||||
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.0.0
|
||||
# To have flake8 load the config of the individual subprojects, we have to call
|
||||
# them separately.
|
||||
hooks:
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - AutoGPT
|
||||
alias: flake8-autogpt
|
||||
files: ^autogpt/(autogpt|scripts|tests)/
|
||||
args: [--config=autogpt/.flake8]
|
||||
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Forge
|
||||
alias: flake8-forge
|
||||
files: ^forge/(forge|tests)/
|
||||
args: [--config=forge/.flake8]
|
||||
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Benchmark
|
||||
alias: flake8-benchmark
|
||||
files: ^benchmark/(agbenchmark|tests)/((?!reports).)*[/.]
|
||||
args: [--config=benchmark/.flake8]
|
||||
|
||||
- repo: local
|
||||
# To have watertight type checking, we check *all* the files in an affected
|
||||
# project. To trigger on poetry.lock we also reset the file `types` filter.
|
||||
hooks:
|
||||
- id: pyright
|
||||
name: Typecheck - AutoGPT
|
||||
alias: pyright-autogpt
|
||||
entry: poetry -C autogpt run pyright
|
||||
args: [-p, autogpt, autogpt]
|
||||
# include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||
files: ^(autogpt/((autogpt|scripts|tests)/|poetry\.lock$)|forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Forge
|
||||
alias: pyright-forge
|
||||
entry: poetry -C forge run pyright
|
||||
args: [-p, forge, forge]
|
||||
files: ^forge/(forge/|poetry\.lock$)
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Benchmark
|
||||
alias: pyright-benchmark
|
||||
entry: poetry -C benchmark run pyright
|
||||
args: [-p, benchmark, benchmark]
|
||||
files: ^benchmark/(agbenchmark|tests)/
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: pytest-autogpt
|
||||
name: Run tests - AutoGPT (excl. slow tests)
|
||||
entry: bash -c 'cd autogpt && poetry run pytest --cov=autogpt -m "not slow" tests/unit tests/integration'
|
||||
# include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||
files: ^(autogpt/((autogpt|tests)/|poetry\.lock$)|forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: pytest-forge
|
||||
name: Run tests - Forge (excl. slow tests)
|
||||
entry: bash -c 'cd forge && poetry run pytest --cov=forge -m "not slow"'
|
||||
files: ^forge/(forge/|tests/|poetry\.lock$)
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: pytest-benchmark
|
||||
name: Run tests - Benchmark
|
||||
entry: bash -c 'cd benchmark && poetry run pytest --cov=benchmark'
|
||||
files: ^benchmark/(agbenchmark/|tests/|poetry\.lock$)
|
||||
language: system
|
||||
pass_filenames: false
|
@ -1,11 +1,14 @@
|
||||
[flake8]
|
||||
max-line-length = 88
|
||||
extend-exclude =
|
||||
.*_cache/,
|
||||
.venv,
|
||||
# Ignore rules that conflict with Black code style
|
||||
extend-ignore = E203, W503
|
||||
exclude =
|
||||
.git,
|
||||
__pycache__/,
|
||||
*.pyc,
|
||||
.pytest_cache/,
|
||||
venv*/,
|
||||
.venv/,
|
||||
data/,
|
||||
logs/,
|
||||
tests/unit/data/,
|
||||
extend-ignore =
|
||||
# No whitespace before ':' conflicts with Black style for slices
|
||||
E203,
|
||||
|
@ -1,47 +0,0 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
args: ['--maxkb=500']
|
||||
- id: check-byte-order-marker
|
||||
- id: check-case-conflict
|
||||
- id: check-merge-conflict
|
||||
- id: check-symlinks
|
||||
- id: debug-statements
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
language_version: python3.10
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3.10
|
||||
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
|
||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||
# rev: 'v1.3.0'
|
||||
# hooks:
|
||||
# - id: mypy
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
# - id: autoflake
|
||||
# name: autoflake
|
||||
# entry: autoflake --in-place --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring autogpt tests
|
||||
# language: python
|
||||
# types: [ python ]
|
||||
- id: pytest-check
|
||||
name: pytest-check
|
||||
entry: bash -c 'cd autogpt && poetry run pytest --cov=autogpt tests/unit'
|
||||
language: system
|
||||
pass_filenames: false
|
||||
always_run: true
|
@ -1,77 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.config.config import ConfigBuilder
|
||||
from forge.file_storage import FileStorageBackendName, get_storage
|
||||
from forge.logging.config import configure_logging
|
||||
|
||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||
from autogpt.agents.agent_manager import AgentManager
|
||||
from autogpt.app.main import _configure_llm_provider, run_interaction_loop
|
||||
|
||||
LOG_DIR = Path(__file__).parent / "logs"
|
||||
|
||||
|
||||
def run_specific_agent(task: str, continuous_mode: bool = False) -> None:
|
||||
agent = bootstrap_agent(task, continuous_mode)
|
||||
asyncio.run(run_interaction_loop(agent))
|
||||
|
||||
|
||||
def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
|
||||
configure_logging(
|
||||
level=logging.DEBUG,
|
||||
log_dir=LOG_DIR,
|
||||
plain_console_output=True,
|
||||
)
|
||||
|
||||
config = ConfigBuilder.build_config_from_env()
|
||||
config.continuous_mode = continuous_mode
|
||||
config.continuous_limit = 20
|
||||
config.noninteractive_mode = True
|
||||
|
||||
ai_profile = AIProfile(
|
||||
ai_name="AutoGPT",
|
||||
ai_role="a multi-purpose AI assistant.",
|
||||
ai_goals=[task],
|
||||
)
|
||||
|
||||
agent_settings = AgentSettings(
|
||||
name=Agent.default_settings.name,
|
||||
agent_id=AgentManager.generate_id("AutoGPT-benchmark"),
|
||||
description=Agent.default_settings.description,
|
||||
ai_profile=ai_profile,
|
||||
config=AgentConfiguration(
|
||||
fast_llm=config.fast_llm,
|
||||
smart_llm=config.smart_llm,
|
||||
allow_fs_access=not config.restrict_to_workspace,
|
||||
use_functions_api=config.openai_functions,
|
||||
),
|
||||
history=Agent.default_settings.history.copy(deep=True),
|
||||
)
|
||||
|
||||
local = config.file_storage_backend == FileStorageBackendName.LOCAL
|
||||
restrict_to_root = not local or config.restrict_to_workspace
|
||||
file_storage = get_storage(
|
||||
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root
|
||||
)
|
||||
file_storage.initialize()
|
||||
|
||||
agent = Agent(
|
||||
settings=agent_settings,
|
||||
llm_provider=_configure_llm_provider(config),
|
||||
file_storage=file_storage,
|
||||
legacy_config=config,
|
||||
)
|
||||
return agent
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# The first argument is the script name itself, second is the task
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python script.py <task>")
|
||||
sys.exit(1)
|
||||
task = sys.argv[1]
|
||||
run_specific_agent(task, continuous_mode=True)
|
@ -4,7 +4,7 @@ from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.config.config import Config
|
||||
from forge.file_storage.base import FileStorage
|
||||
from forge.llm.providers import ChatModelProvider
|
||||
from forge.llm.providers import MultiProvider
|
||||
|
||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||
|
||||
@ -15,7 +15,7 @@ def create_agent(
|
||||
ai_profile: AIProfile,
|
||||
app_config: Config,
|
||||
file_storage: FileStorage,
|
||||
llm_provider: ChatModelProvider,
|
||||
llm_provider: MultiProvider,
|
||||
directives: Optional[AIDirectives] = None,
|
||||
) -> Agent:
|
||||
if not task:
|
||||
@ -39,7 +39,7 @@ def configure_agent_with_state(
|
||||
state: AgentSettings,
|
||||
app_config: Config,
|
||||
file_storage: FileStorage,
|
||||
llm_provider: ChatModelProvider,
|
||||
llm_provider: MultiProvider,
|
||||
) -> Agent:
|
||||
return _configure_agent(
|
||||
state=state,
|
||||
@ -51,7 +51,7 @@ def configure_agent_with_state(
|
||||
|
||||
def _configure_agent(
|
||||
app_config: Config,
|
||||
llm_provider: ChatModelProvider,
|
||||
llm_provider: MultiProvider,
|
||||
file_storage: FileStorage,
|
||||
agent_id: str = "",
|
||||
task: str = "",
|
||||
@ -59,20 +59,22 @@ def _configure_agent(
|
||||
directives: Optional[AIDirectives] = None,
|
||||
state: Optional[AgentSettings] = None,
|
||||
) -> Agent:
|
||||
if not (state or agent_id and task and ai_profile and directives):
|
||||
if state:
|
||||
agent_state = state
|
||||
elif agent_id and task and ai_profile and directives:
|
||||
agent_state = state or create_agent_state(
|
||||
agent_id=agent_id,
|
||||
task=task,
|
||||
ai_profile=ai_profile,
|
||||
directives=directives,
|
||||
app_config=app_config,
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
"Either (state) or (agent_id, task, ai_profile, directives)"
|
||||
" must be specified"
|
||||
)
|
||||
|
||||
agent_state = state or create_agent_state(
|
||||
agent_id=agent_id,
|
||||
task=task,
|
||||
ai_profile=ai_profile,
|
||||
directives=directives,
|
||||
app_config=app_config,
|
||||
)
|
||||
|
||||
return Agent(
|
||||
settings=agent_state,
|
||||
llm_provider=llm_provider,
|
||||
|
@ -7,7 +7,7 @@ from forge.file_storage.base import FileStorage
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents.agent import Agent
|
||||
from forge.config.config import Config
|
||||
from forge.llm.providers.schema import ChatModelProvider
|
||||
from forge.llm.providers import MultiProvider
|
||||
|
||||
from .configurators import _configure_agent
|
||||
from .profile_generator import generate_agent_profile_for_task
|
||||
@ -18,7 +18,7 @@ async def generate_agent_for_task(
|
||||
task: str,
|
||||
app_config: Config,
|
||||
file_storage: FileStorage,
|
||||
llm_provider: ChatModelProvider,
|
||||
llm_provider: MultiProvider,
|
||||
) -> Agent:
|
||||
ai_profile, task_directives = await generate_agent_profile_for_task(
|
||||
task=task,
|
||||
|
@ -5,10 +5,10 @@ from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.config.config import Config
|
||||
from forge.llm.prompting import ChatPrompt, LanguageModelClassification, PromptStrategy
|
||||
from forge.llm.providers import MultiProvider
|
||||
from forge.llm.providers.schema import (
|
||||
AssistantChatMessage,
|
||||
ChatMessage,
|
||||
ChatModelProvider,
|
||||
CompletionModelFunction,
|
||||
)
|
||||
from forge.models.config import SystemConfiguration, UserConfigurable
|
||||
@ -141,7 +141,7 @@ class AgentProfileGeneratorConfiguration(SystemConfiguration):
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
).schema
|
||||
).dict()
|
||||
)
|
||||
|
||||
|
||||
@ -160,7 +160,7 @@ class AgentProfileGenerator(PromptStrategy):
|
||||
self._model_classification = model_classification
|
||||
self._system_prompt_message = system_prompt
|
||||
self._user_prompt_template = user_prompt_template
|
||||
self._create_agent_function = CompletionModelFunction.parse(
|
||||
self._create_agent_function = CompletionModelFunction.parse_obj(
|
||||
create_agent_function
|
||||
)
|
||||
|
||||
@ -183,7 +183,7 @@ class AgentProfileGenerator(PromptStrategy):
|
||||
|
||||
def parse_response_content(
|
||||
self,
|
||||
response_content: AssistantChatMessage,
|
||||
response: AssistantChatMessage,
|
||||
) -> tuple[AIProfile, AIDirectives]:
|
||||
"""Parse the actual text response from the objective model.
|
||||
|
||||
@ -195,15 +195,15 @@ class AgentProfileGenerator(PromptStrategy):
|
||||
|
||||
"""
|
||||
try:
|
||||
if not response_content.tool_calls:
|
||||
if not response.tool_calls:
|
||||
raise ValueError(
|
||||
f"LLM did not call {self._create_agent_function.name} function; "
|
||||
"agent profile creation failed"
|
||||
)
|
||||
arguments: object = response_content.tool_calls[0].function.arguments
|
||||
arguments: object = response.tool_calls[0].function.arguments
|
||||
ai_profile = AIProfile(
|
||||
ai_name=arguments.get("name"),
|
||||
ai_role=arguments.get("description"),
|
||||
ai_name=arguments.get("name"), # type: ignore
|
||||
ai_role=arguments.get("description"), # type: ignore
|
||||
)
|
||||
ai_directives = AIDirectives(
|
||||
best_practices=arguments.get("directives", {}).get("best_practices"),
|
||||
@ -211,7 +211,7 @@ class AgentProfileGenerator(PromptStrategy):
|
||||
resources=[],
|
||||
)
|
||||
except KeyError:
|
||||
logger.debug(f"Failed to parse this response content: {response_content}")
|
||||
logger.debug(f"Failed to parse this response content: {response}")
|
||||
raise
|
||||
return ai_profile, ai_directives
|
||||
|
||||
@ -219,7 +219,7 @@ class AgentProfileGenerator(PromptStrategy):
|
||||
async def generate_agent_profile_for_task(
|
||||
task: str,
|
||||
app_config: Config,
|
||||
llm_provider: ChatModelProvider,
|
||||
llm_provider: MultiProvider,
|
||||
) -> tuple[AIProfile, AIDirectives]:
|
||||
"""Generates an AIConfig object from the given string.
|
||||
|
||||
|
@ -24,7 +24,7 @@ class MyAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
settings: AgentSettings,
|
||||
llm_provider: ChatModelProvider,
|
||||
llm_provider: MultiProvider
|
||||
file_storage: FileStorage,
|
||||
legacy_config: Config,
|
||||
):
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
||||
|
||||
import sentry_sdk
|
||||
from forge.agent.base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
|
||||
@ -14,7 +14,7 @@ from forge.agent.protocols import (
|
||||
DirectiveProvider,
|
||||
MessageProvider,
|
||||
)
|
||||
from forge.command.command import Command, CommandOutput
|
||||
from forge.command.command import Command
|
||||
from forge.components.action_history import (
|
||||
ActionHistoryComponent,
|
||||
EpisodicActionHistory,
|
||||
@ -34,8 +34,8 @@ from forge.llm.prompting.utils import dump_prompt
|
||||
from forge.llm.providers import (
|
||||
AssistantFunctionCall,
|
||||
ChatMessage,
|
||||
ChatModelProvider,
|
||||
ChatModelResponse,
|
||||
MultiProvider,
|
||||
)
|
||||
from forge.llm.providers.utils import function_specs_from_commands
|
||||
from forge.models.action import (
|
||||
@ -76,7 +76,9 @@ class AgentConfiguration(BaseAgentConfiguration):
|
||||
|
||||
|
||||
class AgentSettings(BaseAgentSettings):
|
||||
config: AgentConfiguration = Field(default_factory=AgentConfiguration)
|
||||
config: AgentConfiguration = Field( # type: ignore
|
||||
default_factory=AgentConfiguration
|
||||
)
|
||||
|
||||
history: EpisodicActionHistory[OneShotAgentActionProposal] = Field(
|
||||
default_factory=EpisodicActionHistory[OneShotAgentActionProposal]
|
||||
@ -86,8 +88,8 @@ class AgentSettings(BaseAgentSettings):
|
||||
context: AgentContext = Field(default_factory=AgentContext)
|
||||
|
||||
|
||||
class Agent(BaseAgent, Configurable[AgentSettings]):
|
||||
default_settings: AgentSettings = AgentSettings(
|
||||
class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
|
||||
default_settings: ClassVar[AgentSettings] = AgentSettings(
|
||||
name="Agent",
|
||||
description=__doc__ if __doc__ else "",
|
||||
)
|
||||
@ -95,7 +97,7 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
|
||||
def __init__(
|
||||
self,
|
||||
settings: AgentSettings,
|
||||
llm_provider: ChatModelProvider,
|
||||
llm_provider: MultiProvider,
|
||||
file_storage: FileStorage,
|
||||
legacy_config: Config,
|
||||
):
|
||||
@ -280,7 +282,7 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
|
||||
|
||||
return result
|
||||
|
||||
async def _execute_tool(self, tool_call: AssistantFunctionCall) -> CommandOutput:
|
||||
async def _execute_tool(self, tool_call: AssistantFunctionCall) -> Any:
|
||||
"""Execute the command and return the result
|
||||
|
||||
Args:
|
||||
|
@ -43,7 +43,7 @@ class AssistantThoughts(ModelWithSummary):
|
||||
|
||||
|
||||
class OneShotAgentActionProposal(ActionProposal):
|
||||
thoughts: AssistantThoughts
|
||||
thoughts: AssistantThoughts # type: ignore
|
||||
|
||||
|
||||
class OneShotAgentPromptConfiguration(SystemConfiguration):
|
||||
@ -186,11 +186,8 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
|
||||
def response_format_instruction(self, use_functions_api: bool) -> tuple[str, str]:
|
||||
response_schema = self.response_schema.copy(deep=True)
|
||||
if (
|
||||
use_functions_api
|
||||
and response_schema.properties
|
||||
and "use_tool" in response_schema.properties
|
||||
):
|
||||
assert response_schema.properties
|
||||
if use_functions_api and "use_tool" in response_schema.properties:
|
||||
del response_schema.properties["use_tool"]
|
||||
|
||||
# Unindent for performance
|
||||
@ -288,10 +285,10 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
"Parsing object extracted from LLM response:\n"
|
||||
f"{json.dumps(assistant_reply_dict, indent=4)}"
|
||||
)
|
||||
|
||||
parsed_response = OneShotAgentActionProposal.parse_obj(assistant_reply_dict)
|
||||
if self.config.use_functions_api:
|
||||
if not response.tool_calls:
|
||||
raise InvalidAgentResponseError("Assistant did not use a tool")
|
||||
parsed_response.use_tool = response.tool_calls[0].function
|
||||
assistant_reply_dict["use_tool"] = response.tool_calls[0].function
|
||||
|
||||
parsed_response = OneShotAgentActionProposal.parse_obj(assistant_reply_dict)
|
||||
return parsed_response
|
||||
|
@ -25,7 +25,7 @@ from forge.agent_protocol.models import (
|
||||
)
|
||||
from forge.config.config import Config
|
||||
from forge.file_storage import FileStorage
|
||||
from forge.llm.providers import ChatModelProvider, ModelProviderBudget
|
||||
from forge.llm.providers import ModelProviderBudget, MultiProvider
|
||||
from forge.models.action import ActionErrorResult, ActionSuccessResult
|
||||
from forge.utils.const import ASK_COMMAND, FINISH_COMMAND
|
||||
from forge.utils.exceptions import AgentFinished, NotFoundError
|
||||
@ -49,7 +49,7 @@ class AgentProtocolServer:
|
||||
app_config: Config,
|
||||
database: AgentDB,
|
||||
file_storage: FileStorage,
|
||||
llm_provider: ChatModelProvider,
|
||||
llm_provider: MultiProvider,
|
||||
):
|
||||
self.app_config = app_config
|
||||
self.db = database
|
||||
@ -444,9 +444,7 @@ class AgentProtocolServer:
|
||||
agent_id = task_agent_id(task_id)
|
||||
return self.file_storage.clone_with_subroot(f"agents/{agent_id}/workspace")
|
||||
|
||||
def _get_task_llm_provider(
|
||||
self, task: Task, step_id: str = ""
|
||||
) -> ChatModelProvider:
|
||||
def _get_task_llm_provider(self, task: Task, step_id: str = "") -> MultiProvider:
|
||||
"""
|
||||
Configures the LLM provider with headers to link outgoing requests to the task.
|
||||
"""
|
||||
|
@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from colorama import Fore, Style
|
||||
from forge.agent_protocol.database import AgentDB
|
||||
from forge.components.code_executor import (
|
||||
from forge.components.code_executor.code_executor import (
|
||||
is_docker_available,
|
||||
we_are_running_in_a_docker_container,
|
||||
)
|
||||
@ -82,7 +82,9 @@ async def run_auto_gpt(
|
||||
local = config.file_storage_backend == FileStorageBackendName.LOCAL
|
||||
restrict_to_root = not local or config.restrict_to_workspace
|
||||
file_storage = get_storage(
|
||||
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root
|
||||
config.file_storage_backend,
|
||||
root_path=Path("data"),
|
||||
restrict_to_root=restrict_to_root,
|
||||
)
|
||||
file_storage.initialize()
|
||||
|
||||
@ -353,7 +355,9 @@ async def run_auto_gpt_server(
|
||||
local = config.file_storage_backend == FileStorageBackendName.LOCAL
|
||||
restrict_to_root = not local or config.restrict_to_workspace
|
||||
file_storage = get_storage(
|
||||
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root
|
||||
config.file_storage_backend,
|
||||
root_path=Path("data"),
|
||||
restrict_to_root=restrict_to_root,
|
||||
)
|
||||
file_storage.initialize()
|
||||
|
||||
|
@ -7,7 +7,7 @@ import re
|
||||
import socket
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Coroutine, ParamSpec, TypeVar
|
||||
from typing import Any, Callable, Coroutine, ParamSpec, TypeVar, cast
|
||||
|
||||
import requests
|
||||
from colorama import Fore, Style
|
||||
@ -88,7 +88,7 @@ def vcs_state_diverges_from_master() -> bool:
|
||||
def get_git_user_email() -> str:
|
||||
try:
|
||||
repo = Repo(search_parent_directories=True)
|
||||
return repo.config_reader().get_value("user", "email", default="")
|
||||
return cast(str, repo.config_reader().get_value("user", "email", default=""))
|
||||
except InvalidGitRepositoryError:
|
||||
return ""
|
||||
|
||||
|
529
autogpt/poetry.lock
generated
529
autogpt/poetry.lock
generated
File diff suppressed because one or more lines are too long
@ -1,9 +1,7 @@
|
||||
[tool.poetry]
|
||||
name = "agpt"
|
||||
version = "0.5.0"
|
||||
authors = [
|
||||
"Significant Gravitas <support@agpt.co>",
|
||||
]
|
||||
authors = ["Significant Gravitas <support@agpt.co>"]
|
||||
readme = "README.md"
|
||||
description = "An open-source attempt to make GPT-4 autonomous"
|
||||
homepage = "https://github.com/Significant-Gravitas/AutoGPT/tree/master/autogpt"
|
||||
@ -30,11 +28,10 @@ charset-normalizer = "^3.1.0"
|
||||
click = "*"
|
||||
colorama = "^0.4.6"
|
||||
distro = "^1.8.0"
|
||||
en-core-web-sm = {url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl"}
|
||||
en-core-web-sm = { url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl" }
|
||||
fastapi = "^0.109.1"
|
||||
ftfy = "^6.1.1"
|
||||
google-api-python-client = "*"
|
||||
gTTS = "^2.3.1"
|
||||
hypercorn = "^0.14.4"
|
||||
inflection = "*"
|
||||
jsonschema = "*"
|
||||
@ -58,21 +55,18 @@ openapi-python-client = "^0.14.0"
|
||||
# Benchmarking
|
||||
agbenchmark = { path = "../benchmark", optional = true }
|
||||
# agbenchmark = {git = "https://github.com/Significant-Gravitas/AutoGPT.git", subdirectory = "benchmark", optional = true}
|
||||
google-cloud-logging = "^3.8.0"
|
||||
google-cloud-storage = "^2.13.0"
|
||||
psycopg2-binary = "^2.9.9"
|
||||
|
||||
[tool.poetry.extras]
|
||||
benchmark = ["agbenchmark"]
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "*"
|
||||
boto3-stubs = {extras = ["s3"], version = "^1.33.6"}
|
||||
flake8 = "*"
|
||||
black = "^23.12.1"
|
||||
flake8 = "^7.0.0"
|
||||
gitpython = "^3.1.32"
|
||||
isort = "*"
|
||||
mypy = "*"
|
||||
isort = "^5.13.1"
|
||||
pre-commit = "*"
|
||||
pyright = "^1.1.364"
|
||||
types-beautifulsoup4 = "*"
|
||||
types-colorama = "*"
|
||||
types-Markdown = "*"
|
||||
@ -89,7 +83,7 @@ pytest-integration = "*"
|
||||
pytest-mock = "*"
|
||||
pytest-recording = "*"
|
||||
pytest-xdist = "*"
|
||||
vcrpy = {git = "https://github.com/Significant-Gravitas/vcrpy.git", rev = "master"}
|
||||
vcrpy = { git = "https://github.com/Significant-Gravitas/vcrpy.git", rev = "master" }
|
||||
|
||||
|
||||
[build-system]
|
||||
@ -101,50 +95,18 @@ build-backend = "poetry.core.masonry.api"
|
||||
line-length = 88
|
||||
target-version = ['py310']
|
||||
include = '\.pyi?$'
|
||||
packages = ["autogpt"]
|
||||
extend-exclude = '.+/(dist|.venv|venv|build|data)/.+'
|
||||
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
multi_line_output = 3
|
||||
include_trailing_comma = true
|
||||
force_grid_wrap = 0
|
||||
use_parentheses = true
|
||||
ensure_newline_before_comments = true
|
||||
line_length = 88
|
||||
sections = [
|
||||
"FUTURE",
|
||||
"STDLIB",
|
||||
"THIRDPARTY",
|
||||
"FIRSTPARTY",
|
||||
"LOCALFOLDER"
|
||||
]
|
||||
extend_skip = [
|
||||
"agbenchmark_config/temp_folder/",
|
||||
"data/",
|
||||
]
|
||||
skip_glob = ["data"]
|
||||
|
||||
|
||||
[tool.mypy]
|
||||
follow_imports = 'skip'
|
||||
check_untyped_defs = true
|
||||
disallow_untyped_calls = true
|
||||
files = [
|
||||
'autogpt/**/*.py',
|
||||
'tests/**/*.py'
|
||||
]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
'requests.*',
|
||||
'yaml.*'
|
||||
]
|
||||
ignore_missing_imports = true
|
||||
[tool.pyright]
|
||||
pythonVersion = "3.10"
|
||||
exclude = ["data/**", "**/node_modules", "**/__pycache__", "**/.*"]
|
||||
ignore = ["../forge/**"]
|
||||
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = [
|
||||
"requires_openai_api_key",
|
||||
"requires_huggingface_api_key"
|
||||
]
|
||||
markers = ["slow", "requires_openai_api_key", "requires_huggingface_api_key"]
|
||||
|
@ -4,12 +4,12 @@ import sys
|
||||
from importlib.metadata import version
|
||||
|
||||
try:
|
||||
import poetry.factory # noqa
|
||||
import poetry.factory # type: ignore # noqa
|
||||
except ModuleNotFoundError:
|
||||
os.system(f"{sys.executable} -m pip install 'poetry>=1.6.1,<2.0.0'")
|
||||
|
||||
from poetry.core.constraints.version.version import Version
|
||||
from poetry.factory import Factory
|
||||
from poetry.core.constraints.version.version import Version # type: ignore
|
||||
from poetry.factory import Factory # type: ignore
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -20,7 +20,7 @@ from autogpt.app.utils import coroutine
|
||||
)
|
||||
@coroutine
|
||||
async def generate_release_notes(repo_path: Optional[Path] = None):
|
||||
logger = logging.getLogger(generate_release_notes.name)
|
||||
logger = logging.getLogger(generate_release_notes.name) # pyright: ignore
|
||||
|
||||
repo = Repo(repo_path, search_parent_directories=True)
|
||||
tags = list(repo.tags)
|
||||
|
@ -12,7 +12,7 @@ from forge.file_storage.local import (
|
||||
FileStorageConfiguration,
|
||||
LocalFileStorage,
|
||||
)
|
||||
from forge.llm.providers import ChatModelProvider
|
||||
from forge.llm.providers import MultiProvider
|
||||
from forge.logging.config import configure_logging
|
||||
|
||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||
@ -71,14 +71,12 @@ def setup_logger(config: Config):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_provider(config: Config) -> ChatModelProvider:
|
||||
def llm_provider(config: Config) -> MultiProvider:
|
||||
return _configure_llm_provider(config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent(
|
||||
config: Config, llm_provider: ChatModelProvider, storage: FileStorage
|
||||
) -> Agent:
|
||||
def agent(config: Config, llm_provider: MultiProvider, storage: FileStorage) -> Agent:
|
||||
ai_profile = AIProfile(
|
||||
ai_name="Base",
|
||||
ai_role="A base AI",
|
||||
|
@ -1,13 +1,16 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.config.config import Config
|
||||
from forge.file_storage import FileStorageBackendName, get_storage
|
||||
from forge.llm.providers import MultiProvider
|
||||
|
||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_agent(config: Config, llm_provider, memory_json_file):
|
||||
def dummy_agent(config: Config, llm_provider: MultiProvider):
|
||||
ai_profile = AIProfile(
|
||||
ai_name="Dummy Agent",
|
||||
ai_role="Dummy Role",
|
||||
@ -31,7 +34,9 @@ def dummy_agent(config: Config, llm_provider, memory_json_file):
|
||||
local = config.file_storage_backend == FileStorageBackendName.LOCAL
|
||||
restrict_to_root = not local or config.restrict_to_workspace
|
||||
file_storage = get_storage(
|
||||
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root
|
||||
config.file_storage_backend,
|
||||
root_path=Path("data"),
|
||||
restrict_to_root=restrict_to_root,
|
||||
)
|
||||
file_storage.initialize()
|
||||
|
||||
|
@ -4,7 +4,7 @@ import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from forge.components.code_executor import (
|
||||
from forge.components.code_executor.code_executor import (
|
||||
ALLOWLIST_CONTROL,
|
||||
CodeExecutorComponent,
|
||||
is_docker_available,
|
||||
|
@ -257,17 +257,3 @@ def test_huggingface_fail_request_bad_image(
|
||||
result = image_gen_component.generate_image("astronaut riding a horse", 512)
|
||||
|
||||
assert result == "Error creating image."
|
||||
|
||||
|
||||
def test_huggingface_fail_missing_api_token(
|
||||
mocker, image_gen_component: ImageGeneratorComponent, agent: Agent
|
||||
):
|
||||
agent.legacy_config.image_provider = "huggingface"
|
||||
agent.legacy_config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
# Mock requests.post to raise ValueError
|
||||
mocker.patch("requests.post", side_effect=ValueError)
|
||||
|
||||
# Verify request raises an error.
|
||||
with pytest.raises(ValueError):
|
||||
image_gen_component.generate_image("astronaut riding a horse", 512)
|
||||
|
@ -67,8 +67,8 @@ def test_missing_azure_config(config: Config) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
config.openai_credentials.load_azure_config(config_file)
|
||||
|
||||
assert config.openai_credentials.api_type != "azure"
|
||||
assert config.openai_credentials.api_version == ""
|
||||
assert config.openai_credentials.api_type != SecretStr("azure")
|
||||
assert config.openai_credentials.api_version is None
|
||||
assert config.openai_credentials.azure_model_to_deploy_id_map is None
|
||||
|
||||
|
||||
@ -98,8 +98,8 @@ azure_model_map:
|
||||
|
||||
def test_azure_config(config_with_azure: Config) -> None:
|
||||
assert (credentials := config_with_azure.openai_credentials) is not None
|
||||
assert credentials.api_type == "azure"
|
||||
assert credentials.api_version == "2023-06-01-preview"
|
||||
assert credentials.api_type == SecretStr("azure")
|
||||
assert credentials.api_version == SecretStr("2023-06-01-preview")
|
||||
assert credentials.azure_endpoint == SecretStr("https://dummy.openai.azure.com")
|
||||
assert credentials.azure_model_to_deploy_id_map == {
|
||||
config_with_azure.fast_llm: "FAST-LLM_ID",
|
||||
|
@ -4,7 +4,7 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from forge.file_storage import GCSFileStorage, GCSFileStorageConfiguration
|
||||
from forge.file_storage.gcs import GCSFileStorage, GCSFileStorageConfiguration
|
||||
from google.auth.exceptions import GoogleAuthError
|
||||
from google.cloud import storage
|
||||
from google.cloud.exceptions import NotFound
|
||||
@ -14,6 +14,8 @@ try:
|
||||
except GoogleAuthError:
|
||||
pytest.skip("Google Cloud Authentication not configured", allow_module_level=True)
|
||||
|
||||
pytestmark = pytest.mark.slow
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gcs_bucket_name() -> str:
|
||||
@ -26,7 +28,7 @@ def gcs_root() -> Path:
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gcs_storage_uninitialized(gcs_bucket_name: str, gcs_root: Path) -> GCSFileStorage:
|
||||
def gcs_storage_uninitialized(gcs_bucket_name: str, gcs_root: Path):
|
||||
os.environ["STORAGE_BUCKET"] = gcs_bucket_name
|
||||
storage_config = GCSFileStorageConfiguration.from_env()
|
||||
storage_config.root = gcs_root
|
||||
@ -52,7 +54,7 @@ def test_initialize(gcs_bucket_name: str, gcs_storage_uninitialized: GCSFileStor
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gcs_storage(gcs_storage_uninitialized: GCSFileStorage) -> GCSFileStorage:
|
||||
def gcs_storage(gcs_storage_uninitialized: GCSFileStorage):
|
||||
(gcs_storage := gcs_storage_uninitialized).initialize()
|
||||
yield gcs_storage # type: ignore
|
||||
|
||||
@ -77,7 +79,7 @@ TEST_FILES: list[tuple[str | Path, str]] = [
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def gcs_storage_with_files(gcs_storage: GCSFileStorage) -> GCSFileStorage:
|
||||
async def gcs_storage_with_files(gcs_storage: GCSFileStorage):
|
||||
for file_name, file_content in TEST_FILES:
|
||||
gcs_storage._bucket.blob(
|
||||
str(gcs_storage.get_path(file_name))
|
||||
|
@ -1,7 +1,7 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from forge.json import json_loads
|
||||
from forge.json.parsing import json_loads
|
||||
|
||||
_JSON_FIXABLE: list[tuple[str, str]] = [
|
||||
# Missing comma
|
||||
|
@ -1,7 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from forge.file_storage import FileStorageConfiguration, LocalFileStorage
|
||||
from forge.file_storage.local import FileStorageConfiguration, LocalFileStorage
|
||||
|
||||
_ACCESSIBLE_PATHS = [
|
||||
Path("."),
|
||||
|
@ -5,7 +5,7 @@ from pathlib import Path
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from botocore.exceptions import ClientError
|
||||
from forge.file_storage import S3FileStorage, S3FileStorageConfiguration
|
||||
from forge.file_storage.s3 import S3FileStorage, S3FileStorageConfiguration
|
||||
|
||||
if not (os.getenv("S3_ENDPOINT_URL") and os.getenv("AWS_ACCESS_KEY_ID")):
|
||||
pytest.skip("S3 environment variables are not set", allow_module_level=True)
|
||||
@ -22,7 +22,7 @@ def s3_root() -> Path:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_storage_uninitialized(s3_bucket_name: str, s3_root: Path) -> S3FileStorage:
|
||||
def s3_storage_uninitialized(s3_bucket_name: str, s3_root: Path):
|
||||
os.environ["STORAGE_BUCKET"] = s3_bucket_name
|
||||
storage_config = S3FileStorageConfiguration.from_env()
|
||||
storage_config.root = s3_root
|
||||
@ -36,12 +36,13 @@ def test_initialize(s3_bucket_name: str, s3_storage_uninitialized: S3FileStorage
|
||||
|
||||
# test that the bucket doesn't exist yet
|
||||
with pytest.raises(ClientError):
|
||||
s3.meta.client.head_bucket(Bucket=s3_bucket_name)
|
||||
s3.meta.client.head_bucket(Bucket=s3_bucket_name) # pyright: ignore
|
||||
|
||||
s3_storage_uninitialized.initialize()
|
||||
|
||||
# test that the bucket has been created
|
||||
s3.meta.client.head_bucket(Bucket=s3_bucket_name)
|
||||
s3.meta.client.head_bucket(Bucket=s3_bucket_name) # pyright: ignore
|
||||
# FIXME: remove the "pyright: ignore" comments after moving this test file to forge
|
||||
|
||||
|
||||
def test_workspace_bucket_name(
|
||||
@ -52,7 +53,7 @@ def test_workspace_bucket_name(
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_storage(s3_storage_uninitialized: S3FileStorage) -> S3FileStorage:
|
||||
def s3_storage(s3_storage_uninitialized: S3FileStorage):
|
||||
(s3_storage := s3_storage_uninitialized).initialize()
|
||||
yield s3_storage # type: ignore
|
||||
|
||||
@ -71,7 +72,7 @@ TEST_FILES: list[tuple[str | Path, str]] = [
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def s3_storage_with_files(s3_storage: S3FileStorage) -> S3FileStorage:
|
||||
async def s3_storage_with_files(s3_storage: S3FileStorage):
|
||||
for file_name, file_content in TEST_FILES:
|
||||
s3_storage._bucket.Object(str(s3_storage.get_path(file_name))).put(
|
||||
Body=file_content
|
||||
|
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
from hashlib import sha256
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from openai import OpenAI
|
||||
@ -53,11 +54,14 @@ def cached_openai_client(mocker: MockerFixture) -> OpenAI:
|
||||
def _patched_prepare_options(self, options: FinalRequestOptions):
|
||||
_prepare_options(options)
|
||||
|
||||
if not options.json_data:
|
||||
return
|
||||
|
||||
headers: dict[str, str | Omit] = (
|
||||
{**options.headers} if is_given(options.headers) else {}
|
||||
)
|
||||
options.headers = headers
|
||||
data: dict = options.json_data
|
||||
data = cast(dict, options.json_data)
|
||||
|
||||
logging.getLogger("cached_openai_client").debug(
|
||||
f"Outgoing API request: {headers}\n{data if data else None}"
|
||||
|
@ -1,15 +1,12 @@
|
||||
[flake8]
|
||||
max-line-length = 88
|
||||
select = "E303, W293, W291, W292, E305, E231, E302"
|
||||
# Ignore rules that conflict with Black code style
|
||||
extend-ignore = E203, W503
|
||||
exclude =
|
||||
.tox,
|
||||
__pycache__,
|
||||
__pycache__/,
|
||||
*.pyc,
|
||||
.env
|
||||
venv*/*,
|
||||
.venv/*,
|
||||
reports/*,
|
||||
dist/*,
|
||||
agent/*,
|
||||
code,
|
||||
agbenchmark/challenges/*
|
||||
.pytest_cache/,
|
||||
venv*/,
|
||||
.venv/,
|
||||
reports/,
|
||||
agbenchmark/reports/,
|
||||
|
@ -1,36 +0,0 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
args: ['--maxkb=500']
|
||||
- id: check-byte-order-marker
|
||||
- id: check-case-conflict
|
||||
- id: check-merge-conflict
|
||||
- id: check-symlinks
|
||||
- id: debug-statements
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
language_version: python3.10
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3.10
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: 'v1.3.0'
|
||||
hooks:
|
||||
- id: mypy
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: autoflake
|
||||
name: autoflake
|
||||
entry: autoflake --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring --in-place agbenchmark
|
||||
language: python
|
||||
types: [ python ]
|
@ -28,7 +28,7 @@ async def run_api_agent(
|
||||
configuration = Configuration(host=config.host)
|
||||
async with ApiClient(configuration) as api_client:
|
||||
api_instance = AgentApi(api_client)
|
||||
task_request_body = TaskRequestBody(input=task)
|
||||
task_request_body = TaskRequestBody(input=task, additional_input=None)
|
||||
|
||||
start_time = time.time()
|
||||
response = await api_instance.create_agent_task(
|
||||
|
@ -106,8 +106,8 @@ def find_agbenchmark_without_uvicorn():
|
||||
|
||||
|
||||
class CreateReportRequest(BaseModel):
|
||||
test: str = None
|
||||
test_run_id: str = None
|
||||
test: str
|
||||
test_run_id: str
|
||||
# category: Optional[str] = []
|
||||
mock: Optional[bool] = False
|
||||
|
||||
@ -178,8 +178,8 @@ def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
|
||||
logger.debug(f"Benchmark finished running in {time.time() - start_time} s")
|
||||
|
||||
# List all folders in the current working directory
|
||||
path_reports = agbenchmark_config.reports_folder
|
||||
folders = [folder for folder in path_reports.iterdir() if folder.is_dir()]
|
||||
reports_folder = agbenchmark_config.reports_folder
|
||||
folders = [folder for folder in reports_folder.iterdir() if folder.is_dir()]
|
||||
|
||||
# Sort the folders based on their names
|
||||
sorted_folders = sorted(folders, key=lambda x: x.name)
|
||||
@ -196,13 +196,14 @@ def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
|
||||
data = json.load(file)
|
||||
logger.debug(f"Report data: {data}")
|
||||
else:
|
||||
logger.error(
|
||||
raise HTTPException(
|
||||
502,
|
||||
"Could not get result after running benchmark: "
|
||||
f"'report.json' does not exist in '{latest_folder}'"
|
||||
f"'report.json' does not exist in '{latest_folder}'",
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"Could not get result after running benchmark: no reports found"
|
||||
raise HTTPException(
|
||||
504, "Could not get result after running benchmark: no reports found"
|
||||
)
|
||||
|
||||
return data
|
||||
@ -239,7 +240,9 @@ def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
|
||||
api_instance = AgentApi(api_client)
|
||||
task_input = challenge_info.task
|
||||
|
||||
task_request_body = TaskRequestBody(input=task_input)
|
||||
task_request_body = TaskRequestBody(
|
||||
input=task_input, additional_input=None
|
||||
)
|
||||
task_response = await api_instance.create_agent_task(
|
||||
task_request_body=task_request_body
|
||||
)
|
||||
@ -276,7 +279,7 @@ def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
|
||||
# Forward the request
|
||||
response = await client.post(
|
||||
new_url,
|
||||
data=await request.body(),
|
||||
content=await request.body(),
|
||||
headers=dict(request.headers),
|
||||
)
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, ClassVar, Optional
|
||||
from typing import AsyncIterator, Awaitable, ClassVar, Optional
|
||||
|
||||
import pytest
|
||||
from agent_protocol_client import AgentApi, Step
|
||||
@ -54,7 +54,7 @@ class BaseChallenge(ABC):
|
||||
config: AgentBenchmarkConfig,
|
||||
request: pytest.FixtureRequest,
|
||||
i_attempt: int,
|
||||
) -> None:
|
||||
) -> None | Awaitable[None]:
|
||||
"""
|
||||
Test method for use by Pytest-based benchmark sessions. Should return normally
|
||||
if the challenge passes, and raise a (preferably descriptive) error otherwise.
|
||||
|
@ -1,4 +1,3 @@
|
||||
from collections import deque
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
@ -6,19 +5,17 @@ import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Iterator, Literal, Optional
|
||||
from typing import Annotated, Any, ClassVar, Iterator, Literal, Optional
|
||||
|
||||
import pytest
|
||||
from agent_protocol_client import (
|
||||
AgentApi,
|
||||
ApiClient,
|
||||
Configuration as ClientConfig,
|
||||
Step,
|
||||
)
|
||||
from agent_protocol_client import AgentApi, ApiClient
|
||||
from agent_protocol_client import Configuration as ClientConfig
|
||||
from agent_protocol_client import Step
|
||||
from colorama import Fore, Style
|
||||
from openai import _load_client as get_openai_client
|
||||
from pydantic import BaseModel, constr, Field, validator
|
||||
from pydantic import BaseModel, Field, constr, validator
|
||||
|
||||
from agbenchmark.agent_api_interface import download_agent_artifacts_into_folder
|
||||
from agbenchmark.agent_interface import copy_challenge_artifacts_into_workspace
|
||||
@ -49,7 +46,7 @@ class BuiltinChallengeSpec(BaseModel):
|
||||
|
||||
class Info(BaseModel):
|
||||
difficulty: DifficultyLevel
|
||||
description: constr(regex=r"^Tests if the agent can.*")
|
||||
description: Annotated[str, constr(regex=r"^Tests if the agent can.*")]
|
||||
side_effects: list[str] = Field(default_factory=list)
|
||||
|
||||
info: Info
|
||||
@ -184,7 +181,7 @@ class BuiltinChallenge(BaseChallenge):
|
||||
steps: list[Step] = []
|
||||
try:
|
||||
async for step in self.run_challenge(
|
||||
config, timeout, mock=request.config.getoption("--mock")
|
||||
config, timeout, mock=bool(request.config.getoption("--mock"))
|
||||
):
|
||||
if not task_id:
|
||||
task_id = step.task_id
|
||||
@ -199,6 +196,8 @@ class BuiltinChallenge(BaseChallenge):
|
||||
timed_out = False
|
||||
except TimeoutError:
|
||||
timed_out = True
|
||||
|
||||
assert isinstance(request.node, pytest.Item)
|
||||
request.node.user_properties.append(("steps", steps))
|
||||
request.node.user_properties.append(("n_steps", n_steps))
|
||||
request.node.user_properties.append(("timed_out", timed_out))
|
||||
@ -411,15 +410,10 @@ class BuiltinChallenge(BaseChallenge):
|
||||
def load_builtin_challenges() -> Iterator[type[BuiltinChallenge]]:
|
||||
logger.info("Loading built-in challenges...")
|
||||
|
||||
challenges_path = os.path.dirname(__file__)
|
||||
challenges_path = Path(__file__).parent
|
||||
logger.debug(f"Looking for challenge spec files in {challenges_path}...")
|
||||
|
||||
json_files = deque(
|
||||
glob.glob(
|
||||
f"{challenges_path}/**/data.json",
|
||||
recursive=True,
|
||||
)
|
||||
)
|
||||
json_files = deque(challenges_path.rglob("data.json"))
|
||||
|
||||
logger.debug(f"Found {len(json_files)} built-in challenges.")
|
||||
|
||||
@ -431,7 +425,7 @@ def load_builtin_challenges() -> Iterator[type[BuiltinChallenge]]:
|
||||
ignored += 1
|
||||
continue
|
||||
|
||||
challenge = BuiltinChallenge.from_challenge_spec_file(Path(json_file))
|
||||
challenge = BuiltinChallenge.from_challenge_spec_file(json_file)
|
||||
logger.debug(f"Generated test for {challenge.info.name}")
|
||||
yield challenge
|
||||
|
||||
@ -442,8 +436,8 @@ def load_builtin_challenges() -> Iterator[type[BuiltinChallenge]]:
|
||||
)
|
||||
|
||||
|
||||
def _challenge_should_be_ignored(json_file_path: str):
|
||||
def _challenge_should_be_ignored(json_file_path: Path):
|
||||
return (
|
||||
"challenges/deprecated" in json_file_path
|
||||
or "challenges/library" in json_file_path
|
||||
"challenges/deprecated" in json_file_path.as_posix()
|
||||
or "challenges/library" in json_file_path.as_posix()
|
||||
)
|
||||
|
@ -23,9 +23,10 @@ def test_get_ethereum_price() -> None:
|
||||
real_eth_price_value = float(real_eth_price)
|
||||
|
||||
# Check if the eth price is within $50 of the actual Ethereum price
|
||||
assert (
|
||||
abs(real_eth_price_value - eth_price_value) <= 50
|
||||
), f"AssertionError: Ethereum price is not within $50 of the actual Ethereum price (Provided price: ${eth_price}, Real price: ${real_eth_price})"
|
||||
assert abs(real_eth_price_value - eth_price_value) <= 50, (
|
||||
"AssertionError: Ethereum price is not within $50 of the actual Ethereum price "
|
||||
f"(Provided price: ${eth_price}, Real price: ${real_eth_price})"
|
||||
)
|
||||
|
||||
print("Matches")
|
||||
|
||||
|
@ -23,9 +23,10 @@ def test_get_ethereum_price() -> None:
|
||||
real_eth_price_value = float(real_eth_price)
|
||||
|
||||
# Check if the eth price is within $50 of the actual Ethereum price
|
||||
assert (
|
||||
abs(real_eth_price_value - eth_price_value) <= 50
|
||||
), f"AssertionError: Ethereum price is not within $50 of the actual Ethereum price (Provided price: ${eth_price}, Real price: ${real_eth_price})"
|
||||
assert abs(real_eth_price_value - eth_price_value) <= 50, (
|
||||
"AssertionError: Ethereum price is not within $50 of the actual Ethereum price "
|
||||
f"(Provided price: ${eth_price}, Real price: ${real_eth_price})"
|
||||
)
|
||||
|
||||
print("Matches")
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
# mypy: ignore-errors
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
# mypy: ignore-errors
|
||||
# pyright: reportMissingImports=false
|
||||
from typing import List
|
||||
|
||||
from sample_code import three_sum
|
||||
|
@ -21,7 +21,6 @@ def generate_password(length: int = 8) -> str:
|
||||
|
||||
if __name__ == "__main__":
|
||||
password_length = (
|
||||
int(sys.argv[sys.argv.index("--length") + 1])
|
||||
if "--length" in sys.argv else 8
|
||||
int(sys.argv[sys.argv.index("--length") + 1]) if "--length" in sys.argv else 8
|
||||
)
|
||||
print(generate_password(password_length))
|
||||
|
@ -1,3 +1,4 @@
|
||||
# pyright: reportMissingImports=false
|
||||
import unittest
|
||||
|
||||
import password_generator
|
||||
@ -18,7 +19,9 @@ class TestPasswordGenerator(unittest.TestCase):
|
||||
def test_password_content(self):
|
||||
password = password_generator.generate_password()
|
||||
self.assertTrue(any(c.isdigit() for c in password))
|
||||
self.assertTrue(any(c in password_generator.string.punctuation for c in password))
|
||||
self.assertTrue(
|
||||
any(c in password_generator.string.punctuation for c in password)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,3 +1,4 @@
|
||||
# pyright: reportMissingImports=false
|
||||
import unittest
|
||||
|
||||
from url_shortener import retrieve_url, shorten_url
|
||||
|
@ -56,7 +56,7 @@ def winner(board):
|
||||
|
||||
def getLocation():
|
||||
location = input(
|
||||
"Choose where to play. Enter two numbers separated by a comma, for example: 1,1 "
|
||||
"Choose where to play. Enter two numbers separated by a comma [example: 1,1]: "
|
||||
)
|
||||
print(f"\nYou picked {location}")
|
||||
coordinates = [int(x) for x in location.split(",")]
|
||||
@ -69,7 +69,8 @@ def getLocation():
|
||||
):
|
||||
print("You inputted a location in an invalid format")
|
||||
location = input(
|
||||
"Choose where to play. Enter two numbers separated by a comma, for example: 1,1 "
|
||||
"Choose where to play. Enter two numbers separated by a comma "
|
||||
"[example: 1,1]: "
|
||||
)
|
||||
coordinates = [int(x) for x in location.split(",")]
|
||||
return coordinates
|
||||
|
@ -37,15 +37,14 @@ class GameStatus(BaseModel):
|
||||
winner: Optional[str]
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
class Game(BaseModel):
|
||||
game_id: str
|
||||
players: List[str]
|
||||
board: dict # This could represent the state of the game board, you might need to flesh this out further
|
||||
ships: List[ShipPlacement] # List of ship placements for this game
|
||||
turns: List[Turn] # List of turns that have been taken
|
||||
players: list[str]
|
||||
# This could represent the state of the game board,
|
||||
# you might need to flesh this out further:
|
||||
board: dict
|
||||
ships: list[ShipPlacement] # List of ship placements for this game
|
||||
turns: list[Turn] # List of turns that have been taken
|
||||
|
||||
|
||||
class AbstractBattleship(ABC):
|
||||
@ -86,7 +85,7 @@ class AbstractBattleship(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_game(self) -> Game:
|
||||
def get_game(self) -> Game | None:
|
||||
"""
|
||||
Retrieve the state of the game.
|
||||
"""
|
||||
@ -103,5 +102,8 @@ class AbstractBattleship(ABC):
|
||||
def create_game(self) -> None:
|
||||
"""
|
||||
Create a new game.
|
||||
|
||||
Returns:
|
||||
str: The ID of the created game.
|
||||
"""
|
||||
pass
|
||||
|
@ -1,3 +1,4 @@
|
||||
# pyright: reportMissingImports=false
|
||||
import pytest
|
||||
from abstract_class import ShipPlacement, Turn
|
||||
from battleship import Battleship
|
||||
|
@ -50,7 +50,7 @@ def test_cant_hit_before_ships_placed(battleship_game):
|
||||
|
||||
|
||||
def test_cant_place_ship_after_all_ships_placed(battleship_game, initialized_game_id):
|
||||
game = battleship_game.get_game(initialized_game_id)
|
||||
battleship_game.get_game(initialized_game_id)
|
||||
additional_ship = ShipPlacement(
|
||||
ship_type="carrier", start={"row": 2, "column": "E"}, direction="horizontal"
|
||||
)
|
||||
|
@ -61,6 +61,7 @@ def test_ship_sinking_feedback(battleship_game, initialized_game_id):
|
||||
{"row": 1, "column": "H"},
|
||||
]
|
||||
|
||||
response = None
|
||||
for index, hit in enumerate(hits):
|
||||
turn = Turn(target={"row": 2, "column": hit})
|
||||
response = battleship_game.create_turn(initialized_game_id, turn)
|
||||
@ -69,7 +70,7 @@ def test_ship_sinking_feedback(battleship_game, initialized_game_id):
|
||||
static_turn = Turn(target=static_moves[index])
|
||||
battleship_game.create_turn(initialized_game_id, static_turn)
|
||||
|
||||
assert response.result == "sunk"
|
||||
assert response and response.result == "sunk"
|
||||
|
||||
|
||||
def test_restart_game(battleship_game):
|
||||
|
@ -37,15 +37,14 @@ class GameStatus(BaseModel):
|
||||
winner: Optional[str]
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
class Game(BaseModel):
|
||||
game_id: str
|
||||
players: List[str]
|
||||
board: dict # This could represent the state of the game board, you might need to flesh this out further
|
||||
ships: List[ShipPlacement] # List of ship placements for this game
|
||||
turns: List[Turn] # List of turns that have been taken
|
||||
players: list[str]
|
||||
# This could represent the state of the game board,
|
||||
# you might need to flesh this out further:
|
||||
board: dict
|
||||
ships: list[ShipPlacement] # List of ship placements for this game
|
||||
turns: list[Turn] # List of turns that have been taken
|
||||
|
||||
|
||||
class AbstractBattleship(ABC):
|
||||
@ -86,7 +85,7 @@ class AbstractBattleship(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_game(self) -> Game:
|
||||
def get_game(self, game_id: str) -> Game | None:
|
||||
"""
|
||||
Retrieve the state of the game.
|
||||
"""
|
||||
@ -100,8 +99,11 @@ class AbstractBattleship(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_game(self) -> None:
|
||||
def create_game(self) -> str:
|
||||
"""
|
||||
Create a new game.
|
||||
|
||||
Returns:
|
||||
str: The ID of the created game.
|
||||
"""
|
||||
pass
|
||||
|
@ -1,14 +1,20 @@
|
||||
from typing import Dict
|
||||
|
||||
from abstract_class import (AbstractBattleship, Game, GameStatus,
|
||||
ShipPlacement, Turn, TurnResponse)
|
||||
from abstract_class import (
|
||||
AbstractBattleship,
|
||||
Game,
|
||||
GameStatus,
|
||||
ShipPlacement,
|
||||
Turn,
|
||||
TurnResponse,
|
||||
)
|
||||
|
||||
|
||||
class Battleship(AbstractBattleship):
|
||||
def __init__(self):
|
||||
self.games: Dict[int, Game] = {}
|
||||
self.games: Dict[str, Game] = {}
|
||||
|
||||
def create_game(self) -> int:
|
||||
def create_game(self) -> str:
|
||||
game_id = str(len(self.games))
|
||||
new_game = Game(
|
||||
game_id=game_id,
|
||||
@ -19,7 +25,7 @@ class Battleship(AbstractBattleship):
|
||||
)
|
||||
|
||||
self.games[game_id] = new_game
|
||||
return new_game.game_id
|
||||
return game_id
|
||||
|
||||
def create_ship_placement(self, game_id: str, placement: ShipPlacement) -> None:
|
||||
game = self.games.get(game_id)
|
||||
@ -79,38 +85,34 @@ class Battleship(AbstractBattleship):
|
||||
|
||||
game.turns.append(turn)
|
||||
|
||||
if hit_ship == "hit":
|
||||
if not hit_ship or hit_ship == "hit": # if no ship or already hit
|
||||
return TurnResponse(result="miss", ship_type=None)
|
||||
|
||||
if hit_ship:
|
||||
ship_placement = next(sp for sp in game.ships if sp.ship_type == hit_ship)
|
||||
ship_placement = next(sp for sp in game.ships if sp.ship_type == hit_ship)
|
||||
start_row, start_col = (
|
||||
ship_placement.start["row"],
|
||||
ord(ship_placement.start["column"]) - ord("A"),
|
||||
)
|
||||
ship_positions = [
|
||||
(
|
||||
start_row + (i if ship_placement.direction == "vertical" else 0),
|
||||
start_col + (i if ship_placement.direction == "horizontal" else 0),
|
||||
)
|
||||
for i in range(self.SHIP_LENGTHS[hit_ship])
|
||||
]
|
||||
|
||||
if hit_ship:
|
||||
ship_placement = next(sp for sp in game.ships if sp.ship_type == hit_ship)
|
||||
start_row, start_col = ship_placement.start["row"], ord(
|
||||
ship_placement.start["column"]
|
||||
) - ord("A")
|
||||
ship_positions = [
|
||||
(
|
||||
start_row + (i if ship_placement.direction == "vertical" else 0),
|
||||
start_col + (i if ship_placement.direction == "horizontal" else 0),
|
||||
)
|
||||
for i in range(self.SHIP_LENGTHS[hit_ship])
|
||||
]
|
||||
targeted_positions = {
|
||||
(t.target["row"], ord(t.target["column"]) - ord("A")) for t in game.turns
|
||||
}
|
||||
|
||||
targeted_positions = {
|
||||
(t.target["row"], ord(t.target["column"]) - ord("A"))
|
||||
for t in game.turns
|
||||
}
|
||||
game.board[(target_row, target_col)] = "hit"
|
||||
|
||||
game.board[(target_row, target_col)] = "hit"
|
||||
|
||||
if set(ship_positions).issubset(targeted_positions):
|
||||
for pos in ship_positions:
|
||||
game.board[pos] = "hit"
|
||||
return TurnResponse(result="sunk", ship_type=hit_ship)
|
||||
else:
|
||||
return TurnResponse(result="hit", ship_type=hit_ship)
|
||||
if set(ship_positions).issubset(targeted_positions):
|
||||
for pos in ship_positions:
|
||||
game.board[pos] = "hit"
|
||||
return TurnResponse(result="sunk", ship_type=hit_ship)
|
||||
else:
|
||||
return TurnResponse(result="hit", ship_type=hit_ship)
|
||||
|
||||
def get_game_status(self, game_id: str) -> GameStatus:
|
||||
game = self.games.get(game_id)
|
||||
@ -132,12 +134,12 @@ class Battleship(AbstractBattleship):
|
||||
def get_winner(self, game_id: str) -> str:
|
||||
game_status = self.get_game_status(game_id)
|
||||
|
||||
if game_status.is_game_over:
|
||||
if game_status.is_game_over and game_status.winner:
|
||||
return game_status.winner
|
||||
else:
|
||||
return None
|
||||
raise ValueError(f"Game {game_id} isn't over yet")
|
||||
|
||||
def get_game(self, game_id: str) -> Game:
|
||||
def get_game(self, game_id: str) -> Game | None:
|
||||
return self.games.get(game_id)
|
||||
|
||||
def delete_game(self, game_id: str) -> None:
|
||||
|
@ -50,7 +50,7 @@ def test_cant_hit_before_ships_placed(battleship_game):
|
||||
|
||||
|
||||
def test_cant_place_ship_after_all_ships_placed(battleship_game, initialized_game_id):
|
||||
game = battleship_game.get_game(initialized_game_id)
|
||||
battleship_game.get_game(initialized_game_id)
|
||||
additional_ship = ShipPlacement(
|
||||
ship_type="carrier", start={"row": 2, "column": "E"}, direction="horizontal"
|
||||
)
|
||||
|
@ -61,6 +61,7 @@ def test_ship_sinking_feedback(battleship_game, initialized_game_id):
|
||||
{"row": 1, "column": "H"},
|
||||
]
|
||||
|
||||
response = None
|
||||
for index, hit in enumerate(hits):
|
||||
turn = Turn(target={"row": 2, "column": hit})
|
||||
response = battleship_game.create_turn(initialized_game_id, turn)
|
||||
@ -69,7 +70,7 @@ def test_ship_sinking_feedback(battleship_game, initialized_game_id):
|
||||
static_turn = Turn(target=static_moves[index])
|
||||
battleship_game.create_turn(initialized_game_id, static_turn)
|
||||
|
||||
assert response.result == "sunk"
|
||||
assert response and response.result == "sunk"
|
||||
|
||||
|
||||
def test_restart_game(battleship_game):
|
||||
|
@ -6,7 +6,7 @@ from typing import ClassVar, Iterator, Literal
|
||||
import pytest
|
||||
import requests
|
||||
from agent_protocol_client import AgentApi, Step
|
||||
from pydantic import BaseModel, validator, ValidationError
|
||||
from pydantic import BaseModel, ValidationError, validator
|
||||
|
||||
from agbenchmark.config import AgentBenchmarkConfig
|
||||
from agbenchmark.utils.data_types import Category, EvalResult
|
||||
@ -93,11 +93,12 @@ class Eval(ABC):
|
||||
...
|
||||
|
||||
|
||||
class StringEval(BaseModel, Eval):
|
||||
type: ReferenceAnswerType
|
||||
class BaseStringEval(BaseModel, Eval):
|
||||
# type: ReferenceAnswerType
|
||||
pass
|
||||
|
||||
|
||||
class ExactStringMatchEval(StringEval):
|
||||
class ExactStringMatchEval(BaseStringEval):
|
||||
type: Literal["exact_match"] = "exact_match"
|
||||
reference_answer: str
|
||||
|
||||
@ -109,7 +110,7 @@ class ExactStringMatchEval(StringEval):
|
||||
return string == self.reference_answer
|
||||
|
||||
|
||||
class FuzzyStringMatchEval(StringEval):
|
||||
class FuzzyStringMatchEval(BaseStringEval):
|
||||
type: Literal["fuzzy_match"] = "fuzzy_match"
|
||||
reference_answer: str
|
||||
|
||||
@ -122,7 +123,7 @@ class FuzzyStringMatchEval(StringEval):
|
||||
return self.reference_answer.lower() in string.lower()
|
||||
|
||||
|
||||
class MustIncludeStringEval(StringEval):
|
||||
class MustIncludeStringEval(BaseStringEval):
|
||||
type: Literal["must_include"] = "must_include"
|
||||
reference_answer: str
|
||||
|
||||
@ -134,6 +135,9 @@ class MustIncludeStringEval(StringEval):
|
||||
return self.reference_answer.lower() in string.lower()
|
||||
|
||||
|
||||
StringEval = ExactStringMatchEval | FuzzyStringMatchEval | MustIncludeStringEval
|
||||
|
||||
|
||||
class UrlMatchEval(BaseModel, Eval):
|
||||
url: str
|
||||
"""Example: `"__WIKI__/wiki/Octopus"`"""
|
||||
@ -142,8 +146,8 @@ class UrlMatchEval(BaseModel, Eval):
|
||||
def description(self) -> str:
|
||||
return f"Agent must navigate to '{self.url}'"
|
||||
|
||||
def evaluate(self, url: str) -> bool:
|
||||
return url == resolve_uri(self.url)
|
||||
def evaluate(self, string: str) -> bool:
|
||||
return string == resolve_uri(self.url)
|
||||
|
||||
|
||||
class ProgramHtmlEval(BaseModel):
|
||||
@ -258,7 +262,8 @@ class WebArenaChallengeSpec(BaseModel):
|
||||
f"{' and '.join(s.base_url for s in sites)}.\n\n"
|
||||
+ "\n".join(
|
||||
s.additional_info.format(url=s.base_url)
|
||||
for s in sites if s.additional_info
|
||||
for s in sites
|
||||
if s.additional_info
|
||||
)
|
||||
).strip()
|
||||
|
||||
@ -391,7 +396,9 @@ class WebArenaChallenge(BaseChallenge):
|
||||
if request.config.getoption("--nc"):
|
||||
timeout = 100000
|
||||
elif cutoff := request.config.getoption("--cutoff"):
|
||||
timeout = int(cutoff)
|
||||
timeout = int(cutoff) # type: ignore
|
||||
|
||||
assert isinstance(request.node, pytest.Item)
|
||||
|
||||
n_steps = 0
|
||||
timed_out = None
|
||||
@ -400,7 +407,7 @@ class WebArenaChallenge(BaseChallenge):
|
||||
eval_results_per_step: list[list[tuple[_Eval, EvalResult]]] = []
|
||||
try:
|
||||
async for step in self.run_challenge(
|
||||
config, timeout, mock=request.config.getoption("--mock")
|
||||
config, timeout, mock=bool(request.config.getoption("--mock"))
|
||||
):
|
||||
if not step.output:
|
||||
logger.warn(f"Step has no output: {step}")
|
||||
@ -415,7 +422,7 @@ class WebArenaChallenge(BaseChallenge):
|
||||
)
|
||||
|
||||
step_eval_results = self.evaluate_step_result(
|
||||
step, mock=request.config.getoption("--mock")
|
||||
step, mock=bool(request.config.getoption("--mock"))
|
||||
)
|
||||
logger.debug(f"Intermediary results: {step_eval_results}")
|
||||
eval_results_per_step.append(step_eval_results)
|
||||
@ -462,7 +469,7 @@ class WebArenaChallenge(BaseChallenge):
|
||||
|
||||
|
||||
def load_webarena_challenges(
|
||||
skip_unavailable: bool = True
|
||||
skip_unavailable: bool = True,
|
||||
) -> Iterator[type[WebArenaChallenge]]:
|
||||
logger.info("Loading WebArena challenges...")
|
||||
|
||||
|
@ -123,8 +123,10 @@ def check_regression(request: pytest.FixtureRequest) -> None:
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
rt_tracker = RegressionTestsTracker(agbenchmark_config.regression_tests_file)
|
||||
|
||||
assert isinstance(request.node, pytest.Function)
|
||||
assert isinstance(request.node.parent, pytest.Class)
|
||||
test_name = request.node.parent.name
|
||||
challenge_location = getattr(request.node.parent.cls, "CHALLENGE_LOCATION", "")
|
||||
challenge_location = getattr(request.node.cls, "CHALLENGE_LOCATION", "")
|
||||
skip_string = f"Skipping {test_name} at {challenge_location}"
|
||||
|
||||
# Check if the test name exists in the regression tests
|
||||
@ -148,7 +150,9 @@ def mock(request: pytest.FixtureRequest) -> bool:
|
||||
Returns:
|
||||
bool: Whether `--mock` is set for this session.
|
||||
"""
|
||||
return request.config.getoption("--mock")
|
||||
mock = request.config.getoption("--mock")
|
||||
assert isinstance(mock, bool)
|
||||
return mock
|
||||
|
||||
|
||||
test_reports: dict[str, Test] = {}
|
||||
@ -221,7 +225,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc):
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(
|
||||
items: list[pytest.Item], config: pytest.Config
|
||||
items: list[pytest.Function], config: pytest.Config
|
||||
) -> None:
|
||||
"""
|
||||
Pytest hook that is called after initial test collection has been performed.
|
||||
@ -248,8 +252,9 @@ def pytest_collection_modifyitems(
|
||||
i = 0
|
||||
while i < len(items):
|
||||
item = items[i]
|
||||
assert item.cls and issubclass(item.cls, BaseChallenge)
|
||||
challenge = item.cls
|
||||
challenge_name = item.cls.__name__
|
||||
challenge_name = challenge.info.name
|
||||
|
||||
if not issubclass(challenge, BaseChallenge):
|
||||
item.warn(
|
||||
|
@ -18,9 +18,9 @@ def run_benchmark(
|
||||
maintain: bool = False,
|
||||
improve: bool = False,
|
||||
explore: bool = False,
|
||||
tests: tuple[str] = tuple(),
|
||||
categories: tuple[str] = tuple(),
|
||||
skip_categories: tuple[str] = tuple(),
|
||||
tests: tuple[str, ...] = tuple(),
|
||||
categories: tuple[str, ...] = tuple(),
|
||||
skip_categories: tuple[str, ...] = tuple(),
|
||||
attempts_per_challenge: int = 1,
|
||||
mock: bool = False,
|
||||
no_dep: bool = False,
|
||||
|
@ -53,9 +53,9 @@ class SingletonReportManager:
|
||||
@classmethod
|
||||
def clear_instance(cls):
|
||||
cls.instance = None
|
||||
cls.INFO_MANAGER = None
|
||||
cls.REGRESSION_MANAGER = None
|
||||
cls.SUCCESS_RATE_TRACKER = None
|
||||
del cls.INFO_MANAGER
|
||||
del cls.REGRESSION_MANAGER
|
||||
del cls.SUCCESS_RATE_TRACKER
|
||||
|
||||
|
||||
class BaseReportManager:
|
||||
@ -99,7 +99,8 @@ class BaseReportManager:
|
||||
class SessionReportManager(BaseReportManager):
|
||||
"""Abstracts interaction with the regression tests file"""
|
||||
|
||||
tests: dict[str, Test] | Report
|
||||
tests: dict[str, Test]
|
||||
report: Report | None = None
|
||||
|
||||
def __init__(self, report_file: Path, benchmark_start_time: datetime):
|
||||
super().__init__(report_file)
|
||||
@ -109,20 +110,21 @@ class SessionReportManager(BaseReportManager):
|
||||
|
||||
def save(self) -> None:
|
||||
with self.report_file.open("w") as f:
|
||||
if isinstance(self.tests, Report):
|
||||
f.write(self.tests.json(indent=4))
|
||||
if self.report:
|
||||
f.write(self.report.json(indent=4))
|
||||
else:
|
||||
json.dump({k: v.dict() for k, v in self.tests.items()}, f, indent=4)
|
||||
|
||||
def load(self) -> None:
|
||||
super().load()
|
||||
if "tests" in self.tests: # type: ignore
|
||||
self.tests = Report.parse_obj(self.tests)
|
||||
|
||||
if "tests" in self.tests:
|
||||
self.report = Report.parse_obj(self.tests)
|
||||
else:
|
||||
self.tests = {n: Test.parse_obj(d) for n, d in self.tests.items()}
|
||||
|
||||
def add_test_report(self, test_name: str, test_report: Test) -> None:
|
||||
if isinstance(self.tests, Report):
|
||||
if self.report:
|
||||
raise RuntimeError("Session report already finalized")
|
||||
|
||||
if test_name.startswith("Test"):
|
||||
@ -134,10 +136,10 @@ class SessionReportManager(BaseReportManager):
|
||||
def finalize_session_report(self, config: AgentBenchmarkConfig) -> None:
|
||||
command = " ".join(sys.argv)
|
||||
|
||||
if isinstance(self.tests, Report):
|
||||
if self.report:
|
||||
raise RuntimeError("Session report already finalized")
|
||||
|
||||
self.tests = Report(
|
||||
self.report = Report(
|
||||
command=command.split(os.sep)[-1],
|
||||
benchmark_git_commit_sha="---",
|
||||
agent_git_commit_sha="---",
|
||||
@ -156,7 +158,7 @@ class SessionReportManager(BaseReportManager):
|
||||
config=config.dict(exclude={"reports_folder"}, exclude_none=True),
|
||||
)
|
||||
|
||||
agent_categories = get_highest_achieved_difficulty_per_category(self.tests)
|
||||
agent_categories = get_highest_achieved_difficulty_per_category(self.report)
|
||||
if len(agent_categories) > 1:
|
||||
save_single_radar_chart(
|
||||
agent_categories,
|
||||
@ -166,8 +168,8 @@ class SessionReportManager(BaseReportManager):
|
||||
self.save()
|
||||
|
||||
def get_total_costs(self):
|
||||
if isinstance(self.tests, Report):
|
||||
tests = self.tests.tests
|
||||
if self.report:
|
||||
tests = self.report.tests
|
||||
else:
|
||||
tests = self.tests
|
||||
|
||||
|
@ -3,7 +3,7 @@ Model definitions used internally and for reports generated during command-line
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
from typing import Annotated, Any, Dict, List
|
||||
|
||||
from agent_protocol_client import Step
|
||||
from pydantic import BaseModel, Field, constr, validator
|
||||
@ -88,7 +88,7 @@ class Test(BaseModel):
|
||||
class ReportBase(BaseModel):
|
||||
command: str
|
||||
completion_time: str | None = None
|
||||
benchmark_start_time: constr(regex=datetime_format)
|
||||
benchmark_start_time: Annotated[str, constr(regex=datetime_format)]
|
||||
metrics: MetricsOverall
|
||||
config: Dict[str, str | dict[str, str]]
|
||||
agent_git_commit_sha: str | None = None
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Model definitions for use in the API"""
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel, constr
|
||||
|
||||
@ -36,7 +37,7 @@ class RunDetails(BaseModel):
|
||||
run_id: str | None = None
|
||||
command: str
|
||||
completion_time: str | None = None
|
||||
benchmark_start_time: constr(regex=datetime_format)
|
||||
benchmark_start_time: Annotated[str, constr(regex=datetime_format)]
|
||||
|
||||
|
||||
class BenchmarkRun(BaseModel):
|
||||
|
@ -1,14 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TaskInput(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class TaskRequestBody(BaseModel):
|
||||
input: str = Field(
|
||||
...,
|
||||
@ -16,7 +12,7 @@ class TaskRequestBody(BaseModel):
|
||||
description="Input prompt for the task.",
|
||||
example="Write the words you receive to the file 'output.txt'.",
|
||||
)
|
||||
additional_input: Optional[TaskInput] = {}
|
||||
additional_input: Optional[dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class TaskEvalRequestBody(TaskRequestBody):
|
||||
|
@ -32,7 +32,10 @@ def _add_ini_and_option(
|
||||
default: str | bool | int,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Add an option to both the ini file as well as the command line flags, with the latter overriding the former."""
|
||||
"""
|
||||
Add an option to both the ini file and the command line flags.
|
||||
Command line flags/options takes precedence over the ini config.
|
||||
"""
|
||||
parser.addini(
|
||||
name,
|
||||
help + " This overrides the similarly named option from the config.",
|
||||
@ -44,7 +47,10 @@ def _add_ini_and_option(
|
||||
def _get_ini_or_option(
|
||||
config: Any, name: str, choices: Optional[list[str]]
|
||||
) -> str | None:
|
||||
"""Get an option from either the ini file or the command line flags, the latter taking precedence."""
|
||||
"""
|
||||
Get an option from either the ini file or the command line flags,
|
||||
with the latter taking precedence.
|
||||
"""
|
||||
value = config.getini(name)
|
||||
if value is not None and choices is not None and value not in choices:
|
||||
raise ValueError(
|
||||
@ -73,7 +79,7 @@ def pytest_addoption(parser: Parser) -> None:
|
||||
default=False,
|
||||
help=(
|
||||
"List all non-nodeid dependency names + the tests they resolve to. "
|
||||
"Will also list all nodeid dependency names when verbosity is high enough."
|
||||
"Will also list all nodeid dependency names in verbose mode."
|
||||
),
|
||||
)
|
||||
|
||||
@ -83,7 +89,10 @@ def pytest_addoption(parser: Parser) -> None:
|
||||
"--list-processed-dependencies",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="List all dependencies of all tests as a list of nodeids + the names that could not be resolved.",
|
||||
help=(
|
||||
"List all dependencies of all tests as a list of nodeids "
|
||||
"+ the names that could not be resolved."
|
||||
),
|
||||
)
|
||||
|
||||
# Add an ini option + flag to choose the action to take for failed dependencies
|
||||
@ -94,7 +103,8 @@ def pytest_addoption(parser: Parser) -> None:
|
||||
name="failed_dependency_action",
|
||||
help=(
|
||||
"The action to take when a test has dependencies that failed. "
|
||||
'Use "run" to run the test anyway, "skip" to skip the test, and "fail" to fail the test.'
|
||||
'Use "run" to run the test anyway, "skip" to skip the test, '
|
||||
'and "fail" to fail the test.'
|
||||
),
|
||||
default="skip",
|
||||
choices=DEPENDENCY_PROBLEM_ACTIONS.keys(),
|
||||
@ -107,8 +117,10 @@ def pytest_addoption(parser: Parser) -> None:
|
||||
group,
|
||||
name="missing_dependency_action",
|
||||
help=(
|
||||
"The action to take when a test has dependencies that cannot be found within the current scope. "
|
||||
'Use "run" to run the test anyway, "skip" to skip the test, and "fail" to fail the test.'
|
||||
"The action to take when a test has dependencies that cannot be found "
|
||||
"within the current scope. "
|
||||
'Use "run" to run the test anyway, "skip" to skip the test, '
|
||||
'and "fail" to fail the test.'
|
||||
),
|
||||
default="warning",
|
||||
choices=DEPENDENCY_PROBLEM_ACTIONS.keys(),
|
||||
@ -139,7 +151,7 @@ def pytest_configure(config: Any) -> None:
|
||||
|
||||
|
||||
@pytest.hookimpl(trylast=True)
|
||||
def pytest_collection_modifyitems(config: Any, items: list[Item]) -> None:
|
||||
def pytest_collection_modifyitems(config: Any, items: list[pytest.Function]) -> None:
|
||||
manager = managers[-1]
|
||||
|
||||
# Register the founds tests on the manager
|
||||
|
@ -3,7 +3,7 @@
|
||||
# The name of the marker used
|
||||
MARKER_NAME = "depends"
|
||||
|
||||
# The name of the keyword argument for the marker that contains custom name(s) for the tests
|
||||
# The name of the kwarg for 'depends' markers that contains custom name(s) for the tests
|
||||
MARKER_KWARG_ID = "name"
|
||||
|
||||
# The name of the keyword argument for the marker that specifies the tests to depend on
|
||||
|
@ -57,8 +57,10 @@ def curved_edges(
|
||||
"""
|
||||
ax = plt.gca()
|
||||
for u, v, data in G.edges(data=True):
|
||||
src = np.array(pos[u])
|
||||
dst = np.array(pos[v])
|
||||
_src = pos[u]
|
||||
_dst = pos[v]
|
||||
src = np.array(_src)
|
||||
dst = np.array(_dst)
|
||||
|
||||
same_level = abs(src[1] - dst[1]) < 0.01
|
||||
|
||||
@ -68,7 +70,7 @@ def curved_edges(
|
||||
arrow = patches.FancyArrowPatch(
|
||||
posA=curve[0], # type: ignore
|
||||
posB=curve[-1], # type: ignore
|
||||
connectionstyle=f"arc3,rad=0.2",
|
||||
connectionstyle="arc3,rad=0.2",
|
||||
color="gray",
|
||||
arrowstyle="-|>",
|
||||
mutation_scale=15.0,
|
||||
@ -80,8 +82,8 @@ def curved_edges(
|
||||
else:
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=dst,
|
||||
xytext=src,
|
||||
xy=_dst,
|
||||
xytext=_src,
|
||||
arrowprops=dict(
|
||||
arrowstyle="-|>", color="gray", lw=1, shrinkA=10, shrinkB=10
|
||||
),
|
||||
@ -89,7 +91,8 @@ def curved_edges(
|
||||
|
||||
|
||||
def tree_layout(graph: nx.DiGraph, root_node: Any) -> Dict[Any, Tuple[float, float]]:
|
||||
"""Compute positions as a tree layout centered on the root with alternating vertical shifts."""
|
||||
"""Compute positions as a tree layout centered on the root
|
||||
with alternating vertical shifts."""
|
||||
bfs_tree = nx.bfs_tree(graph, source=root_node)
|
||||
levels = {
|
||||
node: depth
|
||||
@ -137,7 +140,7 @@ def tree_layout(graph: nx.DiGraph, root_node: Any) -> Dict[Any, Tuple[float, flo
|
||||
def graph_spring_layout(
|
||||
dag: nx.DiGraph, labels: Dict[Any, str], tree: bool = True
|
||||
) -> None:
|
||||
num_nodes = len(dag.nodes())
|
||||
num_nodes = len(list(dag.nodes()))
|
||||
# Setting up the figure and axis
|
||||
fig, ax = plt.subplots()
|
||||
ax.axis("off") # Turn off the axis
|
||||
@ -288,7 +291,8 @@ def graph_interactive_network(
|
||||
|
||||
# Optionally, save to a file
|
||||
# Sync with the flutter UI
|
||||
# this literally only works in the AutoGPT repo, but this part of the code is not reached if BUILD_SKILL_TREE is false
|
||||
# this literally only works in the AutoGPT repo, but this part of the code
|
||||
# is not reached if BUILD_SKILL_TREE is false
|
||||
write_pretty_json(graph_data, flutter_app_path / "tree_structure.json")
|
||||
validate_skill_tree(graph_data, "")
|
||||
|
||||
@ -332,11 +336,13 @@ def graph_interactive_network(
|
||||
|
||||
def extract_subgraph_based_on_category(graph, category):
|
||||
"""
|
||||
Extracts a subgraph that includes all nodes and edges required to reach all nodes with a specified category.
|
||||
Extracts a subgraph that includes all nodes and edges required to reach all nodes
|
||||
with a specified category.
|
||||
|
||||
:param graph: The original graph.
|
||||
:param category: The target category.
|
||||
:return: Subgraph with nodes and edges required to reach the nodes with the given category.
|
||||
:return: Subgraph with nodes and edges required to reach the nodes
|
||||
with the given category.
|
||||
"""
|
||||
|
||||
subgraph = {"nodes": [], "edges": []}
|
||||
@ -424,7 +430,8 @@ def get_roots(graph):
|
||||
|
||||
def validate_skill_tree(graph, skill_tree_name):
|
||||
"""
|
||||
Validate if a given graph represents a valid skill tree and raise appropriate exceptions if not.
|
||||
Validate if a given graph represents a valid skill tree
|
||||
and raise appropriate exceptions if not.
|
||||
|
||||
:param graph: A dictionary representing the graph with 'nodes' and 'edges'.
|
||||
:raises: ValueError with a description of the invalidity.
|
||||
@ -434,7 +441,8 @@ def validate_skill_tree(graph, skill_tree_name):
|
||||
if cycle_path:
|
||||
cycle_str = " -> ".join(cycle_path)
|
||||
raise ValueError(
|
||||
f"{skill_tree_name} skill tree is circular! Circular path detected: {cycle_str}."
|
||||
f"{skill_tree_name} skill tree is circular! "
|
||||
f"Detected circular path: {cycle_str}."
|
||||
)
|
||||
|
||||
# Check for multiple roots
|
||||
|
@ -1,18 +1,19 @@
|
||||
"""
|
||||
A module to manage dependencies between pytest tests.
|
||||
|
||||
This module provides the methods implementing the main logic. These are used in the pytest hooks that are in
|
||||
__init__.py.
|
||||
This module provides the methods implementing the main logic.
|
||||
These are used in the pytest hooks that are in __init__.py.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Generator
|
||||
|
||||
import colorama
|
||||
import networkx
|
||||
from _pytest.nodes import Item
|
||||
from pytest import Function, Item
|
||||
|
||||
from agbenchmark.challenges.base import BaseChallenge
|
||||
|
||||
from .constants import MARKER_KWARG_DEPENDENCIES, MARKER_NAME
|
||||
from .graphs import graph_interactive_network
|
||||
@ -38,7 +39,8 @@ class TestResult(object):
|
||||
)
|
||||
if result.when in self.results:
|
||||
raise AttributeError(
|
||||
f"Received multiple results for step {result.when} of test {self.nodeid}"
|
||||
f"Received multiple results for step {result.when} "
|
||||
f"of test {self.nodeid}"
|
||||
)
|
||||
self.results[result.when] = result.outcome
|
||||
|
||||
@ -66,7 +68,7 @@ class TestDependencies(object):
|
||||
for dep in marker.kwargs.get(MARKER_KWARG_DEPENDENCIES, [])
|
||||
]
|
||||
for dependency in dependencies:
|
||||
# If the name is not known, try to make it absolute (ie file::[class::]method)
|
||||
# If the name is not known, try to make it absolute (file::[class::]method)
|
||||
if dependency not in manager.name_to_nodeids:
|
||||
absolute_dependency = get_absolute_nodeid(dependency, self.nodeid)
|
||||
if absolute_dependency in manager.name_to_nodeids:
|
||||
@ -86,20 +88,20 @@ class DependencyManager(object):
|
||||
def __init__(self) -> None:
|
||||
"""Create a new DependencyManager."""
|
||||
self.options: dict[str, Any] = {}
|
||||
self._items: list[Item] | None = None
|
||||
self._items: list[Function] | None = None
|
||||
self._name_to_nodeids: Any = None
|
||||
self._nodeid_to_item: Any = None
|
||||
self._results: Any = None
|
||||
|
||||
@property
|
||||
def items(self) -> list[Item]:
|
||||
def items(self) -> list[Function]:
|
||||
"""The collected tests that are managed by this instance."""
|
||||
if self._items is None:
|
||||
raise AttributeError("The items attribute has not been set yet")
|
||||
return self._items
|
||||
|
||||
@items.setter
|
||||
def items(self, items: list[Item]) -> None:
|
||||
def items(self, items: list[Function]) -> None:
|
||||
if self._items is not None:
|
||||
raise AttributeError("The items attribute has already been set")
|
||||
self._items = items
|
||||
@ -125,7 +127,8 @@ class DependencyManager(object):
|
||||
for item in items:
|
||||
nodeid = clean_nodeid(item.nodeid)
|
||||
# Process the dependencies of this test
|
||||
# This uses the mappings created in the previous loop, and can thus not be merged into that loop
|
||||
# This uses the mappings created in the previous loop,
|
||||
# and can thus not be merged into that loop
|
||||
self._dependencies[nodeid] = TestDependencies(item, self)
|
||||
|
||||
@property
|
||||
@ -135,7 +138,7 @@ class DependencyManager(object):
|
||||
return self._name_to_nodeids
|
||||
|
||||
@property
|
||||
def nodeid_to_item(self) -> dict[str, Item]:
|
||||
def nodeid_to_item(self) -> dict[str, Function]:
|
||||
"""A mapping from node ids to test items."""
|
||||
assert self.items is not None
|
||||
return self._nodeid_to_item
|
||||
@ -194,7 +197,9 @@ class DependencyManager(object):
|
||||
|
||||
@property
|
||||
def sorted_items(self) -> Generator:
|
||||
"""Get a sorted list of tests where all tests are sorted after their dependencies."""
|
||||
"""
|
||||
Get a sorted list of tests where all tests are sorted after their dependencies.
|
||||
"""
|
||||
# Build a directed graph for sorting
|
||||
build_skill_tree = os.getenv("BUILD_SKILL_TREE")
|
||||
BUILD_SKILL_TREE = (
|
||||
@ -202,8 +207,8 @@ class DependencyManager(object):
|
||||
)
|
||||
dag = networkx.DiGraph()
|
||||
|
||||
# Insert all items as nodes, to prevent items that have no dependencies and are not dependencies themselves from
|
||||
# being lost
|
||||
# Insert all items as nodes, to prevent items that have no dependencies
|
||||
# and are not dependencies themselves from being lost
|
||||
dag.add_nodes_from(self.items)
|
||||
|
||||
# Insert edges for all the dependencies
|
||||
@ -214,11 +219,8 @@ class DependencyManager(object):
|
||||
|
||||
labels = {}
|
||||
for item in self.items:
|
||||
try:
|
||||
with open(item.cls.CHALLENGE_LOCATION) as f:
|
||||
data = json.load(f)
|
||||
except:
|
||||
data = {}
|
||||
assert item.cls and issubclass(item.cls, BaseChallenge)
|
||||
data = item.cls.info.dict()
|
||||
|
||||
node_name = get_name(item)
|
||||
data["name"] = node_name
|
||||
|
@ -38,7 +38,8 @@ def strip_nodeid_parameters(nodeid: str) -> str:
|
||||
|
||||
def get_absolute_nodeid(nodeid: str, scope: str) -> str:
|
||||
"""
|
||||
Transform a possibly relative node id to an absolute one using the scope in which it is used.
|
||||
Transform a possibly relative node id to an absolute one
|
||||
using the scope in which it is used.
|
||||
|
||||
>>> scope = 'test_file.py::TestClass::test'
|
||||
>>> get_absolute_nodeid('test2', scope)
|
||||
@ -49,7 +50,7 @@ def get_absolute_nodeid(nodeid: str, scope: str) -> str:
|
||||
'test_file2.py::TestClass2::test2'
|
||||
"""
|
||||
parts = nodeid.split("::")
|
||||
# Completely relative (test_name), so add the full current scope (either file::class or file)
|
||||
# Completely relative (test_name): add the full current scope (file::class or file)
|
||||
if len(parts) == 1:
|
||||
base_nodeid = scope.rsplit("::", 1)[0]
|
||||
nodeid = f"{base_nodeid}::{nodeid}"
|
||||
|
@ -15,7 +15,8 @@ def get_data_from_helicone(challenge: str) -> Optional[float]:
|
||||
# Define the endpoint of your GraphQL server
|
||||
url = "https://www.helicone.ai/api/graphql"
|
||||
|
||||
# Set the headers, usually you'd need to set the content type and possibly an authorization token
|
||||
# Set the headers, usually you'd need to set the content type
|
||||
# and possibly an authorization token
|
||||
headers = {"authorization": f"Bearer {os.environ.get('HELICONE_API_KEY')}"}
|
||||
|
||||
# Define the query, variables, and operation name
|
||||
|
@ -1,7 +1,18 @@
|
||||
SCORING_MAP = {
|
||||
"percentage": "assign a float score that will represent a percentage out of 100. Use decimal points to be even more accurate. 0 represents the worst possible generation, while 100 represents the ideal generation",
|
||||
"scale": "assign an integer score from a scale of 1-10. 1 represents a really bad generation, while 10 represents an ideal generation",
|
||||
"binary": "assign a binary score of either 0 or 1. 0 represents a failure, while 1 represents a success",
|
||||
"percentage": (
|
||||
"assign a float score that will represent a percentage out of 100. "
|
||||
"Use decimal points to be even more accurate. "
|
||||
"0 represents the worst possible generation, "
|
||||
"while 100 represents the ideal generation"
|
||||
),
|
||||
"scale": (
|
||||
"assign an integer score from a scale of 1-10. "
|
||||
"1 represents a really bad generation, while 10 represents an ideal generation"
|
||||
),
|
||||
"binary": (
|
||||
"assign a binary score of either 0 or 1. "
|
||||
"0 represents a failure, while 1 represents a success"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@ -17,7 +28,7 @@ Here is the ideal response you're comparing to based on the task:
|
||||
Here is the current machine generated response to the task that you need to evaluate:
|
||||
{response}
|
||||
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
RUBRIC_PROMPT = """Ignore previous directions. You are now an expert at evaluating machine generated responses to given tasks.
|
||||
In order to score the generated texts you will {scoring}. Make sure to factor in rubric into your thinking, deliberation, and final result regarding scoring. Return nothing but a float score.
|
||||
@ -31,7 +42,7 @@ Use the below rubric to guide your thinking about scoring:
|
||||
Here is the current machine generated response to the task that you need to evaluate:
|
||||
{response}
|
||||
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
QUESTION_PROMPT = """Ignore previous directions. You are now an expert at evaluating machine generated responses to given tasks.
|
||||
In order to score the generated texts you will {scoring}. Make sure to think about whether the generated response answers the question well in order to score accurately. Return nothing but a float score.
|
||||
@ -45,12 +56,12 @@ Here is a question that checks if the task was completed correctly:
|
||||
Here is the current machine generated response to the task that you need to evaluate:
|
||||
{response}
|
||||
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
FEW_SHOT_EXAMPLES = """Here are some examples of how to score a machine generated response based on the above:
|
||||
{examples}
|
||||
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
CUSTOM_PROMPT = """{custom}
|
||||
{scoring}
|
||||
|
@ -202,11 +202,15 @@ def sorted_by_enum_index(
|
||||
sortable: Iterable[T],
|
||||
enum: type[Enum],
|
||||
*,
|
||||
key: Callable[[T], Enum | None] = lambda x: x, # type: ignore
|
||||
key: Optional[Callable[[T], Enum | None]] = None,
|
||||
reverse: bool = False,
|
||||
) -> list[T]:
|
||||
return sorted(
|
||||
sortable,
|
||||
key=lambda x: enum._member_names_.index(e.name) if (e := key(x)) else 420e3,
|
||||
key=lambda x: (
|
||||
enum._member_names_.index(e.name) # type: ignore
|
||||
if (e := key(x) if key else x)
|
||||
else 420e3
|
||||
),
|
||||
reverse=reverse,
|
||||
)
|
||||
|
@ -1,13 +0,0 @@
|
||||
[mypy]
|
||||
namespace_packages = True
|
||||
follow_imports = skip
|
||||
check_untyped_defs = True
|
||||
disallow_untyped_defs = True
|
||||
exclude = ^(agbenchmark/challenges/|agent/|venv|venv-dev)
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-agbenchmark.utils.data_types.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-numpy.*]
|
||||
ignore_errors = True
|
213
benchmark/poetry.lock
generated
213
benchmark/poetry.lock
generated
@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "agent-protocol-client"
|
||||
@ -197,63 +197,49 @@ tests = ["attrs[tests-no-zope]", "zope-interface"]
|
||||
tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"]
|
||||
tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"]
|
||||
|
||||
[[package]]
|
||||
name = "autoflake"
|
||||
version = "1.7.8"
|
||||
description = "Removes unused imports and unused variables"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "autoflake-1.7.8-py3-none-any.whl", hash = "sha256:46373ef69b6714f5064c923bb28bd797c4f8a9497f557d87fc36665c6d956b39"},
|
||||
{file = "autoflake-1.7.8.tar.gz", hash = "sha256:e7e46372dee46fa1c97acf310d99d922b63d369718a270809d7c278d34a194cf"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pyflakes = ">=1.1.0,<3"
|
||||
tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""}
|
||||
|
||||
[[package]]
|
||||
name = "black"
|
||||
version = "22.3.0"
|
||||
version = "23.12.1"
|
||||
description = "The uncompromising code formatter."
|
||||
optional = false
|
||||
python-versions = ">=3.6.2"
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "black-22.3.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:2497f9c2386572e28921fa8bec7be3e51de6801f7459dffd6e62492531c47e09"},
|
||||
{file = "black-22.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5795a0375eb87bfe902e80e0c8cfaedf8af4d49694d69161e5bd3206c18618bb"},
|
||||
{file = "black-22.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e3556168e2e5c49629f7b0f377070240bd5511e45e25a4497bb0073d9dda776a"},
|
||||
{file = "black-22.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67c8301ec94e3bcc8906740fe071391bce40a862b7be0b86fb5382beefecd968"},
|
||||
{file = "black-22.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:fd57160949179ec517d32ac2ac898b5f20d68ed1a9c977346efbac9c2f1e779d"},
|
||||
{file = "black-22.3.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:cc1e1de68c8e5444e8f94c3670bb48a2beef0e91dddfd4fcc29595ebd90bb9ce"},
|
||||
{file = "black-22.3.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2fc92002d44746d3e7db7cf9313cf4452f43e9ea77a2c939defce3b10b5c82"},
|
||||
{file = "black-22.3.0-cp36-cp36m-win_amd64.whl", hash = "sha256:a6342964b43a99dbc72f72812bf88cad8f0217ae9acb47c0d4f141a6416d2d7b"},
|
||||
{file = "black-22.3.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:328efc0cc70ccb23429d6be184a15ce613f676bdfc85e5fe8ea2a9354b4e9015"},
|
||||
{file = "black-22.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06f9d8846f2340dfac80ceb20200ea5d1b3f181dd0556b47af4e8e0b24fa0a6b"},
|
||||
{file = "black-22.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:ad4efa5fad66b903b4a5f96d91461d90b9507a812b3c5de657d544215bb7877a"},
|
||||
{file = "black-22.3.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:e8477ec6bbfe0312c128e74644ac8a02ca06bcdb8982d4ee06f209be28cdf163"},
|
||||
{file = "black-22.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:637a4014c63fbf42a692d22b55d8ad6968a946b4a6ebc385c5505d9625b6a464"},
|
||||
{file = "black-22.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:863714200ada56cbc366dc9ae5291ceb936573155f8bf8e9de92aef51f3ad0f0"},
|
||||
{file = "black-22.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10dbe6e6d2988049b4655b2b739f98785a884d4d6b85bc35133a8fb9a2233176"},
|
||||
{file = "black-22.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:cee3e11161dde1b2a33a904b850b0899e0424cc331b7295f2a9698e79f9a69a0"},
|
||||
{file = "black-22.3.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5891ef8abc06576985de8fa88e95ab70641de6c1fca97e2a15820a9b69e51b20"},
|
||||
{file = "black-22.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:30d78ba6bf080eeaf0b7b875d924b15cd46fec5fd044ddfbad38c8ea9171043a"},
|
||||
{file = "black-22.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ee8f1f7228cce7dffc2b464f07ce769f478968bfb3dd1254a4c2eeed84928aad"},
|
||||
{file = "black-22.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ee227b696ca60dd1c507be80a6bc849a5a6ab57ac7352aad1ffec9e8b805f21"},
|
||||
{file = "black-22.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:9b542ced1ec0ceeff5b37d69838106a6348e60db7b8fdd245294dc1d26136265"},
|
||||
{file = "black-22.3.0-py3-none-any.whl", hash = "sha256:bc58025940a896d7e5356952228b68f793cf5fcb342be703c3a2669a1488cb72"},
|
||||
{file = "black-22.3.0.tar.gz", hash = "sha256:35020b8886c022ced9282b51b5a875b6d1ab0c387b31a065b84db7c33085ca79"},
|
||||
{file = "black-23.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0aaf6041986767a5e0ce663c7a2f0e9eaf21e6ff87a5f95cbf3675bfd4c41d2"},
|
||||
{file = "black-23.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c88b3711d12905b74206227109272673edce0cb29f27e1385f33b0163c414bba"},
|
||||
{file = "black-23.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a920b569dc6b3472513ba6ddea21f440d4b4c699494d2e972a1753cdc25df7b0"},
|
||||
{file = "black-23.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:3fa4be75ef2a6b96ea8d92b1587dd8cb3a35c7e3d51f0738ced0781c3aa3a5a3"},
|
||||
{file = "black-23.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8d4df77958a622f9b5a4c96edb4b8c0034f8434032ab11077ec6c56ae9f384ba"},
|
||||
{file = "black-23.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:602cfb1196dc692424c70b6507593a2b29aac0547c1be9a1d1365f0d964c353b"},
|
||||
{file = "black-23.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c4352800f14be5b4864016882cdba10755bd50805c95f728011bcb47a4afd59"},
|
||||
{file = "black-23.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:0808494f2b2df923ffc5723ed3c7b096bd76341f6213989759287611e9837d50"},
|
||||
{file = "black-23.12.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:25e57fd232a6d6ff3f4478a6fd0580838e47c93c83eaf1ccc92d4faf27112c4e"},
|
||||
{file = "black-23.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2d9e13db441c509a3763a7a3d9a49ccc1b4e974a47be4e08ade2a228876500ec"},
|
||||
{file = "black-23.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1bd9c210f8b109b1762ec9fd36592fdd528485aadb3f5849b2740ef17e674e"},
|
||||
{file = "black-23.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:ae76c22bde5cbb6bfd211ec343ded2163bba7883c7bc77f6b756a1049436fbb9"},
|
||||
{file = "black-23.12.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1fa88a0f74e50e4487477bc0bb900c6781dbddfdfa32691e780bf854c3b4a47f"},
|
||||
{file = "black-23.12.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a4d6a9668e45ad99d2f8ec70d5c8c04ef4f32f648ef39048d010b0689832ec6d"},
|
||||
{file = "black-23.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b18fb2ae6c4bb63eebe5be6bd869ba2f14fd0259bda7d18a46b764d8fb86298a"},
|
||||
{file = "black-23.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:c04b6d9d20e9c13f43eee8ea87d44156b8505ca8a3c878773f68b4e4812a421e"},
|
||||
{file = "black-23.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3e1b38b3135fd4c025c28c55ddfc236b05af657828a8a6abe5deec419a0b7055"},
|
||||
{file = "black-23.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4f0031eaa7b921db76decd73636ef3a12c942ed367d8c3841a0739412b260a54"},
|
||||
{file = "black-23.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97e56155c6b737854e60a9ab1c598ff2533d57e7506d97af5481141671abf3ea"},
|
||||
{file = "black-23.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:dd15245c8b68fe2b6bd0f32c1556509d11bb33aec9b5d0866dd8e2ed3dba09c2"},
|
||||
{file = "black-23.12.1-py3-none-any.whl", hash = "sha256:78baad24af0f033958cad29731e27363183e140962595def56423e626f4bee3e"},
|
||||
{file = "black-23.12.1.tar.gz", hash = "sha256:4ce3ef14ebe8d9509188014d96af1c456a910d5b5cbf434a09fef7e024b3d0d5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
click = ">=8.0.0"
|
||||
mypy-extensions = ">=0.4.3"
|
||||
packaging = ">=22.0"
|
||||
pathspec = ">=0.9.0"
|
||||
platformdirs = ">=2"
|
||||
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
|
||||
typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""}
|
||||
|
||||
[package.extras]
|
||||
colorama = ["colorama (>=0.4.3)"]
|
||||
d = ["aiohttp (>=3.7.4)"]
|
||||
d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"]
|
||||
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
|
||||
uvloop = ["uvloop (>=0.15.2)"]
|
||||
|
||||
@ -558,6 +544,73 @@ mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.6.1)", "types-Pill
|
||||
test = ["Pillow", "contourpy[test-no-images]", "matplotlib"]
|
||||
test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"]
|
||||
|
||||
[[package]]
|
||||
name = "coverage"
|
||||
version = "7.5.1"
|
||||
description = "Code coverage measurement for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "coverage-7.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0884920835a033b78d1c73b6d3bbcda8161a900f38a488829a83982925f6c2e"},
|
||||
{file = "coverage-7.5.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:39afcd3d4339329c5f58de48a52f6e4e50f6578dd6099961cf22228feb25f38f"},
|
||||
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a7b0ceee8147444347da6a66be737c9d78f3353b0681715b668b72e79203e4a"},
|
||||
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a9ca3f2fae0088c3c71d743d85404cec8df9be818a005ea065495bedc33da35"},
|
||||
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd215c0c7d7aab005221608a3c2b46f58c0285a819565887ee0b718c052aa4e"},
|
||||
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4bf0655ab60d754491004a5efd7f9cccefcc1081a74c9ef2da4735d6ee4a6223"},
|
||||
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:61c4bf1ba021817de12b813338c9be9f0ad5b1e781b9b340a6d29fc13e7c1b5e"},
|
||||
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:db66fc317a046556a96b453a58eced5024af4582a8dbdc0c23ca4dbc0d5b3146"},
|
||||
{file = "coverage-7.5.1-cp310-cp310-win32.whl", hash = "sha256:b016ea6b959d3b9556cb401c55a37547135a587db0115635a443b2ce8f1c7228"},
|
||||
{file = "coverage-7.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:df4e745a81c110e7446b1cc8131bf986157770fa405fe90e15e850aaf7619bc8"},
|
||||
{file = "coverage-7.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:796a79f63eca8814ca3317a1ea443645c9ff0d18b188de470ed7ccd45ae79428"},
|
||||
{file = "coverage-7.5.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4fc84a37bfd98db31beae3c2748811a3fa72bf2007ff7902f68746d9757f3746"},
|
||||
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6175d1a0559986c6ee3f7fccfc4a90ecd12ba0a383dcc2da30c2b9918d67d8a3"},
|
||||
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fc81d5878cd6274ce971e0a3a18a8803c3fe25457165314271cf78e3aae3aa2"},
|
||||
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:556cf1a7cbc8028cb60e1ff0be806be2eded2daf8129b8811c63e2b9a6c43bca"},
|
||||
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9981706d300c18d8b220995ad22627647be11a4276721c10911e0e9fa44c83e8"},
|
||||
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d7fed867ee50edf1a0b4a11e8e5d0895150e572af1cd6d315d557758bfa9c057"},
|
||||
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ef48e2707fb320c8f139424a596f5b69955a85b178f15af261bab871873bb987"},
|
||||
{file = "coverage-7.5.1-cp311-cp311-win32.whl", hash = "sha256:9314d5678dcc665330df5b69c1e726a0e49b27df0461c08ca12674bcc19ef136"},
|
||||
{file = "coverage-7.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:5fa567e99765fe98f4e7d7394ce623e794d7cabb170f2ca2ac5a4174437e90dd"},
|
||||
{file = "coverage-7.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b6cf3764c030e5338e7f61f95bd21147963cf6aa16e09d2f74f1fa52013c1206"},
|
||||
{file = "coverage-7.5.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ec92012fefebee89a6b9c79bc39051a6cb3891d562b9270ab10ecfdadbc0c34"},
|
||||
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16db7f26000a07efcf6aea00316f6ac57e7d9a96501e990a36f40c965ec7a95d"},
|
||||
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:beccf7b8a10b09c4ae543582c1319c6df47d78fd732f854ac68d518ee1fb97fa"},
|
||||
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8748731ad392d736cc9ccac03c9845b13bb07d020a33423fa5b3a36521ac6e4e"},
|
||||
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7352b9161b33fd0b643ccd1f21f3a3908daaddf414f1c6cb9d3a2fd618bf2572"},
|
||||
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:7a588d39e0925f6a2bff87154752481273cdb1736270642aeb3635cb9b4cad07"},
|
||||
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:68f962d9b72ce69ea8621f57551b2fa9c70509af757ee3b8105d4f51b92b41a7"},
|
||||
{file = "coverage-7.5.1-cp312-cp312-win32.whl", hash = "sha256:f152cbf5b88aaeb836127d920dd0f5e7edff5a66f10c079157306c4343d86c19"},
|
||||
{file = "coverage-7.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:5a5740d1fb60ddf268a3811bcd353de34eb56dc24e8f52a7f05ee513b2d4f596"},
|
||||
{file = "coverage-7.5.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e2213def81a50519d7cc56ed643c9e93e0247f5bbe0d1247d15fa520814a7cd7"},
|
||||
{file = "coverage-7.5.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5037f8fcc2a95b1f0e80585bd9d1ec31068a9bcb157d9750a172836e98bc7a90"},
|
||||
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3721c2c9e4c4953a41a26c14f4cef64330392a6d2d675c8b1db3b645e31f0e"},
|
||||
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca498687ca46a62ae590253fba634a1fe9836bc56f626852fb2720f334c9e4e5"},
|
||||
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cdcbc320b14c3e5877ee79e649677cb7d89ef588852e9583e6b24c2e5072661"},
|
||||
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:57e0204b5b745594e5bc14b9b50006da722827f0b8c776949f1135677e88d0b8"},
|
||||
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fe7502616b67b234482c3ce276ff26f39ffe88adca2acf0261df4b8454668b4"},
|
||||
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9e78295f4144f9dacfed4f92935fbe1780021247c2fabf73a819b17f0ccfff8d"},
|
||||
{file = "coverage-7.5.1-cp38-cp38-win32.whl", hash = "sha256:1434e088b41594baa71188a17533083eabf5609e8e72f16ce8c186001e6b8c41"},
|
||||
{file = "coverage-7.5.1-cp38-cp38-win_amd64.whl", hash = "sha256:0646599e9b139988b63704d704af8e8df7fa4cbc4a1f33df69d97f36cb0a38de"},
|
||||
{file = "coverage-7.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4cc37def103a2725bc672f84bd939a6fe4522310503207aae4d56351644682f1"},
|
||||
{file = "coverage-7.5.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fc0b4d8bfeabd25ea75e94632f5b6e047eef8adaed0c2161ada1e922e7f7cece"},
|
||||
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d0a0f5e06881ecedfe6f3dd2f56dcb057b6dbeb3327fd32d4b12854df36bf26"},
|
||||
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9735317685ba6ec7e3754798c8871c2f49aa5e687cc794a0b1d284b2389d1bd5"},
|
||||
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d21918e9ef11edf36764b93101e2ae8cc82aa5efdc7c5a4e9c6c35a48496d601"},
|
||||
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c3e757949f268364b96ca894b4c342b41dc6f8f8b66c37878aacef5930db61be"},
|
||||
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:79afb6197e2f7f60c4824dd4b2d4c2ec5801ceb6ba9ce5d2c3080e5660d51a4f"},
|
||||
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d1d0d98d95dd18fe29dc66808e1accf59f037d5716f86a501fc0256455219668"},
|
||||
{file = "coverage-7.5.1-cp39-cp39-win32.whl", hash = "sha256:1cc0fe9b0b3a8364093c53b0b4c0c2dd4bb23acbec4c9240b5f284095ccf7981"},
|
||||
{file = "coverage-7.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:dde0070c40ea8bb3641e811c1cfbf18e265d024deff6de52c5950677a8fb1e0f"},
|
||||
{file = "coverage-7.5.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:6537e7c10cc47c595828b8a8be04c72144725c383c4702703ff4e42e44577312"},
|
||||
{file = "coverage-7.5.1.tar.gz", hash = "sha256:54de9ef3a9da981f7af93eafde4ede199e0846cd819eb27c88e2b712aae9708c"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""}
|
||||
|
||||
[package.extras]
|
||||
toml = ["tomli"]
|
||||
|
||||
[[package]]
|
||||
name = "cycler"
|
||||
version = "0.12.1"
|
||||
@ -671,19 +724,19 @@ typing = ["typing-extensions (>=4.8)"]
|
||||
|
||||
[[package]]
|
||||
name = "flake8"
|
||||
version = "3.9.2"
|
||||
version = "7.0.0"
|
||||
description = "the modular source code checker: pep8 pyflakes and co"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
|
||||
python-versions = ">=3.8.1"
|
||||
files = [
|
||||
{file = "flake8-3.9.2-py2.py3-none-any.whl", hash = "sha256:bf8fd333346d844f616e8d47905ef3a3384edae6b4e9beb0c5101e25e3110907"},
|
||||
{file = "flake8-3.9.2.tar.gz", hash = "sha256:07528381786f2a6237b061f6e96610a4167b226cb926e2aa2b6b1d78057c576b"},
|
||||
{file = "flake8-7.0.0-py2.py3-none-any.whl", hash = "sha256:a6dfbb75e03252917f2473ea9653f7cd799c3064e54d4c8140044c5c065f53c3"},
|
||||
{file = "flake8-7.0.0.tar.gz", hash = "sha256:33f96621059e65eec474169085dc92bf26e7b2d47366b70be2f67ab80dc25132"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
mccabe = ">=0.6.0,<0.7.0"
|
||||
pycodestyle = ">=2.7.0,<2.8.0"
|
||||
pyflakes = ">=2.3.0,<2.4.0"
|
||||
mccabe = ">=0.7.0,<0.8.0"
|
||||
pycodestyle = ">=2.11.0,<2.12.0"
|
||||
pyflakes = ">=3.2.0,<3.3.0"
|
||||
|
||||
[[package]]
|
||||
name = "fonttools"
|
||||
@ -1376,13 +1429,13 @@ traitlets = "*"
|
||||
|
||||
[[package]]
|
||||
name = "mccabe"
|
||||
version = "0.6.1"
|
||||
version = "0.7.0"
|
||||
description = "McCabe checker, plugin for flake8"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "mccabe-0.6.1-py2.py3-none-any.whl", hash = "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42"},
|
||||
{file = "mccabe-0.6.1.tar.gz", hash = "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f"},
|
||||
{file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"},
|
||||
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1973,13 +2026,13 @@ pyasn1 = ">=0.4.6,<0.6.0"
|
||||
|
||||
[[package]]
|
||||
name = "pycodestyle"
|
||||
version = "2.7.0"
|
||||
version = "2.11.1"
|
||||
description = "Python style guide checker"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pycodestyle-2.7.0-py2.py3-none-any.whl", hash = "sha256:514f76d918fcc0b55c6680472f0a37970994e07bbb80725808c17089be302068"},
|
||||
{file = "pycodestyle-2.7.0.tar.gz", hash = "sha256:c389c1d06bf7904078ca03399a4816f974a1d590090fecea0c63ec26ebaf1cef"},
|
||||
{file = "pycodestyle-2.11.1-py2.py3-none-any.whl", hash = "sha256:44fe31000b2d866f2e41841b18528a505fbd7fef9017b04eff4e2648a0fadc67"},
|
||||
{file = "pycodestyle-2.11.1.tar.gz", hash = "sha256:41ba0e7afc9752dfb53ced5489e89f8186be00e599e712660695b7a75ff2663f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2047,13 +2100,13 @@ email = ["email-validator (>=1.0.3)"]
|
||||
|
||||
[[package]]
|
||||
name = "pyflakes"
|
||||
version = "2.3.1"
|
||||
version = "3.2.0"
|
||||
description = "passive checker of Python programs"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pyflakes-2.3.1-py2.py3-none-any.whl", hash = "sha256:7893783d01b8a89811dd72d7dfd4d84ff098e5eed95cfa8905b22bbffe52efc3"},
|
||||
{file = "pyflakes-2.3.1.tar.gz", hash = "sha256:f5bc8ecabc05bb9d291eb5203d6810b49040f6ff446a756326104746cc00c1db"},
|
||||
{file = "pyflakes-3.2.0-py2.py3-none-any.whl", hash = "sha256:84b5be138a2dfbb40689ca07e2152deb896a65c3a3e24c251c5c62489568074a"},
|
||||
{file = "pyflakes-3.2.0.tar.gz", hash = "sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2085,6 +2138,24 @@ files = [
|
||||
[package.extras]
|
||||
diagrams = ["jinja2", "railroad-diagrams"]
|
||||
|
||||
[[package]]
|
||||
name = "pyright"
|
||||
version = "1.1.364"
|
||||
description = "Command line wrapper for pyright"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "pyright-1.1.364-py3-none-any.whl", hash = "sha256:865f1e02873c5dc7427c95acf53659a118574010e6fb364e27e47ec5c46a9f26"},
|
||||
{file = "pyright-1.1.364.tar.gz", hash = "sha256:612a2106a4078ec57efc22b5620729e9bdf4a3c17caba013b534bd33f7d08e5a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
nodeenv = ">=1.6.0"
|
||||
|
||||
[package.extras]
|
||||
all = ["twine (>=3.4.1)"]
|
||||
dev = ["twine (>=3.4.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pysocks"
|
||||
version = "1.7.1"
|
||||
@ -2137,6 +2208,24 @@ pytest = ">=7.0.0"
|
||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
|
||||
testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "5.0.0"
|
||||
description = "Pytest plugin for measuring coverage."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"},
|
||||
{file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
coverage = {version = ">=5.2.1", extras = ["toml"]}
|
||||
pytest = ">=4.6"
|
||||
|
||||
[package.extras]
|
||||
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.8.2"
|
||||
@ -2774,4 +2863,4 @@ multidict = ">=4.0"
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "6eefdbbefb500de627cac39eb6eb1fdcecab76dd4c3599cf08ef6dc647cf71c9"
|
||||
content-hash = "4a980e6d8f54a2f7f6a3c55d4f40ac3a4b27b5ac6573dd2a39e11213a4b126dd"
|
||||
|
@ -37,59 +37,49 @@ click-default-group = "^1.2.4"
|
||||
tabulate = "^0.9.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
flake8 = "^3.9.2"
|
||||
isort = "^5.9.3"
|
||||
black = "22.3"
|
||||
autoflake = "^1.4"
|
||||
black = "^23.12.1"
|
||||
flake8 = "^7.0.0"
|
||||
isort = "^5.13.1"
|
||||
pyright = "^1.1.364"
|
||||
pandas = "^2.0.3"
|
||||
gspread = "^5.10.0"
|
||||
oauth2client = "^4.1.3"
|
||||
pre-commit = "^3.3.3"
|
||||
pytest-cov = "^5.0.0"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
agbenchmark = "agbenchmark.__main__:cli"
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
minversion = "6.0"
|
||||
addopts = "-ra -q"
|
||||
testpaths = [
|
||||
"tests", "agbenchmark",
|
||||
]
|
||||
asyncio_mode = "auto"
|
||||
markers = [
|
||||
"interface",
|
||||
"code",
|
||||
"memory",
|
||||
"iterate",
|
||||
"adaptability",
|
||||
"safety",
|
||||
"content_gen",
|
||||
"product_advisor"
|
||||
]
|
||||
filterwarnings = [
|
||||
"ignore::pytest.PytestAssertRewriteWarning",
|
||||
"ignore::matplotlib.MatplotlibDeprecationWarning"
|
||||
]
|
||||
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ['py310']
|
||||
include = '\.pyi?$'
|
||||
packages = ["autogpt"]
|
||||
extend-exclude = '(/dist|/.venv|/venv|/build|/agent|agbenchmark/challenges)/'
|
||||
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
multi_line_output = 3
|
||||
include_trailing_comma = true
|
||||
force_grid_wrap = 0
|
||||
use_parentheses = true
|
||||
ensure_newline_before_comments = true
|
||||
line_length = 88
|
||||
sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"]
|
||||
skip_glob = [".tox", "__pycache__", "*.pyc", "venv*/*", "reports", "venv", "env", "node_modules", ".env", ".venv", "dist", "agent/*", "agbenchmark/challenges/*"]
|
||||
skip_glob = ["reports"]
|
||||
|
||||
[tool.poetry.scripts]
|
||||
agbenchmark = "agbenchmark.__main__:cli"
|
||||
|
||||
[tool.pyright]
|
||||
pythonVersion = "3.10"
|
||||
exclude = [
|
||||
"notebooks/**",
|
||||
"reports/**",
|
||||
"**/node_modules",
|
||||
"**/__pycache__",
|
||||
"**/.*",
|
||||
]
|
||||
ignore = [
|
||||
"../forge/**"
|
||||
]
|
||||
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
|
@ -17,7 +17,7 @@ def print_markdown_report(report_json_file: str):
|
||||
report = Report.parse_file(report_json_file)
|
||||
|
||||
# Header and metadata
|
||||
click.echo(f"# Benchmark Report")
|
||||
click.echo("# Benchmark Report")
|
||||
click.echo(f"- ⌛ **Run time:** `{report.metrics.run_time}`")
|
||||
click.echo(
|
||||
f" - **Started at:** `{report.benchmark_start_time[:16].replace('T', '` `')}`"
|
||||
|
@ -1,11 +1,16 @@
|
||||
import datetime
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
URL_BENCHMARK = "http://localhost:8080/ap/v1"
|
||||
URL_AGENT = "http://localhost:8000/ap/v1"
|
||||
|
||||
import datetime
|
||||
import time
|
||||
try:
|
||||
response = requests.get(f"{URL_AGENT}/agent/tasks")
|
||||
except requests.exceptions.ConnectionError:
|
||||
pytest.skip("No agent available to test against", allow_module_level=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -20,7 +25,8 @@ import time
|
||||
),
|
||||
(
|
||||
"f219f3d3-a41b-45a9-a3d0-389832086ee8",
|
||||
"Read the file called file_to_read.txt and write its content to a file called output.txt",
|
||||
"Read the file called file_to_read.txt "
|
||||
"and write its content to a file called output.txt",
|
||||
1,
|
||||
"ReadFile",
|
||||
False,
|
||||
@ -28,7 +34,11 @@ import time
|
||||
],
|
||||
)
|
||||
def test_entire_workflow(
|
||||
eval_id, input_text, expected_artifact_length, test_name, should_be_successful
|
||||
eval_id: str,
|
||||
input_text: str,
|
||||
expected_artifact_length: int,
|
||||
test_name: str,
|
||||
should_be_successful: bool,
|
||||
):
|
||||
task_request = {"eval_id": eval_id, "input": input_text}
|
||||
response = requests.get(f"{URL_AGENT}/agent/tasks")
|
||||
@ -64,7 +74,7 @@ def test_entire_workflow(
|
||||
)
|
||||
assert step_response.status_code == 200
|
||||
step_response = step_response.json()
|
||||
assert step_response["is_last"] == True # Assuming is_last is always True
|
||||
assert step_response["is_last"] is True # Assuming is_last is always True
|
||||
|
||||
eval_response = requests.post(
|
||||
URL_BENCHMARK + "/agent/tasks/" + task_response_benchmark_id + "/evaluations",
|
||||
|
8
cli.py
8
cli.py
@ -131,7 +131,9 @@ def start(agent_name: str, no_setup: bool):
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
agent_dir = os.path.join(
|
||||
script_dir,
|
||||
f"agents/{agent_name}" if agent_name not in ["autogpt", "forge"] else agent_name,
|
||||
f"agents/{agent_name}"
|
||||
if agent_name not in ["autogpt", "forge"]
|
||||
else agent_name,
|
||||
)
|
||||
run_command = os.path.join(agent_dir, "run")
|
||||
run_bench_command = os.path.join(agent_dir, "run_benchmark")
|
||||
@ -247,7 +249,9 @@ def start(agent_name, subprocess_args):
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
agent_dir = os.path.join(
|
||||
script_dir,
|
||||
f"agents/{agent_name}" if agent_name not in ["autogpt", "forge"] else agent_name,
|
||||
f"agents/{agent_name}"
|
||||
if agent_name not in ["autogpt", "forge"]
|
||||
else agent_name,
|
||||
)
|
||||
benchmark_script = os.path.join(agent_dir, "run_benchmark")
|
||||
if os.path.exists(agent_dir) and os.path.isfile(benchmark_script):
|
||||
|
@ -202,7 +202,7 @@ class MyAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
settings: AgentSettings,
|
||||
llm_provider: ChatModelProvider,
|
||||
llm_provider: MultiProvider,
|
||||
file_storage: FileStorage,
|
||||
legacy_config: Config,
|
||||
):
|
||||
@ -219,7 +219,7 @@ class MyAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
settings: AgentSettings,
|
||||
llm_provider: ChatModelProvider,
|
||||
llm_provider: MultiProvider,
|
||||
file_storage: FileStorage,
|
||||
legacy_config: Config,
|
||||
):
|
||||
|
@ -1,15 +1,11 @@
|
||||
[flake8]
|
||||
max-line-length = 88
|
||||
select = "E303, W293, W292, E305, E231, E302"
|
||||
# Ignore rules that conflict with Black code style
|
||||
extend-ignore = E203, W503
|
||||
exclude =
|
||||
.tox,
|
||||
__pycache__,
|
||||
.git,
|
||||
__pycache__/,
|
||||
*.pyc,
|
||||
.env
|
||||
venv*/*,
|
||||
.venv/*,
|
||||
reports/*,
|
||||
dist/*,
|
||||
agent/*,
|
||||
code,
|
||||
agbenchmark/challenges/*
|
||||
.pytest_cache/,
|
||||
venv*/,
|
||||
.venv/,
|
||||
|
5
forge/.gitignore
vendored
5
forge/.gitignore
vendored
@ -160,7 +160,8 @@ CURRENT_BULLETIN.md
|
||||
|
||||
agbenchmark_config/workspace
|
||||
agbenchmark_config/reports
|
||||
*.sqlite
|
||||
*.sqlite*
|
||||
*.db
|
||||
.agbench
|
||||
.agbenchmark
|
||||
.benchmarks
|
||||
@ -168,7 +169,7 @@ agbenchmark_config/reports
|
||||
.pytest_cache
|
||||
.vscode
|
||||
ig_*
|
||||
agent.db
|
||||
agbenchmark_config/updates.json
|
||||
agbenchmark_config/challenges_already_beaten.json
|
||||
agbenchmark_config/temp_folder/*
|
||||
test_workspace/
|
||||
|
@ -1,43 +0,0 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
args: ['--maxkb=500']
|
||||
- id: check-byte-order-marker
|
||||
- id: check-case-conflict
|
||||
- id: check-merge-conflict
|
||||
- id: check-symlinks
|
||||
- id: debug-statements
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
language_version: python3.11
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3.11
|
||||
|
||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||
# rev: 'v1.3.0'
|
||||
# hooks:
|
||||
# - id: mypy
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: autoflake
|
||||
name: autoflake
|
||||
entry: autoflake --in-place --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring forge/autogpt
|
||||
language: python
|
||||
types: [ python ]
|
||||
# Mono repo has bronken this TODO: fix
|
||||
# - id: pytest-check
|
||||
# name: pytest-check
|
||||
# entry: pytest
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
# always_run: true
|
@ -9,27 +9,24 @@ from forge.logging.config import configure_logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logo = """\n\n
|
||||
d8888 888 .d8888b. 8888888b. 88888888888
|
||||
d88888 888 d88P Y88b 888 Y88b 888
|
||||
d88P888 888 888 888 888 888 888
|
||||
d88P 888 888 888 888888 .d88b. 888 888 d88P 888
|
||||
d88P 888 888 888 888 d88""88b 888 88888 8888888P" 888
|
||||
d88P 888 888 888 888 888 888 888 888 888 888
|
||||
d8888888888 Y88b 888 Y88b. Y88..88P Y88b d88P 888 888
|
||||
d88P 888 "Y88888 "Y888 "Y88P" "Y8888P88 888 888
|
||||
|
||||
|
||||
|
||||
8888888888
|
||||
888
|
||||
888
|
||||
8888888 .d88b. 888d888 .d88b. .d88b.
|
||||
888 d88""88b 888P" d88P"88b d8P Y8b
|
||||
888 888 888 888 888 888 88888888
|
||||
888 Y88..88P 888 Y88b 888 Y8b.
|
||||
888 "Y88P" 888 "Y88888 "Y8888
|
||||
888
|
||||
Y8b d88P
|
||||
d8888 888 .d8888b. 8888888b. 88888888888
|
||||
d88P888 888 888 888 888 888 888
|
||||
d88P 888 888 888 888888 .d88b. 888 888 d88P 888
|
||||
d88P 888 888 888 888 d88""88b 888 88888 8888888P" 888
|
||||
d88P 888 888 888 888 888 888 888 888 888 888
|
||||
d8888888888 Y88b 888 Y88b. Y88..88P Y88b d88P 888 888
|
||||
d88P 888 "Y88888 "Y888 "Y88P" "Y8888P88 888 888
|
||||
|
||||
|
||||
8888888888
|
||||
888
|
||||
888 .d88b. 888d888 .d88b. .d88b.
|
||||
888888 d88""88b 888P" d88P"88b d8P Y8b
|
||||
888 888 888 888 888 888 88888888
|
||||
888 Y88..88P 888 Y88b 888 Y8b.
|
||||
888 "Y88P" 888 "Y88888 "Y8888
|
||||
888
|
||||
Y8b d88P
|
||||
"Y88P" v0.1.0
|
||||
\n"""
|
||||
|
||||
|
@ -1,15 +1,7 @@
|
||||
from .base import AgentMeta, BaseAgent, BaseAgentConfiguration, BaseAgentSettings
|
||||
from .components import (
|
||||
AgentComponent,
|
||||
ComponentEndpointError,
|
||||
ComponentSystemError,
|
||||
EndpointPipelineError,
|
||||
)
|
||||
from .protocols import (
|
||||
AfterExecute,
|
||||
AfterParse,
|
||||
CommandProvider,
|
||||
DirectiveProvider,
|
||||
ExecutionFailure,
|
||||
MessageProvider,
|
||||
)
|
||||
from .base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
|
||||
|
||||
__all__ = [
|
||||
"BaseAgent",
|
||||
"BaseAgentConfiguration",
|
||||
"BaseAgentSettings",
|
||||
]
|
||||
|
@ -24,7 +24,6 @@ from forge.agent_protocol.models.task import (
|
||||
TaskStepsListResponse,
|
||||
)
|
||||
from forge.file_storage.base import FileStorage
|
||||
from forge.utils.exceptions import NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -79,7 +78,8 @@ class Agent:
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
f"Frontend not found. {frontend_path} does not exist. The frontend will not be served"
|
||||
f"Frontend not found. {frontend_path} does not exist. "
|
||||
"The frontend will not be served."
|
||||
)
|
||||
app.add_middleware(AgentMiddleware, agent=self)
|
||||
|
||||
@ -94,34 +94,25 @@ class Agent:
|
||||
"""
|
||||
Create a task for the agent.
|
||||
"""
|
||||
try:
|
||||
task = await self.db.create_task(
|
||||
input=task_request.input,
|
||||
additional_input=task_request.additional_input,
|
||||
)
|
||||
return task
|
||||
except Exception as e:
|
||||
raise
|
||||
task = await self.db.create_task(
|
||||
input=task_request.input,
|
||||
additional_input=task_request.additional_input,
|
||||
)
|
||||
return task
|
||||
|
||||
async def list_tasks(self, page: int = 1, pageSize: int = 10) -> TaskListResponse:
|
||||
"""
|
||||
List all tasks that the agent has created.
|
||||
"""
|
||||
try:
|
||||
tasks, pagination = await self.db.list_tasks(page, pageSize)
|
||||
response = TaskListResponse(tasks=tasks, pagination=pagination)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise
|
||||
tasks, pagination = await self.db.list_tasks(page, pageSize)
|
||||
response = TaskListResponse(tasks=tasks, pagination=pagination)
|
||||
return response
|
||||
|
||||
async def get_task(self, task_id: str) -> Task:
|
||||
"""
|
||||
Get a task by ID.
|
||||
"""
|
||||
try:
|
||||
task = await self.db.get_task(task_id)
|
||||
except Exception as e:
|
||||
raise
|
||||
task = await self.db.get_task(task_id)
|
||||
return task
|
||||
|
||||
async def list_steps(
|
||||
@ -130,12 +121,9 @@ class Agent:
|
||||
"""
|
||||
List the IDs of all steps that the task has created.
|
||||
"""
|
||||
try:
|
||||
steps, pagination = await self.db.list_steps(task_id, page, pageSize)
|
||||
response = TaskStepsListResponse(steps=steps, pagination=pagination)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise
|
||||
steps, pagination = await self.db.list_steps(task_id, page, pageSize)
|
||||
response = TaskStepsListResponse(steps=steps, pagination=pagination)
|
||||
return response
|
||||
|
||||
async def execute_step(self, task_id: str, step_request: StepRequestBody) -> Step:
|
||||
"""
|
||||
@ -147,11 +135,8 @@ class Agent:
|
||||
"""
|
||||
Get a step by ID.
|
||||
"""
|
||||
try:
|
||||
step = await self.db.get_step(task_id, step_id)
|
||||
return step
|
||||
except Exception as e:
|
||||
raise
|
||||
step = await self.db.get_step(task_id, step_id)
|
||||
return step
|
||||
|
||||
async def list_artifacts(
|
||||
self, task_id: str, page: int = 1, pageSize: int = 10
|
||||
@ -159,62 +144,45 @@ class Agent:
|
||||
"""
|
||||
List the artifacts that the task has created.
|
||||
"""
|
||||
try:
|
||||
artifacts, pagination = await self.db.list_artifacts(
|
||||
task_id, page, pageSize
|
||||
)
|
||||
return TaskArtifactsListResponse(artifacts=artifacts, pagination=pagination)
|
||||
|
||||
except Exception as e:
|
||||
raise
|
||||
artifacts, pagination = await self.db.list_artifacts(task_id, page, pageSize)
|
||||
return TaskArtifactsListResponse(artifacts=artifacts, pagination=pagination)
|
||||
|
||||
async def create_artifact(
|
||||
self, task_id: str, file: UploadFile, relative_path: str
|
||||
self, task_id: str, file: UploadFile, relative_path: str = ""
|
||||
) -> Artifact:
|
||||
"""
|
||||
Create an artifact for the task.
|
||||
"""
|
||||
data = None
|
||||
file_name = file.filename or str(uuid4())
|
||||
try:
|
||||
data = b""
|
||||
while contents := file.file.read(1024 * 1024):
|
||||
data += contents
|
||||
# Check if relative path ends with filename
|
||||
if relative_path.endswith(file_name):
|
||||
file_path = relative_path
|
||||
else:
|
||||
file_path = os.path.join(relative_path, file_name)
|
||||
data = b""
|
||||
while contents := file.file.read(1024 * 1024):
|
||||
data += contents
|
||||
# Check if relative path ends with filename
|
||||
if relative_path.endswith(file_name):
|
||||
file_path = relative_path
|
||||
else:
|
||||
file_path = os.path.join(relative_path, file_name)
|
||||
|
||||
await self.workspace.write_file(file_path, data)
|
||||
await self.workspace.write_file(file_path, data)
|
||||
|
||||
artifact = await self.db.create_artifact(
|
||||
task_id=task_id,
|
||||
file_name=file_name,
|
||||
relative_path=relative_path,
|
||||
agent_created=False,
|
||||
)
|
||||
except Exception as e:
|
||||
raise
|
||||
artifact = await self.db.create_artifact(
|
||||
task_id=task_id,
|
||||
file_name=file_name,
|
||||
relative_path=relative_path,
|
||||
agent_created=False,
|
||||
)
|
||||
return artifact
|
||||
|
||||
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
|
||||
async def get_artifact(self, task_id: str, artifact_id: str) -> StreamingResponse:
|
||||
"""
|
||||
Get an artifact by ID.
|
||||
"""
|
||||
try:
|
||||
artifact = await self.db.get_artifact(artifact_id)
|
||||
if artifact.file_name not in artifact.relative_path:
|
||||
file_path = os.path.join(artifact.relative_path, artifact.file_name)
|
||||
else:
|
||||
file_path = artifact.relative_path
|
||||
retrieved_artifact = self.workspace.read_file(file_path)
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except FileNotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise
|
||||
artifact = await self.db.get_artifact(artifact_id)
|
||||
if artifact.file_name not in artifact.relative_path:
|
||||
file_path = os.path.join(artifact.relative_path, artifact.file_name)
|
||||
else:
|
||||
file_path = artifact.relative_path
|
||||
retrieved_artifact = self.workspace.read_file(file_path, binary=True)
|
||||
|
||||
return StreamingResponse(
|
||||
BytesIO(retrieved_artifact),
|
||||
|
@ -1,6 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi import UploadFile
|
||||
|
||||
from forge.agent_protocol.database.db import AgentDB
|
||||
from forge.agent_protocol.models.task import (
|
||||
@ -16,16 +17,23 @@ from .agent import Agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent():
|
||||
def agent(test_workspace: Path):
|
||||
db = AgentDB("sqlite:///test.db")
|
||||
config = FileStorageConfiguration(root=Path("./test_workspace"))
|
||||
config = FileStorageConfiguration(root=test_workspace)
|
||||
workspace = LocalFileStorage(config)
|
||||
return Agent(db, workspace)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.fixture
|
||||
def file_upload():
|
||||
this_file = Path(__file__)
|
||||
file_handle = this_file.open("rb")
|
||||
yield UploadFile(file_handle, filename=this_file.name)
|
||||
file_handle.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task(agent):
|
||||
async def test_create_task(agent: Agent):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
@ -33,20 +41,18 @@ async def test_create_task(agent):
|
||||
assert task.input == "test_input"
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks(agent):
|
||||
async def test_list_tasks(agent: Agent):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
task = await agent.create_task(task_request)
|
||||
await agent.create_task(task_request)
|
||||
tasks = await agent.list_tasks()
|
||||
assert isinstance(tasks, TaskListResponse)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_task(agent):
|
||||
async def test_get_task(agent: Agent):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
@ -55,9 +61,9 @@ async def test_get_task(agent):
|
||||
assert retrieved_task.task_id == task.task_id
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.xfail(reason="execute_step is not implemented")
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_execute_step(agent):
|
||||
async def test_execute_step(agent: Agent):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
@ -65,14 +71,14 @@ async def test_create_and_execute_step(agent):
|
||||
step_request = StepRequestBody(
|
||||
input="step_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
step = await agent.create_and_execute_step(task.task_id, step_request)
|
||||
step = await agent.execute_step(task.task_id, step_request)
|
||||
assert step.input == "step_input"
|
||||
assert step.additional_input == {"input": "additional_test_input"}
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.xfail(reason="execute_step is not implemented")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_step(agent):
|
||||
async def test_get_step(agent: Agent):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
@ -80,38 +86,52 @@ async def test_get_step(agent):
|
||||
step_request = StepRequestBody(
|
||||
input="step_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
step = await agent.create_and_execute_step(task.task_id, step_request)
|
||||
step = await agent.execute_step(task.task_id, step_request)
|
||||
retrieved_step = await agent.get_step(task.task_id, step.step_id)
|
||||
assert retrieved_step.step_id == step.step_id
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_artifacts(agent):
|
||||
artifacts = await agent.list_artifacts()
|
||||
assert isinstance(artifacts, list)
|
||||
async def test_list_artifacts(agent: Agent):
|
||||
tasks = await agent.list_tasks()
|
||||
assert tasks.tasks, "No tasks in test.db"
|
||||
|
||||
artifacts = await agent.list_artifacts(tasks.tasks[0].task_id)
|
||||
assert isinstance(artifacts.artifacts, list)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_artifact(agent):
|
||||
async def test_create_artifact(agent: Agent, file_upload: UploadFile):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
task = await agent.create_task(task_request)
|
||||
artifact_request = ArtifactRequestBody(file=None, uri="test_uri")
|
||||
artifact = await agent.create_artifact(task.task_id, artifact_request)
|
||||
assert artifact.uri == "test_uri"
|
||||
artifact = await agent.create_artifact(
|
||||
task_id=task.task_id,
|
||||
file=file_upload,
|
||||
relative_path=f"a_dir/{file_upload.filename}",
|
||||
)
|
||||
assert artifact.file_name == file_upload.filename
|
||||
assert artifact.relative_path == f"a_dir/{file_upload.filename}"
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_artifact(agent):
|
||||
async def test_create_and_get_artifact(agent: Agent, file_upload: UploadFile):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
task = await agent.create_task(task_request)
|
||||
artifact_request = ArtifactRequestBody(file=None, uri="test_uri")
|
||||
artifact = await agent.create_artifact(task.task_id, artifact_request)
|
||||
|
||||
artifact = await agent.create_artifact(
|
||||
task_id=task.task_id,
|
||||
file=file_upload,
|
||||
relative_path=f"b_dir/{file_upload.filename}",
|
||||
)
|
||||
await file_upload.seek(0)
|
||||
file_upload_content = await file_upload.read()
|
||||
|
||||
retrieved_artifact = await agent.get_artifact(task.task_id, artifact.artifact_id)
|
||||
assert retrieved_artifact.artifact_id == artifact.artifact_id
|
||||
retrieved_artifact_content = bytearray()
|
||||
async for b in retrieved_artifact.body_iterator:
|
||||
retrieved_artifact_content.extend(b) # type: ignore
|
||||
assert retrieved_artifact_content == file_upload_content
|
||||
|
@ -5,22 +5,21 @@ import inspect
|
||||
import logging
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Generic,
|
||||
Iterator,
|
||||
Optional,
|
||||
ParamSpec,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from colorama import Fore
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.models.action import ActionProposal, ActionResult
|
||||
|
||||
from forge.agent import protocols
|
||||
from forge.agent.components import (
|
||||
AgentComponent,
|
||||
@ -29,15 +28,10 @@ from forge.agent.components import (
|
||||
)
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.config.config import ConfigBuilder
|
||||
from forge.llm.providers import CHAT_MODELS, ModelName, OpenAIModelName
|
||||
from forge.llm.providers.schema import ChatModelInfo
|
||||
from forge.models.config import (
|
||||
Configurable,
|
||||
SystemConfiguration,
|
||||
SystemSettings,
|
||||
UserConfigurable,
|
||||
)
|
||||
from forge.models.action import ActionResult, AnyProposal
|
||||
from forge.models.config import SystemConfiguration, SystemSettings, UserConfigurable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -133,17 +127,7 @@ class AgentMeta(ABCMeta):
|
||||
return instance
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
||||
C = TypeVar("C", bound=AgentComponent)
|
||||
|
||||
default_settings = BaseAgentSettings(
|
||||
name="BaseAgent",
|
||||
description=__doc__ if __doc__ else "",
|
||||
)
|
||||
|
||||
class BaseAgent(Generic[AnyProposal], metaclass=AgentMeta):
|
||||
def __init__(
|
||||
self,
|
||||
settings: BaseAgentSettings,
|
||||
@ -173,13 +157,13 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
||||
return self.config.send_token_limit or self.llm.max_tokens * 3 // 4
|
||||
|
||||
@abstractmethod
|
||||
async def propose_action(self) -> ActionProposal:
|
||||
async def propose_action(self) -> AnyProposal:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def execute(
|
||||
self,
|
||||
proposal: ActionProposal,
|
||||
proposal: AnyProposal,
|
||||
user_feedback: str = "",
|
||||
) -> ActionResult:
|
||||
...
|
||||
@ -187,7 +171,7 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
||||
@abstractmethod
|
||||
async def do_not_execute(
|
||||
self,
|
||||
denied_proposal: ActionProposal,
|
||||
denied_proposal: AnyProposal,
|
||||
user_feedback: str,
|
||||
) -> ActionResult:
|
||||
...
|
||||
@ -203,13 +187,16 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
||||
|
||||
@overload
|
||||
async def run_pipeline(
|
||||
self, protocol_method: Callable[P, None], *args, retry_limit: int = 3
|
||||
self,
|
||||
protocol_method: Callable[P, None | Awaitable[None]],
|
||||
*args,
|
||||
retry_limit: int = 3,
|
||||
) -> list[None]:
|
||||
...
|
||||
|
||||
async def run_pipeline(
|
||||
self,
|
||||
protocol_method: Callable[P, Iterator[T] | None],
|
||||
protocol_method: Callable[P, Iterator[T] | None | Awaitable[None]],
|
||||
*args,
|
||||
retry_limit: int = 3,
|
||||
) -> list[T] | list[None]:
|
||||
@ -240,7 +227,10 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
||||
)
|
||||
continue
|
||||
|
||||
method = getattr(component, method_name, None)
|
||||
method = cast(
|
||||
Callable[..., Iterator[T] | None | Awaitable[None]] | None,
|
||||
getattr(component, method_name, None),
|
||||
)
|
||||
if not callable(method):
|
||||
continue
|
||||
|
||||
@ -248,10 +238,9 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
||||
while component_attempts < retry_limit:
|
||||
try:
|
||||
component_args = self._selective_copy(args)
|
||||
if inspect.iscoroutinefunction(method):
|
||||
result = await method(*component_args)
|
||||
else:
|
||||
result = method(*component_args)
|
||||
result = method(*component_args)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
if result is not None:
|
||||
method_result.extend(result)
|
||||
args = component_args
|
||||
@ -269,9 +258,9 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
||||
break
|
||||
# Successful pipeline execution
|
||||
break
|
||||
except EndpointPipelineError:
|
||||
except EndpointPipelineError as e:
|
||||
self._trace.append(
|
||||
f"❌ {Fore.LIGHTRED_EX}{component.__class__.__name__}: "
|
||||
f"❌ {Fore.LIGHTRED_EX}{e.triggerer.__class__.__name__}: "
|
||||
f"EndpointPipelineError{Fore.RESET}"
|
||||
)
|
||||
# Restart from the beginning on EndpointPipelineError
|
||||
|
@ -36,8 +36,9 @@ class AgentComponent(ABC):
|
||||
class ComponentEndpointError(Exception):
|
||||
"""Error of a single protocol method on a component."""
|
||||
|
||||
def __init__(self, message: str = ""):
|
||||
def __init__(self, message: str, component: AgentComponent):
|
||||
self.message = message
|
||||
self.triggerer = component
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
|
@ -1,14 +1,13 @@
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Iterator
|
||||
from typing import TYPE_CHECKING, Awaitable, Generic, Iterator
|
||||
|
||||
from forge.models.action import ActionResult, AnyProposal
|
||||
|
||||
from .components import AgentComponent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.command.command import Command
|
||||
from forge.llm.providers import ChatMessage
|
||||
from forge.models.action import ActionResult
|
||||
|
||||
from .base import ActionProposal
|
||||
|
||||
|
||||
class DirectiveProvider(AgentComponent):
|
||||
@ -34,19 +33,19 @@ class MessageProvider(AgentComponent):
|
||||
...
|
||||
|
||||
|
||||
class AfterParse(AgentComponent):
|
||||
class AfterParse(AgentComponent, Generic[AnyProposal]):
|
||||
@abstractmethod
|
||||
def after_parse(self, result: "ActionProposal") -> None:
|
||||
def after_parse(self, result: AnyProposal) -> None | Awaitable[None]:
|
||||
...
|
||||
|
||||
|
||||
class ExecutionFailure(AgentComponent):
|
||||
@abstractmethod
|
||||
def execution_failure(self, error: Exception) -> None:
|
||||
def execution_failure(self, error: Exception) -> None | Awaitable[None]:
|
||||
...
|
||||
|
||||
|
||||
class AfterExecute(AgentComponent):
|
||||
@abstractmethod
|
||||
def after_execute(self, result: "ActionResult") -> None:
|
||||
def after_execute(self, result: "ActionResult") -> None | Awaitable[None]:
|
||||
...
|
||||
|
@ -1,39 +1,16 @@
|
||||
"""
|
||||
Routes for the Agent Service.
|
||||
|
||||
This module defines the API routes for the Agent service. While there are multiple endpoints provided by the service,
|
||||
the ones that require special attention due to their complexity are:
|
||||
This module defines the API routes for the Agent service.
|
||||
|
||||
1. `execute_agent_task_step`:
|
||||
This route is significant because this is where the agent actually performs the work. The function handles
|
||||
executing the next step for a task based on its current state, and it requires careful implementation to ensure
|
||||
all scenarios (like the presence or absence of steps or a step marked as `last_step`) are handled correctly.
|
||||
|
||||
2. `upload_agent_task_artifacts`:
|
||||
This route allows for the upload of artifacts, supporting various URI types (e.g., s3, gcs, ftp, http).
|
||||
The support for different URI types makes it a bit more complex, and it's important to ensure that all
|
||||
supported URI types are correctly managed. NOTE: The AutoGPT team will eventually handle the most common
|
||||
uri types for you.
|
||||
|
||||
3. `create_agent_task`:
|
||||
While this is a simpler route, it plays a crucial role in the workflow, as it's responsible for the creation
|
||||
of a new task.
|
||||
|
||||
Developers and contributors should be especially careful when making modifications to these routes to ensure
|
||||
consistency and correctness in the system's behavior.
|
||||
Developers and contributors should be especially careful when making modifications
|
||||
to these routes to ensure consistency and correctness in the system's behavior.
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from fastapi import APIRouter, Query, Request, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from forge.utils.exceptions import (
|
||||
NotFoundError,
|
||||
get_detailed_traceback,
|
||||
get_exception_message,
|
||||
)
|
||||
from fastapi import APIRouter, HTTPException, Query, Request, Response, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from .models import (
|
||||
Artifact,
|
||||
@ -46,6 +23,9 @@ from .models import (
|
||||
TaskStepsListResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.agent.agent import Agent
|
||||
|
||||
base_router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -73,10 +53,10 @@ async def create_agent_task(request: Request, task_request: TaskRequestBody) ->
|
||||
|
||||
Args:
|
||||
request (Request): FastAPI request object.
|
||||
task (TaskRequestBody): The task request containing input and additional input data.
|
||||
task (TaskRequestBody): The task request containing input data.
|
||||
|
||||
Returns:
|
||||
Task: A new task with task_id, input, additional_input, and empty lists for artifacts and steps.
|
||||
Task: A new task with task_id, input, and additional_input set.
|
||||
|
||||
Example:
|
||||
Request (TaskRequestBody defined in schema.py):
|
||||
@ -93,46 +73,32 @@ async def create_agent_task(request: Request, task_request: TaskRequestBody) ->
|
||||
"artifacts": [],
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
agent: "Agent" = request["agent"]
|
||||
|
||||
try:
|
||||
task_request = await agent.create_task(task_request)
|
||||
return Response(
|
||||
content=task_request.json(),
|
||||
status_code=200,
|
||||
media_type="application/json",
|
||||
)
|
||||
task = await agent.create_task(task_request)
|
||||
return task
|
||||
except Exception:
|
||||
logger.exception(f"Error whilst trying to create a task: {task_request}")
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": "Internal server error",
|
||||
"exception": get_exception_message(),
|
||||
"traceback": get_detailed_traceback(),
|
||||
}
|
||||
),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@base_router.get("/agent/tasks", tags=["agent"], response_model=TaskListResponse)
|
||||
async def list_agent_tasks(
|
||||
request: Request,
|
||||
page: Optional[int] = Query(1, ge=1),
|
||||
page_size: Optional[int] = Query(10, ge=1),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1),
|
||||
) -> TaskListResponse:
|
||||
"""
|
||||
Retrieves a paginated list of all tasks.
|
||||
|
||||
Args:
|
||||
request (Request): FastAPI request object.
|
||||
page (int, optional): The page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): The number of tasks per page for pagination. Defaults to 10.
|
||||
page (int, optional): Page number for pagination. Default: 1
|
||||
page_size (int, optional): Number of tasks per page for pagination. Default: 10
|
||||
|
||||
Returns:
|
||||
TaskListResponse: A response object containing a list of tasks and pagination details.
|
||||
TaskListResponse: A list of tasks, and pagination details.
|
||||
|
||||
Example:
|
||||
Request:
|
||||
@ -158,34 +124,13 @@ async def list_agent_tasks(
|
||||
}
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
agent: "Agent" = request["agent"]
|
||||
try:
|
||||
tasks = await agent.list_tasks(page, page_size)
|
||||
return Response(
|
||||
content=tasks.json(),
|
||||
status_code=200,
|
||||
media_type="application/json",
|
||||
)
|
||||
except NotFoundError:
|
||||
logger.exception("Error whilst trying to list tasks")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Tasks not found"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
return tasks
|
||||
except Exception:
|
||||
logger.exception("Error whilst trying to list tasks")
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": "Internal server error",
|
||||
"exception": get_exception_message(),
|
||||
"traceback": get_detailed_traceback(),
|
||||
}
|
||||
),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@base_router.get("/agent/tasks/{task_id}", tags=["agent"], response_model=Task)
|
||||
@ -239,36 +184,14 @@ async def get_agent_task(request: Request, task_id: str) -> Task:
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
""" # noqa: E501
|
||||
agent: "Agent" = request["agent"]
|
||||
try:
|
||||
task = await agent.get_task(task_id)
|
||||
|
||||
return Response(
|
||||
content=task.json(),
|
||||
status_code=200,
|
||||
media_type="application/json",
|
||||
)
|
||||
except NotFoundError:
|
||||
logger.exception(f"Error whilst trying to get task: {task_id}")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Task not found"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
return task
|
||||
except Exception:
|
||||
logger.exception(f"Error whilst trying to get task: {task_id}")
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": "Internal server error",
|
||||
"exception": get_exception_message(),
|
||||
"traceback": get_detailed_traceback(),
|
||||
}
|
||||
),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@base_router.get(
|
||||
@ -279,8 +202,8 @@ async def get_agent_task(request: Request, task_id: str) -> Task:
|
||||
async def list_agent_task_steps(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
page: Optional[int] = Query(1, ge=1),
|
||||
page_size: Optional[int] = Query(10, ge=1, alias="pageSize"),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1, alias="pageSize"),
|
||||
) -> TaskStepsListResponse:
|
||||
"""
|
||||
Retrieves a paginated list of steps associated with a specific task.
|
||||
@ -289,10 +212,10 @@ async def list_agent_task_steps(
|
||||
request (Request): FastAPI request object.
|
||||
task_id (str): The ID of the task.
|
||||
page (int, optional): The page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): The number of steps per page for pagination. Defaults to 10.
|
||||
page_size (int, optional): Number of steps per page for pagination. Default: 10.
|
||||
|
||||
Returns:
|
||||
TaskStepsListResponse: A response object containing a list of steps and pagination details.
|
||||
TaskStepsListResponse: A list of steps, and pagination details.
|
||||
|
||||
Example:
|
||||
Request:
|
||||
@ -315,54 +238,40 @@ async def list_agent_task_steps(
|
||||
"pageSize": 10
|
||||
}
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
""" # noqa: E501
|
||||
agent: "Agent" = request["agent"]
|
||||
try:
|
||||
steps = await agent.list_steps(task_id, page, page_size)
|
||||
return Response(
|
||||
content=steps.json(),
|
||||
status_code=200,
|
||||
media_type="application/json",
|
||||
)
|
||||
except NotFoundError:
|
||||
logger.exception("Error whilst trying to list steps")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Steps not found"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
return steps
|
||||
except Exception:
|
||||
logger.exception("Error whilst trying to list steps")
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": "Internal server error",
|
||||
"exception": get_exception_message(),
|
||||
"traceback": get_detailed_traceback(),
|
||||
}
|
||||
),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@base_router.post("/agent/tasks/{task_id}/steps", tags=["agent"], response_model=Step)
|
||||
async def execute_agent_task_step(
|
||||
request: Request, task_id: str, step: Optional[StepRequestBody] = None
|
||||
request: Request, task_id: str, step_request: Optional[StepRequestBody] = None
|
||||
) -> Step:
|
||||
"""
|
||||
Executes the next step for a specified task based on the current task status and returns the
|
||||
executed step with additional feedback fields.
|
||||
Executes the next step for a specified task based on the current task status and
|
||||
returns the executed step with additional feedback fields.
|
||||
|
||||
Depending on the current state of the task, the following scenarios are supported:
|
||||
This route is significant because this is where the agent actually performs work.
|
||||
The function handles executing the next step for a task based on its current state,
|
||||
and it requires careful implementation to ensure all scenarios (like the presence
|
||||
or absence of steps or a step marked as `last_step`) are handled correctly.
|
||||
|
||||
Depending on the current state of the task, the following scenarios are possible:
|
||||
1. No steps exist for the task.
|
||||
2. There is at least one step already for the task, and the task does not have a completed step marked as `last_step`.
|
||||
2. There is at least one step already for the task, and the task does not have a
|
||||
completed step marked as `last_step`.
|
||||
3. There is a completed step marked as `last_step` already on the task.
|
||||
|
||||
In each of these scenarios, a step object will be returned with two additional fields: `output` and `additional_output`.
|
||||
In each of these scenarios, a step object will be returned with two additional
|
||||
fields: `output` and `additional_output`.
|
||||
- `output`: Provides the primary response or feedback to the user.
|
||||
- `additional_output`: Supplementary information or data. Its specific content is not strictly defined and can vary based on the step or agent's implementation.
|
||||
- `additional_output`: Supplementary information or data. Its specific content is
|
||||
not strictly defined and can vary based on the step or agent's implementation.
|
||||
|
||||
Args:
|
||||
request (Request): FastAPI request object.
|
||||
@ -389,39 +298,17 @@ async def execute_agent_task_step(
|
||||
...
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
agent: "Agent" = request["agent"]
|
||||
try:
|
||||
# An empty step request represents a yes to continue command
|
||||
if not step:
|
||||
step = StepRequestBody(input="y")
|
||||
if not step_request:
|
||||
step_request = StepRequestBody(input="y")
|
||||
|
||||
step = await agent.execute_step(task_id, step)
|
||||
|
||||
return Response(
|
||||
content=step.json(),
|
||||
status_code=200,
|
||||
media_type="application/json",
|
||||
)
|
||||
except NotFoundError:
|
||||
logger.exception(f"Error whilst trying to execute a task step: {task_id}")
|
||||
return Response(
|
||||
content=json.dumps({"error": f"Task not found {task_id}"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
step = await agent.execute_step(task_id, step_request)
|
||||
return step
|
||||
except Exception:
|
||||
logger.exception(f"Error whilst trying to execute a task step: {task_id}")
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": "Internal server error",
|
||||
"exception": get_exception_message(),
|
||||
"traceback": get_detailed_traceback(),
|
||||
}
|
||||
),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@base_router.get(
|
||||
@ -450,31 +337,13 @@ async def get_agent_task_step(request: Request, task_id: str, step_id: str) -> S
|
||||
...
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
agent: "Agent" = request["agent"]
|
||||
try:
|
||||
step = await agent.get_step(task_id, step_id)
|
||||
|
||||
return Response(content=step.json(), status_code=200)
|
||||
except NotFoundError:
|
||||
logger.exception(f"Error whilst trying to get step: {step_id}")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Step not found"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
return step
|
||||
except Exception:
|
||||
logger.exception(f"Error whilst trying to get step: {step_id}")
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": "Internal server error",
|
||||
"exception": get_exception_message(),
|
||||
"traceback": get_detailed_traceback(),
|
||||
}
|
||||
),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@base_router.get(
|
||||
@ -485,8 +354,8 @@ async def get_agent_task_step(request: Request, task_id: str, step_id: str) -> S
|
||||
async def list_agent_task_artifacts(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
page: Optional[int] = Query(1, ge=1),
|
||||
page_size: Optional[int] = Query(10, ge=1, alias="pageSize"),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1, alias="pageSize"),
|
||||
) -> TaskArtifactsListResponse:
|
||||
"""
|
||||
Retrieves a paginated list of artifacts associated with a specific task.
|
||||
@ -495,10 +364,10 @@ async def list_agent_task_artifacts(
|
||||
request (Request): FastAPI request object.
|
||||
task_id (str): The ID of the task.
|
||||
page (int, optional): The page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): The number of items per page for pagination. Defaults to 10.
|
||||
page_size (int, optional): Number of items per page for pagination. Default: 10.
|
||||
|
||||
Returns:
|
||||
TaskArtifactsListResponse: A response object containing a list of artifacts and pagination details.
|
||||
TaskArtifactsListResponse: A list of artifacts, and pagination details.
|
||||
|
||||
Example:
|
||||
Request:
|
||||
@ -518,52 +387,33 @@ async def list_agent_task_artifacts(
|
||||
"pageSize": 10
|
||||
}
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
""" # noqa: E501
|
||||
agent: "Agent" = request["agent"]
|
||||
try:
|
||||
artifacts: TaskArtifactsListResponse = await agent.list_artifacts(
|
||||
task_id, page, page_size
|
||||
)
|
||||
artifacts = await agent.list_artifacts(task_id, page, page_size)
|
||||
return artifacts
|
||||
except NotFoundError:
|
||||
logger.exception("Error whilst trying to list artifacts")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Artifacts not found for task_id"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error whilst trying to list artifacts")
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": "Internal server error",
|
||||
"exception": get_exception_message(),
|
||||
"traceback": get_detailed_traceback(),
|
||||
}
|
||||
),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@base_router.post(
|
||||
"/agent/tasks/{task_id}/artifacts", tags=["agent"], response_model=Artifact
|
||||
)
|
||||
async def upload_agent_task_artifacts(
|
||||
request: Request, task_id: str, file: UploadFile, relative_path: Optional[str] = ""
|
||||
request: Request, task_id: str, file: UploadFile, relative_path: str = ""
|
||||
) -> Artifact:
|
||||
"""
|
||||
This endpoint is used to upload an artifact associated with a specific task. The artifact is provided as a file.
|
||||
This endpoint is used to upload an artifact (file) associated with a specific task.
|
||||
|
||||
Args:
|
||||
request (Request): The FastAPI request object.
|
||||
task_id (str): The unique identifier of the task for which the artifact is being uploaded.
|
||||
task_id (str): The ID of the task for which the artifact is being uploaded.
|
||||
file (UploadFile): The file being uploaded as an artifact.
|
||||
relative_path (str): The relative path for the file. This is a query parameter.
|
||||
|
||||
Returns:
|
||||
Artifact: An object containing metadata of the uploaded artifact, including its unique identifier.
|
||||
Artifact: Metadata object for the uploaded artifact, including its ID and path.
|
||||
|
||||
Example:
|
||||
Request:
|
||||
@ -579,35 +429,17 @@ async def upload_agent_task_artifacts(
|
||||
"relative_path": "/my_folder/my_other_folder/",
|
||||
"file_name": "main.py"
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
""" # noqa: E501
|
||||
agent: "Agent" = request["agent"]
|
||||
|
||||
if file is None:
|
||||
return Response(
|
||||
content=json.dumps({"error": "File must be specified"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="File must be specified")
|
||||
try:
|
||||
artifact = await agent.create_artifact(task_id, file, relative_path)
|
||||
return Response(
|
||||
content=artifact.json(),
|
||||
status_code=200,
|
||||
media_type="application/json",
|
||||
)
|
||||
return artifact
|
||||
except Exception:
|
||||
logger.exception(f"Error whilst trying to upload artifact: {task_id}")
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": "Internal server error",
|
||||
"exception": get_exception_message(),
|
||||
"traceback": get_detailed_traceback(),
|
||||
}
|
||||
),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@base_router.get(
|
||||
@ -617,7 +449,7 @@ async def upload_agent_task_artifacts(
|
||||
)
|
||||
async def download_agent_task_artifact(
|
||||
request: Request, task_id: str, artifact_id: str
|
||||
) -> FileResponse:
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Downloads an artifact associated with a specific task.
|
||||
|
||||
@ -636,32 +468,9 @@ async def download_agent_task_artifact(
|
||||
Response:
|
||||
<file_content_of_artifact>
|
||||
"""
|
||||
agent = request["agent"]
|
||||
agent: "Agent" = request["agent"]
|
||||
try:
|
||||
return await agent.get_artifact(task_id, artifact_id)
|
||||
except NotFoundError:
|
||||
logger.exception(f"Error whilst trying to download artifact: {task_id}")
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": f"Artifact not found "
|
||||
"- task_id: {task_id}, artifact_id: {artifact_id}"
|
||||
}
|
||||
),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Error whilst trying to download artifact: {task_id}")
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": f"Internal server error "
|
||||
"- task_id: {task_id}, artifact_id: {artifact_id}",
|
||||
"exception": get_exception_message(),
|
||||
"traceback": get_detailed_traceback(),
|
||||
}
|
||||
),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
raise
|
||||
|
@ -1 +1,3 @@
|
||||
from .db import AgentDB
|
||||
|
||||
__all__ = ["AgentDB"]
|
||||
|
@ -4,23 +4,22 @@ It uses SQLite as the database and file store backend.
|
||||
IT IS NOT ADVISED TO USE THIS IN PRODUCTION!
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import math
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
String,
|
||||
create_engine,
|
||||
)
|
||||
from sqlalchemy import JSON, Boolean, DateTime, ForeignKey, create_engine
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import DeclarativeBase, joinedload, relationship, sessionmaker
|
||||
from sqlalchemy.orm import (
|
||||
DeclarativeBase,
|
||||
Mapped,
|
||||
joinedload,
|
||||
mapped_column,
|
||||
relationship,
|
||||
sessionmaker,
|
||||
)
|
||||
|
||||
from forge.utils.exceptions import NotFoundError
|
||||
|
||||
@ -32,18 +31,20 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
type_annotation_map = {
|
||||
dict[str, Any]: JSON,
|
||||
}
|
||||
|
||||
|
||||
class TaskModel(Base):
|
||||
__tablename__ = "tasks"
|
||||
|
||||
task_id = Column(String, primary_key=True, index=True)
|
||||
input = Column(String)
|
||||
additional_input = Column(JSON)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
||||
modified_at = Column(
|
||||
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
|
||||
task_id: Mapped[str] = mapped_column(primary_key=True, index=True)
|
||||
input: Mapped[str]
|
||||
additional_input: Mapped[dict[str, Any]] = mapped_column(default=dict)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
modified_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
|
||||
artifacts = relationship("ArtifactModel", back_populates="task")
|
||||
@ -52,35 +53,35 @@ class TaskModel(Base):
|
||||
class StepModel(Base):
|
||||
__tablename__ = "steps"
|
||||
|
||||
step_id = Column(String, primary_key=True, index=True)
|
||||
task_id = Column(String, ForeignKey("tasks.task_id"))
|
||||
name = Column(String)
|
||||
input = Column(String)
|
||||
status = Column(String)
|
||||
output = Column(String)
|
||||
is_last = Column(Boolean, default=False)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
||||
modified_at = Column(
|
||||
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
|
||||
step_id: Mapped[str] = mapped_column(primary_key=True, index=True)
|
||||
task_id: Mapped[str] = mapped_column(ForeignKey("tasks.task_id"))
|
||||
name: Mapped[str]
|
||||
input: Mapped[str]
|
||||
status: Mapped[str]
|
||||
output: Mapped[Optional[str]]
|
||||
is_last: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
modified_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
|
||||
additional_input = Column(JSON)
|
||||
additional_output = Column(JSON)
|
||||
additional_input: Mapped[dict[str, Any]] = mapped_column(default=dict)
|
||||
additional_output: Mapped[Optional[dict[str, Any]]]
|
||||
artifacts = relationship("ArtifactModel", back_populates="step")
|
||||
|
||||
|
||||
class ArtifactModel(Base):
|
||||
__tablename__ = "artifacts"
|
||||
|
||||
artifact_id = Column(String, primary_key=True, index=True)
|
||||
task_id = Column(String, ForeignKey("tasks.task_id"))
|
||||
step_id = Column(String, ForeignKey("steps.step_id"))
|
||||
agent_created = Column(Boolean, default=False)
|
||||
file_name = Column(String)
|
||||
relative_path = Column(String)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
||||
modified_at = Column(
|
||||
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
|
||||
artifact_id: Mapped[str] = mapped_column(primary_key=True, index=True)
|
||||
task_id: Mapped[str] = mapped_column(ForeignKey("tasks.task_id"))
|
||||
step_id: Mapped[Optional[str]] = mapped_column(ForeignKey("steps.step_id"))
|
||||
agent_created: Mapped[bool] = mapped_column(default=False)
|
||||
file_name: Mapped[str]
|
||||
relative_path: Mapped[str]
|
||||
created_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
|
||||
modified_at: Mapped[datetime] = mapped_column(
|
||||
default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
|
||||
step = relationship("StepModel", back_populates="artifacts")
|
||||
@ -150,6 +151,10 @@ class AgentDB:
|
||||
Base.metadata.create_all(self.engine)
|
||||
self.Session = sessionmaker(bind=self.engine)
|
||||
|
||||
def close(self) -> None:
|
||||
self.Session.close_all()
|
||||
self.engine.dispose()
|
||||
|
||||
async def create_task(
|
||||
self, input: Optional[str], additional_input: Optional[dict] = {}
|
||||
) -> Task:
|
||||
@ -172,8 +177,6 @@ class AgentDB:
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while creating task: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while creating task: {e}")
|
||||
raise
|
||||
@ -207,8 +210,6 @@ class AgentDB:
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while creating step: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while creating step: {e}")
|
||||
raise
|
||||
@ -237,7 +238,7 @@ class AgentDB:
|
||||
session.close()
|
||||
if self.debug_enabled:
|
||||
logger.debug(
|
||||
f"Artifact already exists with relative_path: {relative_path}"
|
||||
f"Artifact {file_name} already exists at {relative_path}/"
|
||||
)
|
||||
return convert_to_artifact(existing_artifact)
|
||||
|
||||
@ -254,14 +255,12 @@ class AgentDB:
|
||||
session.refresh(new_artifact)
|
||||
if self.debug_enabled:
|
||||
logger.debug(
|
||||
f"Created new artifact with artifact_id: {new_artifact.artifact_id}"
|
||||
f"Created new artifact with ID: {new_artifact.artifact_id}"
|
||||
)
|
||||
return convert_to_artifact(new_artifact)
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while creating step: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while creating step: {e}")
|
||||
raise
|
||||
@ -285,8 +284,6 @@ class AgentDB:
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while getting task: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while getting task: {e}")
|
||||
raise
|
||||
@ -312,8 +309,6 @@ class AgentDB:
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while getting step: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while getting step: {e}")
|
||||
raise
|
||||
@ -337,8 +332,6 @@ class AgentDB:
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while getting artifact: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while getting artifact: {e}")
|
||||
raise
|
||||
@ -375,14 +368,13 @@ class AgentDB:
|
||||
return await self.get_step(task_id, step_id)
|
||||
else:
|
||||
logger.error(
|
||||
f"Step not found for update with task_id: {task_id} and step_id: {step_id}"
|
||||
"Can't update non-existent Step with "
|
||||
f"task_id: {task_id} and step_id: {step_id}"
|
||||
)
|
||||
raise NotFoundError("Step not found")
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while getting step: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while getting step: {e}")
|
||||
raise
|
||||
@ -441,8 +433,6 @@ class AgentDB:
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while listing tasks: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while listing tasks: {e}")
|
||||
raise
|
||||
@ -475,8 +465,6 @@ class AgentDB:
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while listing steps: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while listing steps: {e}")
|
||||
raise
|
||||
@ -509,8 +497,6 @@ class AgentDB:
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while listing artifacts: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while listing artifacts: {e}")
|
||||
raise
|
||||
|
@ -22,14 +22,27 @@ from forge.agent_protocol.models import (
|
||||
)
|
||||
from forge.utils.exceptions import NotFoundError as DataNotFoundError
|
||||
|
||||
TEST_DB_FILENAME = "test_db.sqlite3"
|
||||
TEST_DB_URL = f"sqlite:///{TEST_DB_FILENAME}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_table_creation():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
|
||||
conn = sqlite3.connect("test_db.sqlite3")
|
||||
cursor = conn.cursor()
|
||||
@pytest.fixture
|
||||
def agent_db():
|
||||
db = AgentDB(TEST_DB_URL)
|
||||
yield db
|
||||
db.close()
|
||||
os.remove(TEST_DB_FILENAME)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def raw_db_connection(agent_db: AgentDB):
|
||||
connection = sqlite3.connect(TEST_DB_FILENAME)
|
||||
yield connection
|
||||
connection.close()
|
||||
|
||||
|
||||
def test_table_creation(raw_db_connection: sqlite3.Connection):
|
||||
cursor = raw_db_connection.cursor()
|
||||
|
||||
# Test for tasks table existence
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='tasks'")
|
||||
@ -45,8 +58,6 @@ def test_table_creation():
|
||||
)
|
||||
assert cursor.fetchone() is not None
|
||||
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_schema():
|
||||
@ -84,7 +95,10 @@ async def test_step_schema():
|
||||
name="Write to file",
|
||||
input="Write the words you receive to the file 'output.txt'.",
|
||||
status=StepStatus.created,
|
||||
output="I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')>",
|
||||
output=(
|
||||
"I am going to use the write_to_file command and write Washington "
|
||||
"to a file called output.txt <write_to_file('output.txt', 'Washington')>"
|
||||
),
|
||||
artifacts=[
|
||||
Artifact(
|
||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
@ -101,13 +115,13 @@ async def test_step_schema():
|
||||
assert step.step_id == "6bb1801a-fd80-45e8-899a-4dd723cc602e"
|
||||
assert step.name == "Write to file"
|
||||
assert step.status == StepStatus.created
|
||||
assert (
|
||||
step.output
|
||||
== "I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')>"
|
||||
assert step.output == (
|
||||
"I am going to use the write_to_file command and write Washington "
|
||||
"to a file called output.txt <write_to_file('output.txt', 'Washington')>"
|
||||
)
|
||||
assert len(step.artifacts) == 1
|
||||
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
||||
assert step.is_last == False
|
||||
assert step.is_last is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -118,6 +132,7 @@ async def test_convert_to_task():
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
input="Write the words you receive to the file 'output.txt'.",
|
||||
additional_input={},
|
||||
artifacts=[
|
||||
ArtifactModel(
|
||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
@ -147,6 +162,7 @@ async def test_convert_to_step():
|
||||
name="Write to file",
|
||||
status="created",
|
||||
input="Write the words you receive to the file 'output.txt'.",
|
||||
additional_input={},
|
||||
artifacts=[
|
||||
ArtifactModel(
|
||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
@ -166,7 +182,7 @@ async def test_convert_to_step():
|
||||
assert step.status == StepStatus.created
|
||||
assert len(step.artifacts) == 1
|
||||
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
||||
assert step.is_last == False
|
||||
assert step.is_last is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -183,91 +199,67 @@ async def test_convert_to_artifact():
|
||||
artifact = convert_to_artifact(artifact_model)
|
||||
assert artifact.artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
||||
assert artifact.relative_path == "file:///path/to/main.py"
|
||||
assert artifact.agent_created == True
|
||||
assert artifact.agent_created is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task():
|
||||
# Having issues with pytest fixture so added setup and teardown in each test as a rapid workaround
|
||||
# TODO: Fix this!
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
|
||||
async def test_create_task(agent_db: AgentDB):
|
||||
task = await agent_db.create_task("task_input")
|
||||
assert task.input == "task_input"
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_get_task():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
async def test_create_and_get_task(agent_db: AgentDB):
|
||||
task = await agent_db.create_task("test_input")
|
||||
fetched_task = await agent_db.get_task(task.task_id)
|
||||
assert fetched_task.input == "test_input"
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_task_not_found():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
async def test_get_task_not_found(agent_db: AgentDB):
|
||||
with pytest.raises(DataNotFoundError):
|
||||
await agent_db.get_task(9999)
|
||||
os.remove(db_name.split("///")[1])
|
||||
await agent_db.get_task("9999")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_get_step():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
async def test_create_and_get_step(agent_db: AgentDB):
|
||||
task = await agent_db.create_task("task_input")
|
||||
step_input = StepInput(type="python/code")
|
||||
step_input = {"type": "python/code"}
|
||||
request = StepRequestBody(input="test_input debug", additional_input=step_input)
|
||||
step = await agent_db.create_step(task.task_id, request)
|
||||
step = await agent_db.get_step(task.task_id, step.step_id)
|
||||
assert step.input == "test_input debug"
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updating_step():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
async def test_updating_step(agent_db: AgentDB):
|
||||
created_task = await agent_db.create_task("task_input")
|
||||
step_input = StepInput(type="python/code")
|
||||
step_input = {"type": "python/code"}
|
||||
request = StepRequestBody(input="test_input debug", additional_input=step_input)
|
||||
created_step = await agent_db.create_step(created_task.task_id, request)
|
||||
await agent_db.update_step(created_task.task_id, created_step.step_id, "completed")
|
||||
|
||||
step = await agent_db.get_step(created_task.task_id, created_step.step_id)
|
||||
assert step.status.value == "completed"
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_step_not_found():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
async def test_get_step_not_found(agent_db: AgentDB):
|
||||
with pytest.raises(DataNotFoundError):
|
||||
await agent_db.get_step(9999, 9999)
|
||||
os.remove(db_name.split("///")[1])
|
||||
await agent_db.get_step("9999", "9999")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_artifact():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
db = AgentDB(db_name)
|
||||
|
||||
async def test_get_artifact(agent_db: AgentDB):
|
||||
# Given: A task and its corresponding artifact
|
||||
task = await db.create_task("test_input debug")
|
||||
step_input = StepInput(type="python/code")
|
||||
task = await agent_db.create_task("test_input debug")
|
||||
step_input = {"type": "python/code"}
|
||||
requst = StepRequestBody(input="test_input debug", additional_input=step_input)
|
||||
|
||||
step = await db.create_step(task.task_id, requst)
|
||||
step = await agent_db.create_step(task.task_id, requst)
|
||||
|
||||
# Create an artifact
|
||||
artifact = await db.create_artifact(
|
||||
artifact = await agent_db.create_artifact(
|
||||
task_id=task.task_id,
|
||||
file_name="test_get_artifact_sample_file.txt",
|
||||
relative_path="file:///path/to/test_get_artifact_sample_file.txt",
|
||||
@ -276,7 +268,7 @@ async def test_get_artifact():
|
||||
)
|
||||
|
||||
# When: The artifact is fetched by its ID
|
||||
fetched_artifact = await db.get_artifact(artifact.artifact_id)
|
||||
fetched_artifact = await agent_db.get_artifact(artifact.artifact_id)
|
||||
|
||||
# Then: The fetched artifact matches the original
|
||||
assert fetched_artifact.artifact_id == artifact.artifact_id
|
||||
@ -285,47 +277,37 @@ async def test_get_artifact():
|
||||
== "file:///path/to/test_get_artifact_sample_file.txt"
|
||||
)
|
||||
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
db = AgentDB(db_name)
|
||||
|
||||
async def test_list_tasks(agent_db: AgentDB):
|
||||
# Given: Multiple tasks in the database
|
||||
task1 = await db.create_task("test_input_1")
|
||||
task2 = await db.create_task("test_input_2")
|
||||
task1 = await agent_db.create_task("test_input_1")
|
||||
task2 = await agent_db.create_task("test_input_2")
|
||||
|
||||
# When: All tasks are fetched
|
||||
fetched_tasks, pagination = await db.list_tasks()
|
||||
fetched_tasks, pagination = await agent_db.list_tasks()
|
||||
|
||||
# Then: The fetched tasks list includes the created tasks
|
||||
task_ids = [task.task_id for task in fetched_tasks]
|
||||
assert task1.task_id in task_ids
|
||||
assert task2.task_id in task_ids
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_steps():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
db = AgentDB(db_name)
|
||||
|
||||
step_input = StepInput(type="python/code")
|
||||
requst = StepRequestBody(input="test_input debug", additional_input=step_input)
|
||||
async def test_list_steps(agent_db: AgentDB):
|
||||
step_input = {"type": "python/code"}
|
||||
request = StepRequestBody(input="test_input debug", additional_input=step_input)
|
||||
|
||||
# Given: A task and multiple steps for that task
|
||||
task = await db.create_task("test_input")
|
||||
step1 = await db.create_step(task.task_id, requst)
|
||||
requst = StepRequestBody(input="step two", additional_input=step_input)
|
||||
step2 = await db.create_step(task.task_id, requst)
|
||||
task = await agent_db.create_task("test_input")
|
||||
step1 = await agent_db.create_step(task.task_id, request)
|
||||
request = StepRequestBody(input="step two")
|
||||
step2 = await agent_db.create_step(task.task_id, request)
|
||||
|
||||
# When: All steps for the task are fetched
|
||||
fetched_steps, pagination = await db.list_steps(task.task_id)
|
||||
fetched_steps, pagination = await agent_db.list_steps(task.task_id)
|
||||
|
||||
# Then: The fetched steps list includes the created steps
|
||||
step_ids = [step.step_id for step in fetched_steps]
|
||||
assert step1.step_id in step_ids
|
||||
assert step2.step_id in step_ids
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .artifact import Artifact, ArtifactUpload
|
||||
from .artifact import Artifact
|
||||
from .pagination import Pagination
|
||||
from .task import (
|
||||
Step,
|
||||
@ -10,3 +10,16 @@ from .task import (
|
||||
TaskRequestBody,
|
||||
TaskStepsListResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Artifact",
|
||||
"Pagination",
|
||||
"Step",
|
||||
"StepRequestBody",
|
||||
"StepStatus",
|
||||
"Task",
|
||||
"TaskArtifactsListResponse",
|
||||
"TaskListResponse",
|
||||
"TaskRequestBody",
|
||||
"TaskStepsListResponse",
|
||||
]
|
||||
|
@ -3,15 +3,6 @@ from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ArtifactUpload(BaseModel):
|
||||
file: str = Field(..., description="File to upload.", format="binary")
|
||||
relative_path: str = Field(
|
||||
...,
|
||||
description="Relative path of the artifact in the agent's workspace.",
|
||||
example="python/code",
|
||||
)
|
||||
|
||||
|
||||
class Artifact(BaseModel):
|
||||
created_at: datetime = Field(
|
||||
...,
|
||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -17,7 +17,7 @@ class TaskRequestBody(BaseModel):
|
||||
description="Input prompt for the task.",
|
||||
example="Write the words you receive to the file 'output.txt'.",
|
||||
)
|
||||
additional_input: Optional[dict] = None
|
||||
additional_input: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class Task(TaskRequestBody):
|
||||
@ -38,8 +38,8 @@ class Task(TaskRequestBody):
|
||||
description="The ID of the task.",
|
||||
example="50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
)
|
||||
artifacts: Optional[List[Artifact]] = Field(
|
||||
[],
|
||||
artifacts: list[Artifact] = Field(
|
||||
default_factory=list,
|
||||
description="A list of artifacts that the task has produced.",
|
||||
example=[
|
||||
"7a49f31c-f9c6-4346-a22c-e32bc5af4d8e",
|
||||
@ -50,14 +50,12 @@ class Task(TaskRequestBody):
|
||||
|
||||
class StepRequestBody(BaseModel):
|
||||
name: Optional[str] = Field(
|
||||
None, description="The name of the task step.", example="Write to file"
|
||||
default=None, description="The name of the task step.", example="Write to file"
|
||||
)
|
||||
input: Optional[str] = Field(
|
||||
None,
|
||||
description="Input prompt for the step.",
|
||||
example="Washington",
|
||||
input: str = Field(
|
||||
..., description="Input prompt for the step.", example="Washington"
|
||||
)
|
||||
additional_input: Optional[dict] = None
|
||||
additional_input: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class StepStatus(Enum):
|
||||
@ -90,19 +88,23 @@ class Step(StepRequestBody):
|
||||
example="6bb1801a-fd80-45e8-899a-4dd723cc602e",
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
None, description="The name of the task step.", example="Write to file"
|
||||
default=None, description="The name of the task step.", example="Write to file"
|
||||
)
|
||||
status: StepStatus = Field(
|
||||
..., description="The status of the task step.", example="created"
|
||||
)
|
||||
output: Optional[str] = Field(
|
||||
None,
|
||||
default=None,
|
||||
description="Output of the task step.",
|
||||
example="I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')",
|
||||
example=(
|
||||
"I am going to use the write_to_file command and write Washington "
|
||||
"to a file called output.txt <write_to_file('output.txt', 'Washington')"
|
||||
),
|
||||
)
|
||||
additional_output: Optional[dict] = None
|
||||
artifacts: Optional[List[Artifact]] = Field(
|
||||
[], description="A list of artifacts that the step has produced."
|
||||
additional_output: Optional[dict[str, Any]] = None
|
||||
artifacts: list[Artifact] = Field(
|
||||
default_factory=list,
|
||||
description="A list of artifacts that the step has produced.",
|
||||
)
|
||||
is_last: bool = Field(
|
||||
..., description="Whether this is the last step in the task.", example=True
|
||||
|
@ -1,3 +1,5 @@
|
||||
from .command import Command, CommandOutput, CommandParameter
|
||||
from .command import Command
|
||||
from .decorator import command
|
||||
from .parameter import CommandParameter
|
||||
|
||||
__all__ = ["Command", "CommandParameter", "command"]
|
||||
|
@ -1,14 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Generic, ParamSpec, TypeVar
|
||||
from typing import Callable, Concatenate, Generic, ParamSpec, TypeVar, cast
|
||||
|
||||
from forge.agent.protocols import CommandProvider
|
||||
|
||||
from .parameter import CommandParameter
|
||||
|
||||
CommandOutput = Any
|
||||
|
||||
P = ParamSpec("P")
|
||||
CO = TypeVar("CO", bound=CommandOutput)
|
||||
CO = TypeVar("CO") # command output
|
||||
|
||||
_CP = TypeVar("_CP", bound=CommandProvider)
|
||||
|
||||
|
||||
class Command(Generic[P, CO]):
|
||||
@ -24,7 +26,7 @@ class Command(Generic[P, CO]):
|
||||
self,
|
||||
names: list[str],
|
||||
description: str,
|
||||
method: Callable[P, CO],
|
||||
method: Callable[Concatenate[_CP, P], CO],
|
||||
parameters: list[CommandParameter],
|
||||
):
|
||||
# Check if all parameters are provided
|
||||
@ -34,7 +36,9 @@ class Command(Generic[P, CO]):
|
||||
)
|
||||
self.names = names
|
||||
self.description = description
|
||||
self.method = method
|
||||
# Method technically has a `self` parameter, but we can ignore that
|
||||
# since Python passes it internally.
|
||||
self.method = cast(Callable[P, CO], method)
|
||||
self.parameters = parameters
|
||||
|
||||
@property
|
||||
@ -62,7 +66,8 @@ class Command(Generic[P, CO]):
|
||||
def __str__(self) -> str:
|
||||
params = [
|
||||
f"{param.name}: "
|
||||
+ ("%s" if param.spec.required else "Optional[%s]") % param.spec.type.value
|
||||
+ ("%s" if param.spec.required else "Optional[%s]")
|
||||
% (param.spec.type.value if param.spec.type else "Any")
|
||||
for param in self.parameters
|
||||
]
|
||||
return (
|
||||
|
@ -1,2 +1,4 @@
|
||||
from .action_history import ActionHistoryComponent
|
||||
from .model import Episode, EpisodicActionHistory
|
||||
|
||||
__all__ = ["ActionHistoryComponent", "Episode", "EpisodicActionHistory"]
|
||||
|
@ -1,27 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Callable, Generic, Iterator, Optional
|
||||
from typing import TYPE_CHECKING, Callable, Iterator, Optional
|
||||
|
||||
from forge.agent.protocols import AfterExecute, AfterParse, MessageProvider
|
||||
from forge.llm.prompting.utils import indent
|
||||
from forge.llm.providers import ChatMessage, ChatModelProvider
|
||||
from forge.llm.providers import ChatMessage, MultiProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.config.config import Config
|
||||
|
||||
from .model import AP, ActionResult, Episode, EpisodicActionHistory
|
||||
from .model import ActionResult, AnyProposal, Episode, EpisodicActionHistory
|
||||
|
||||
|
||||
class ActionHistoryComponent(MessageProvider, AfterParse, AfterExecute, Generic[AP]):
|
||||
class ActionHistoryComponent(MessageProvider, AfterParse[AnyProposal], AfterExecute):
|
||||
"""Keeps track of the event history and provides a summary of the steps."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_history: EpisodicActionHistory[AP],
|
||||
event_history: EpisodicActionHistory[AnyProposal],
|
||||
max_tokens: int,
|
||||
count_tokens: Callable[[str], int],
|
||||
legacy_config: Config,
|
||||
llm_provider: ChatModelProvider,
|
||||
llm_provider: MultiProvider,
|
||||
) -> None:
|
||||
self.event_history = event_history
|
||||
self.max_tokens = max_tokens
|
||||
@ -37,7 +37,7 @@ class ActionHistoryComponent(MessageProvider, AfterParse, AfterExecute, Generic[
|
||||
):
|
||||
yield ChatMessage.system(f"## Progress on your Task so far\n\n{progress}")
|
||||
|
||||
def after_parse(self, result: AP) -> None:
|
||||
def after_parse(self, result: AnyProposal) -> None:
|
||||
self.event_history.register_action(result)
|
||||
|
||||
async def after_execute(self, result: ActionResult) -> None:
|
||||
@ -48,7 +48,7 @@ class ActionHistoryComponent(MessageProvider, AfterParse, AfterExecute, Generic[
|
||||
|
||||
def _compile_progress(
|
||||
self,
|
||||
episode_history: list[Episode],
|
||||
episode_history: list[Episode[AnyProposal]],
|
||||
max_tokens: Optional[int] = None,
|
||||
count_tokens: Optional[Callable[[str], int]] = None,
|
||||
) -> str:
|
||||
|
@ -1,25 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Generic, Iterator, TypeVar
|
||||
from typing import TYPE_CHECKING, Generic
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
from forge.content_processing.text import summarize_text
|
||||
from forge.llm.prompting.utils import format_numbered_list, indent
|
||||
from forge.models.action import ActionProposal, ActionResult
|
||||
from forge.models.action import ActionResult, AnyProposal
|
||||
from forge.models.utils import ModelWithSummary
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.config.config import Config
|
||||
from forge.llm.providers import ChatModelProvider
|
||||
|
||||
AP = TypeVar("AP", bound=ActionProposal)
|
||||
from forge.llm.providers import MultiProvider
|
||||
|
||||
|
||||
class Episode(GenericModel, Generic[AP]):
|
||||
action: AP
|
||||
class Episode(GenericModel, Generic[AnyProposal]):
|
||||
action: AnyProposal
|
||||
result: ActionResult | None
|
||||
summary: str | None = None
|
||||
|
||||
@ -54,32 +52,29 @@ class Episode(GenericModel, Generic[AP]):
|
||||
return executed_action + action_result
|
||||
|
||||
|
||||
class EpisodicActionHistory(GenericModel, Generic[AP]):
|
||||
class EpisodicActionHistory(GenericModel, Generic[AnyProposal]):
|
||||
"""Utility container for an action history"""
|
||||
|
||||
episodes: list[Episode[AP]] = Field(default_factory=list)
|
||||
episodes: list[Episode[AnyProposal]] = Field(default_factory=list)
|
||||
cursor: int = 0
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def current_episode(self) -> Episode[AP] | None:
|
||||
def current_episode(self) -> Episode[AnyProposal] | None:
|
||||
if self.cursor == len(self):
|
||||
return None
|
||||
return self[self.cursor]
|
||||
|
||||
def __getitem__(self, key: int) -> Episode[AP]:
|
||||
def __getitem__(self, key: int) -> Episode[AnyProposal]:
|
||||
return self.episodes[key]
|
||||
|
||||
def __iter__(self) -> Iterator[Episode[AP]]:
|
||||
return iter(self.episodes)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.episodes)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return len(self.episodes) > 0
|
||||
|
||||
def register_action(self, action: AP) -> None:
|
||||
def register_action(self, action: AnyProposal) -> None:
|
||||
if not self.current_episode:
|
||||
self.episodes.append(Episode(action=action, result=None))
|
||||
assert self.current_episode
|
||||
@ -113,7 +108,7 @@ class EpisodicActionHistory(GenericModel, Generic[AP]):
|
||||
self.cursor = len(self.episodes)
|
||||
|
||||
async def handle_compression(
|
||||
self, llm_provider: ChatModelProvider, app_config: Config
|
||||
self, llm_provider: MultiProvider, app_config: Config
|
||||
) -> None:
|
||||
"""Compresses each episode in the action history using an LLM.
|
||||
|
||||
|
@ -3,6 +3,11 @@ from .code_executor import (
|
||||
DENYLIST_CONTROL,
|
||||
CodeExecutionError,
|
||||
CodeExecutorComponent,
|
||||
is_docker_available,
|
||||
we_are_running_in_a_docker_container,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ALLOWLIST_CONTROL",
|
||||
"DENYLIST_CONTROL",
|
||||
"CodeExecutionError",
|
||||
"CodeExecutorComponent",
|
||||
]
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user