mirror of
https://github.com/Significant-Gravitas/Auto-GPT.git
synced 2025-01-09 04:19:02 +08:00
Set up unified pre-commit + CI w/ linting + type checking & FIX EVERYTHING (#7171)
- **FIX ALL LINT/TYPE ERRORS IN AUTOGPT, FORGE, AND BENCHMARK** ### Linting - Clean up linter configs for `autogpt`, `forge`, and `benchmark` - Add type checking with Pyright - Create unified pre-commit config - Create unified linting and type checking CI workflow ### Testing - Synchronize CI test setups for `autogpt`, `forge`, and `benchmark` - Add missing pytest-cov to benchmark dependencies - Mark GCS tests as slow to speed up pre-commit test runs - Repair `forge` test suite - Add `AgentDB.close()` method for test DB teardown in db_test.py - Use actual temporary dir instead of forge/test_workspace/ - Move left-behind dependencies for moved `forge`-code to from autogpt to forge ### Notable type changes - Replace uses of `ChatModelProvider` by `MultiProvider` - Removed unnecessary exports from various __init__.py - Simplify `FileStorage.open_file` signature by removing `IOBase` from return type union - Implement `S3BinaryIOWrapper(BinaryIO)` type interposer for `S3FileStorage` - Expand overloads of `GCSFileStorage.open_file` for improved typing of read and write modes Had to silence type checking for the extra overloads, because (I think) Pyright is reporting a false-positive: https://github.com/microsoft/pyright/issues/8007 - Change `count_tokens`, `get_tokenizer`, `count_message_tokens` methods on `ModelProvider`s from class methods to instance methods - Move `CompletionModelFunction.schema` method -> helper function `format_function_def_for_openai` in `forge.llm.providers.openai` - Rename `ModelProvider` -> `BaseModelProvider` - Rename `ChatModelProvider` -> `BaseChatModelProvider` - Add type `ChatModelProvider` which is a union of all subclasses of `BaseChatModelProvider` ### Removed rather than fixed - Remove deprecated and broken autogpt/agbenchmark_config/benchmarks.py - Various base classes and properties on base classes in `forge.llm.providers.schema` and `forge.models.providers` ### Fixes for other issues that came to light - Clean up `forge.agent_protocol.api_router`, `forge.agent_protocol.database`, and `forge.agent.agent` - Add fallback behavior to `ImageGeneratorComponent` - Remove test for deprecated failure behavior - Fix `agbenchmark.challenges.builtin` challenge exclusion mechanism on Windows - Fix `_tool_calls_compat_extract_calls` in `forge.llm.providers.openai` - Add support for `any` (= no type specified) in `JSONSchema.typescript_type`
This commit is contained in:
parent
2c13a2706c
commit
f107ff8cf0
53
.github/workflows/autogpt-ci.yml
vendored
53
.github/workflows/autogpt-ci.yml
vendored
@ -1,4 +1,4 @@
|
|||||||
name: AutoGPT Python CI
|
name: AutoGPT CI
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@ -24,57 +24,6 @@ defaults:
|
|||||||
working-directory: autogpt
|
working-directory: autogpt
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
lint:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
env:
|
|
||||||
min-python-version: "3.10"
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
|
|
||||||
- name: Set up Python ${{ env.min-python-version }}
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: ${{ env.min-python-version }}
|
|
||||||
|
|
||||||
- id: get_date
|
|
||||||
name: Get date
|
|
||||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: ~/.cache/pypoetry
|
|
||||||
key: ${{ runner.os }}-poetry-${{ hashFiles('autogpt/pyproject.toml') }}-${{ steps.get_date.outputs.date }}
|
|
||||||
|
|
||||||
- name: Install Python dependencies
|
|
||||||
run: |
|
|
||||||
curl -sSL https://install.python-poetry.org | python3 -
|
|
||||||
poetry install
|
|
||||||
|
|
||||||
- name: Lint with flake8
|
|
||||||
run: poetry run flake8
|
|
||||||
|
|
||||||
- name: Check black formatting
|
|
||||||
run: poetry run black . --check
|
|
||||||
if: success() || failure()
|
|
||||||
|
|
||||||
- name: Check isort formatting
|
|
||||||
run: poetry run isort . --check
|
|
||||||
if: success() || failure()
|
|
||||||
|
|
||||||
# - name: Check mypy formatting
|
|
||||||
# run: poetry run mypy
|
|
||||||
# if: success() || failure()
|
|
||||||
|
|
||||||
# - name: Check for unused imports and pass statements
|
|
||||||
# run: |
|
|
||||||
# cmd="autoflake --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring autogpt tests"
|
|
||||||
# poetry run $cmd --check || (echo "You have unused imports or pass statements, please run '${cmd} --in-place'" && exit 1)
|
|
||||||
|
|
||||||
test:
|
test:
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
|
4
.github/workflows/autogpts-ci.yml
vendored
4
.github/workflows/autogpts-ci.yml
vendored
@ -1,4 +1,4 @@
|
|||||||
name: AutoGPTs smoke test CI
|
name: Agent smoke tests
|
||||||
|
|
||||||
on:
|
on:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
@ -28,7 +28,7 @@ on:
|
|||||||
- '!**/*.md'
|
- '!**/*.md'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run-tests:
|
serve-agent-protocol:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
|
102
.github/workflows/benchmark-ci.yml
vendored
102
.github/workflows/benchmark-ci.yml
vendored
@ -1,4 +1,4 @@
|
|||||||
name: Benchmark CI
|
name: AGBenchmark CI
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@ -14,62 +14,91 @@ on:
|
|||||||
- '!benchmark/reports/**'
|
- '!benchmark/reports/**'
|
||||||
- .github/workflows/benchmark-ci.yml
|
- .github/workflows/benchmark-ci.yml
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ format('benchmark-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||||
|
cancel-in-progress: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: bash
|
||||||
|
|
||||||
env:
|
env:
|
||||||
min-python-version: '3.10'
|
min-python-version: '3.10'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
lint:
|
test:
|
||||||
runs-on: ubuntu-latest
|
permissions:
|
||||||
|
contents: read
|
||||||
|
timeout-minutes: 30
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10"]
|
||||||
|
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||||
|
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: bash
|
||||||
|
working-directory: benchmark
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
submodules: true
|
||||||
|
|
||||||
- name: Set up Python ${{ env.min-python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ env.min-python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
- id: get_date
|
- name: Set up Python dependency cache
|
||||||
name: Get date
|
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||||
working-directory: ./benchmark/
|
if: runner.os != 'Windows'
|
||||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||||
|
key: poetry-${{ runner.os }}-${{ hashFiles('benchmark/poetry.lock') }}
|
||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry (Unix)
|
||||||
working-directory: ./benchmark/
|
if: runner.os != 'Windows'
|
||||||
run: |
|
run: |
|
||||||
curl -sSL https://install.python-poetry.org | python -
|
curl -sSL https://install.python-poetry.org | python3 -
|
||||||
|
|
||||||
- name: Install dependencies
|
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||||
working-directory: ./benchmark/
|
PATH="$HOME/.local/bin:$PATH"
|
||||||
|
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Install Poetry (Windows)
|
||||||
|
if: runner.os == 'Windows'
|
||||||
|
shell: pwsh
|
||||||
run: |
|
run: |
|
||||||
export POETRY_VIRTUALENVS_IN_PROJECT=true
|
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||||
poetry install -vvv
|
|
||||||
|
|
||||||
- name: Lint with flake8
|
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||||
working-directory: ./benchmark/
|
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||||
run: poetry run flake8
|
|
||||||
|
|
||||||
- name: Check black formatting
|
- name: Install Python dependencies
|
||||||
working-directory: ./benchmark/
|
run: poetry install
|
||||||
run: poetry run black . --exclude test.py --check
|
|
||||||
if: success() || failure()
|
|
||||||
|
|
||||||
- name: Check isort formatting
|
- name: Run pytest with coverage
|
||||||
working-directory: ./benchmark/
|
|
||||||
run: poetry run isort . --check
|
|
||||||
if: success() || failure()
|
|
||||||
|
|
||||||
- name: Check for unused imports and pass statements
|
|
||||||
working-directory: ./benchmark/
|
|
||||||
run: |
|
run: |
|
||||||
cmd="poetry run autoflake --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring agbenchmark"
|
poetry run pytest -vv \
|
||||||
$cmd --check || (echo "You have unused imports or pass statements, please run '${cmd} --in-place'" && exit 1)
|
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
|
||||||
if: success() || failure()
|
--durations=10 \
|
||||||
|
tests
|
||||||
|
env:
|
||||||
|
CI: true
|
||||||
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
|
||||||
tests-agbenchmark:
|
- name: Upload coverage reports to Codecov
|
||||||
|
uses: codecov/codecov-action@v4
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
flags: agbenchmark,${{ runner.os }}
|
||||||
|
|
||||||
|
self-test-with-agent:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@ -89,11 +118,11 @@ jobs:
|
|||||||
python-version: ${{ env.min-python-version }}
|
python-version: ${{ env.min-python-version }}
|
||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry
|
||||||
working-directory: ./${{ matrix.agent-name }}/
|
|
||||||
run: |
|
run: |
|
||||||
curl -sSL https://install.python-poetry.org | python -
|
curl -sSL https://install.python-poetry.org | python -
|
||||||
|
|
||||||
- name: Run regression tests
|
- name: Run regression tests
|
||||||
|
working-directory: .
|
||||||
run: |
|
run: |
|
||||||
./run agent start ${{ matrix.agent-name }}
|
./run agent start ${{ matrix.agent-name }}
|
||||||
cd ${{ matrix.agent-name }}
|
cd ${{ matrix.agent-name }}
|
||||||
@ -125,7 +154,6 @@ jobs:
|
|||||||
export BUILD_SKILL_TREE=true
|
export BUILD_SKILL_TREE=true
|
||||||
|
|
||||||
poetry run agbenchmark --mock
|
poetry run agbenchmark --mock
|
||||||
poetry run pytest -vv -s tests
|
|
||||||
|
|
||||||
CHANGED=$(git diff --name-only | grep -E '(agbenchmark/challenges)|(../frontend/assets)') || echo "No diffs"
|
CHANGED=$(git diff --name-only | grep -E '(agbenchmark/challenges)|(../frontend/assets)') || echo "No diffs"
|
||||||
if [ ! -z "$CHANGED" ]; then
|
if [ ! -z "$CHANGED" ]; then
|
||||||
|
129
.github/workflows/forge-ci.yml
vendored
Normal file
129
.github/workflows/forge-ci.yml
vendored
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
name: Forge CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ master, development, ci-test* ]
|
||||||
|
paths:
|
||||||
|
- '.github/workflows/forge-ci.yml'
|
||||||
|
- 'forge/**'
|
||||||
|
pull_request:
|
||||||
|
branches: [ master, development, release-* ]
|
||||||
|
paths:
|
||||||
|
- '.github/workflows/forge-ci.yml'
|
||||||
|
- 'forge/**'
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ format('forge-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||||
|
cancel-in-progress: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: bash
|
||||||
|
working-directory: forge
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
timeout-minutes: 30
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10"]
|
||||||
|
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||||
|
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||||
|
# - name: Set up Docker (macOS)
|
||||||
|
# if: runner.os == 'macOS'
|
||||||
|
# uses: crazy-max/ghaction-setup-docker@v3
|
||||||
|
|
||||||
|
- name: Start MinIO service (Linux)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
working-directory: '.'
|
||||||
|
run: |
|
||||||
|
docker pull minio/minio:edge-cicd
|
||||||
|
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||||
|
|
||||||
|
- name: Start MinIO service (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
working-directory: ${{ runner.temp }}
|
||||||
|
run: |
|
||||||
|
brew install minio/stable/minio
|
||||||
|
mkdir data
|
||||||
|
minio server ./data &
|
||||||
|
|
||||||
|
# No MinIO on Windows:
|
||||||
|
# - Windows doesn't support running Linux Docker containers
|
||||||
|
# - It doesn't seem possible to start background processes on Windows. They are
|
||||||
|
# killed after the step returns.
|
||||||
|
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||||
|
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
submodules: true
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Set up Python dependency cache
|
||||||
|
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||||
|
if: runner.os != 'Windows'
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||||
|
key: poetry-${{ runner.os }}-${{ hashFiles('forge/poetry.lock') }}
|
||||||
|
|
||||||
|
- name: Install Poetry (Unix)
|
||||||
|
if: runner.os != 'Windows'
|
||||||
|
run: |
|
||||||
|
curl -sSL https://install.python-poetry.org | python3 -
|
||||||
|
|
||||||
|
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||||
|
PATH="$HOME/.local/bin:$PATH"
|
||||||
|
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Install Poetry (Windows)
|
||||||
|
if: runner.os == 'Windows'
|
||||||
|
shell: pwsh
|
||||||
|
run: |
|
||||||
|
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||||
|
|
||||||
|
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||||
|
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
run: poetry install
|
||||||
|
|
||||||
|
- name: Run pytest with coverage
|
||||||
|
run: |
|
||||||
|
poetry run pytest -vv \
|
||||||
|
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
|
||||||
|
--durations=10 \
|
||||||
|
forge
|
||||||
|
env:
|
||||||
|
CI: true
|
||||||
|
PLAIN_OUTPUT: True
|
||||||
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||||
|
AWS_ACCESS_KEY_ID: minioadmin
|
||||||
|
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||||
|
|
||||||
|
- name: Upload coverage reports to Codecov
|
||||||
|
uses: codecov/codecov-action@v4
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
flags: forge,${{ runner.os }}
|
||||||
|
|
||||||
|
- name: Upload logs to artifact
|
||||||
|
if: always()
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: test-logs
|
||||||
|
path: forge/logs/
|
151
.github/workflows/python-checks.yml
vendored
Normal file
151
.github/workflows/python-checks.yml
vendored
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
name: Python checks
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ master, development, ci-test* ]
|
||||||
|
paths:
|
||||||
|
- '.github/workflows/lint-ci.yml'
|
||||||
|
- 'autogpt/**'
|
||||||
|
- 'forge/**'
|
||||||
|
- 'benchmark/**'
|
||||||
|
- '**.py'
|
||||||
|
- '!autogpt/tests/vcr_cassettes'
|
||||||
|
pull_request:
|
||||||
|
branches: [ master, development, release-* ]
|
||||||
|
paths:
|
||||||
|
- '.github/workflows/lint-ci.yml'
|
||||||
|
- 'autogpt/**'
|
||||||
|
- 'forge/**'
|
||||||
|
- 'benchmark/**'
|
||||||
|
- '**.py'
|
||||||
|
- '!autogpt/tests/vcr_cassettes'
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ format('lint-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||||
|
cancel-in-progress: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: bash
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
get-changed-parts:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- id: changes-in
|
||||||
|
name: Determine affected subprojects
|
||||||
|
uses: dorny/paths-filter@v3
|
||||||
|
with:
|
||||||
|
filters: |
|
||||||
|
autogpt:
|
||||||
|
- autogpt/autogpt/**
|
||||||
|
- autogpt/tests/**
|
||||||
|
- autogpt/poetry.lock
|
||||||
|
forge:
|
||||||
|
- forge/forge/**
|
||||||
|
- forge/tests/**
|
||||||
|
- forge/poetry.lock
|
||||||
|
benchmark:
|
||||||
|
- benchmark/agbenchmark/**
|
||||||
|
- benchmark/tests/**
|
||||||
|
- benchmark/poetry.lock
|
||||||
|
outputs:
|
||||||
|
changed-parts: ${{ steps.changes-in.outputs.changes }}
|
||||||
|
|
||||||
|
lint:
|
||||||
|
needs: get-changed-parts
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
min-python-version: "3.10"
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Set up Python ${{ env.min-python-version }}
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ env.min-python-version }}
|
||||||
|
|
||||||
|
- name: Set up Python dependency cache
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pypoetry
|
||||||
|
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||||
|
|
||||||
|
- name: Install Poetry
|
||||||
|
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
run: poetry -C ${{ matrix.sub-package }} install
|
||||||
|
|
||||||
|
# Lint
|
||||||
|
|
||||||
|
- name: Lint (isort)
|
||||||
|
run: poetry run isort --check .
|
||||||
|
working-directory: ${{ matrix.sub-package }}
|
||||||
|
|
||||||
|
- name: Lint (Black)
|
||||||
|
if: success() || failure()
|
||||||
|
run: poetry run black --check .
|
||||||
|
working-directory: ${{ matrix.sub-package }}
|
||||||
|
|
||||||
|
- name: Lint (Flake8)
|
||||||
|
if: success() || failure()
|
||||||
|
run: poetry run flake8 .
|
||||||
|
working-directory: ${{ matrix.sub-package }}
|
||||||
|
|
||||||
|
types:
|
||||||
|
needs: get-changed-parts
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
min-python-version: "3.10"
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Set up Python ${{ env.min-python-version }}
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ env.min-python-version }}
|
||||||
|
|
||||||
|
- name: Set up Python dependency cache
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pypoetry
|
||||||
|
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||||
|
|
||||||
|
- name: Install Poetry
|
||||||
|
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
run: poetry -C ${{ matrix.sub-package }} install
|
||||||
|
|
||||||
|
# Typecheck
|
||||||
|
|
||||||
|
- name: Typecheck
|
||||||
|
if: success() || failure()
|
||||||
|
run: poetry run pyright
|
||||||
|
working-directory: ${{ matrix.sub-package }}
|
127
.pre-commit-config.yaml
Normal file
127
.pre-commit-config.yaml
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v4.4.0
|
||||||
|
hooks:
|
||||||
|
- id: check-added-large-files
|
||||||
|
args: ["--maxkb=500"]
|
||||||
|
- id: fix-byte-order-marker
|
||||||
|
- id: check-case-conflict
|
||||||
|
- id: check-merge-conflict
|
||||||
|
- id: check-symlinks
|
||||||
|
- id: debug-statements
|
||||||
|
|
||||||
|
- repo: local
|
||||||
|
# isort needs the context of which packages are installed to function, so we
|
||||||
|
# can't use a vendored isort pre-commit hook (which runs in its own isolated venv).
|
||||||
|
hooks:
|
||||||
|
- id: isort-autogpt
|
||||||
|
name: Lint (isort) - AutoGPT
|
||||||
|
entry: poetry -C autogpt run isort
|
||||||
|
files: ^autogpt/
|
||||||
|
types: [file, python]
|
||||||
|
language: system
|
||||||
|
|
||||||
|
- id: isort-forge
|
||||||
|
name: Lint (isort) - Forge
|
||||||
|
entry: poetry -C forge run isort
|
||||||
|
files: ^forge/
|
||||||
|
types: [file, python]
|
||||||
|
language: system
|
||||||
|
|
||||||
|
- id: isort-benchmark
|
||||||
|
name: Lint (isort) - Benchmark
|
||||||
|
entry: poetry -C benchmark run isort
|
||||||
|
files: ^benchmark/
|
||||||
|
types: [file, python]
|
||||||
|
language: system
|
||||||
|
|
||||||
|
- repo: https://github.com/psf/black
|
||||||
|
rev: 23.12.1
|
||||||
|
# Black has sensible defaults, doesn't need package context, and ignores
|
||||||
|
# everything in .gitignore, so it works fine without any config or arguments.
|
||||||
|
hooks:
|
||||||
|
- id: black
|
||||||
|
name: Lint (Black)
|
||||||
|
language_version: python3.10
|
||||||
|
|
||||||
|
- repo: https://github.com/PyCQA/flake8
|
||||||
|
rev: 7.0.0
|
||||||
|
# To have flake8 load the config of the individual subprojects, we have to call
|
||||||
|
# them separately.
|
||||||
|
hooks:
|
||||||
|
- id: flake8
|
||||||
|
name: Lint (Flake8) - AutoGPT
|
||||||
|
alias: flake8-autogpt
|
||||||
|
files: ^autogpt/(autogpt|scripts|tests)/
|
||||||
|
args: [--config=autogpt/.flake8]
|
||||||
|
|
||||||
|
- id: flake8
|
||||||
|
name: Lint (Flake8) - Forge
|
||||||
|
alias: flake8-forge
|
||||||
|
files: ^forge/(forge|tests)/
|
||||||
|
args: [--config=forge/.flake8]
|
||||||
|
|
||||||
|
- id: flake8
|
||||||
|
name: Lint (Flake8) - Benchmark
|
||||||
|
alias: flake8-benchmark
|
||||||
|
files: ^benchmark/(agbenchmark|tests)/((?!reports).)*[/.]
|
||||||
|
args: [--config=benchmark/.flake8]
|
||||||
|
|
||||||
|
- repo: local
|
||||||
|
# To have watertight type checking, we check *all* the files in an affected
|
||||||
|
# project. To trigger on poetry.lock we also reset the file `types` filter.
|
||||||
|
hooks:
|
||||||
|
- id: pyright
|
||||||
|
name: Typecheck - AutoGPT
|
||||||
|
alias: pyright-autogpt
|
||||||
|
entry: poetry -C autogpt run pyright
|
||||||
|
args: [-p, autogpt, autogpt]
|
||||||
|
# include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||||
|
files: ^(autogpt/((autogpt|scripts|tests)/|poetry\.lock$)|forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
|
||||||
|
types: [file]
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
|
||||||
|
- id: pyright
|
||||||
|
name: Typecheck - Forge
|
||||||
|
alias: pyright-forge
|
||||||
|
entry: poetry -C forge run pyright
|
||||||
|
args: [-p, forge, forge]
|
||||||
|
files: ^forge/(forge/|poetry\.lock$)
|
||||||
|
types: [file]
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
|
||||||
|
- id: pyright
|
||||||
|
name: Typecheck - Benchmark
|
||||||
|
alias: pyright-benchmark
|
||||||
|
entry: poetry -C benchmark run pyright
|
||||||
|
args: [-p, benchmark, benchmark]
|
||||||
|
files: ^benchmark/(agbenchmark|tests)/
|
||||||
|
types: [file]
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: pytest-autogpt
|
||||||
|
name: Run tests - AutoGPT (excl. slow tests)
|
||||||
|
entry: bash -c 'cd autogpt && poetry run pytest --cov=autogpt -m "not slow" tests/unit tests/integration'
|
||||||
|
# include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||||
|
files: ^(autogpt/((autogpt|tests)/|poetry\.lock$)|forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
|
||||||
|
- id: pytest-forge
|
||||||
|
name: Run tests - Forge (excl. slow tests)
|
||||||
|
entry: bash -c 'cd forge && poetry run pytest --cov=forge -m "not slow"'
|
||||||
|
files: ^forge/(forge/|tests/|poetry\.lock$)
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
|
||||||
|
- id: pytest-benchmark
|
||||||
|
name: Run tests - Benchmark
|
||||||
|
entry: bash -c 'cd benchmark && poetry run pytest --cov=benchmark'
|
||||||
|
files: ^benchmark/(agbenchmark/|tests/|poetry\.lock$)
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
@ -1,11 +1,14 @@
|
|||||||
[flake8]
|
[flake8]
|
||||||
max-line-length = 88
|
max-line-length = 88
|
||||||
extend-exclude =
|
# Ignore rules that conflict with Black code style
|
||||||
.*_cache/,
|
extend-ignore = E203, W503
|
||||||
.venv,
|
exclude =
|
||||||
|
.git,
|
||||||
|
__pycache__/,
|
||||||
|
*.pyc,
|
||||||
|
.pytest_cache/,
|
||||||
|
venv*/,
|
||||||
|
.venv/,
|
||||||
data/,
|
data/,
|
||||||
logs/,
|
logs/,
|
||||||
tests/unit/data/,
|
tests/unit/data/,
|
||||||
extend-ignore =
|
|
||||||
# No whitespace before ':' conflicts with Black style for slices
|
|
||||||
E203,
|
|
||||||
|
@ -1,47 +0,0 @@
|
|||||||
repos:
|
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
||||||
rev: v4.4.0
|
|
||||||
hooks:
|
|
||||||
- id: check-added-large-files
|
|
||||||
args: ['--maxkb=500']
|
|
||||||
- id: check-byte-order-marker
|
|
||||||
- id: check-case-conflict
|
|
||||||
- id: check-merge-conflict
|
|
||||||
- id: check-symlinks
|
|
||||||
- id: debug-statements
|
|
||||||
|
|
||||||
- repo: https://github.com/pycqa/isort
|
|
||||||
rev: 5.12.0
|
|
||||||
hooks:
|
|
||||||
- id: isort
|
|
||||||
language_version: python3.10
|
|
||||||
|
|
||||||
- repo: https://github.com/psf/black
|
|
||||||
rev: 23.3.0
|
|
||||||
hooks:
|
|
||||||
- id: black
|
|
||||||
language_version: python3.10
|
|
||||||
|
|
||||||
- repo: https://github.com/PyCQA/flake8
|
|
||||||
rev: 7.0.0
|
|
||||||
hooks:
|
|
||||||
- id: flake8
|
|
||||||
|
|
||||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
|
||||||
# rev: 'v1.3.0'
|
|
||||||
# hooks:
|
|
||||||
# - id: mypy
|
|
||||||
|
|
||||||
- repo: local
|
|
||||||
hooks:
|
|
||||||
# - id: autoflake
|
|
||||||
# name: autoflake
|
|
||||||
# entry: autoflake --in-place --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring autogpt tests
|
|
||||||
# language: python
|
|
||||||
# types: [ python ]
|
|
||||||
- id: pytest-check
|
|
||||||
name: pytest-check
|
|
||||||
entry: bash -c 'cd autogpt && poetry run pytest --cov=autogpt tests/unit'
|
|
||||||
language: system
|
|
||||||
pass_filenames: false
|
|
||||||
always_run: true
|
|
@ -1,77 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from forge.config.ai_profile import AIProfile
|
|
||||||
from forge.config.config import ConfigBuilder
|
|
||||||
from forge.file_storage import FileStorageBackendName, get_storage
|
|
||||||
from forge.logging.config import configure_logging
|
|
||||||
|
|
||||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
|
||||||
from autogpt.agents.agent_manager import AgentManager
|
|
||||||
from autogpt.app.main import _configure_llm_provider, run_interaction_loop
|
|
||||||
|
|
||||||
LOG_DIR = Path(__file__).parent / "logs"
|
|
||||||
|
|
||||||
|
|
||||||
def run_specific_agent(task: str, continuous_mode: bool = False) -> None:
|
|
||||||
agent = bootstrap_agent(task, continuous_mode)
|
|
||||||
asyncio.run(run_interaction_loop(agent))
|
|
||||||
|
|
||||||
|
|
||||||
def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
|
|
||||||
configure_logging(
|
|
||||||
level=logging.DEBUG,
|
|
||||||
log_dir=LOG_DIR,
|
|
||||||
plain_console_output=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = ConfigBuilder.build_config_from_env()
|
|
||||||
config.continuous_mode = continuous_mode
|
|
||||||
config.continuous_limit = 20
|
|
||||||
config.noninteractive_mode = True
|
|
||||||
|
|
||||||
ai_profile = AIProfile(
|
|
||||||
ai_name="AutoGPT",
|
|
||||||
ai_role="a multi-purpose AI assistant.",
|
|
||||||
ai_goals=[task],
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_settings = AgentSettings(
|
|
||||||
name=Agent.default_settings.name,
|
|
||||||
agent_id=AgentManager.generate_id("AutoGPT-benchmark"),
|
|
||||||
description=Agent.default_settings.description,
|
|
||||||
ai_profile=ai_profile,
|
|
||||||
config=AgentConfiguration(
|
|
||||||
fast_llm=config.fast_llm,
|
|
||||||
smart_llm=config.smart_llm,
|
|
||||||
allow_fs_access=not config.restrict_to_workspace,
|
|
||||||
use_functions_api=config.openai_functions,
|
|
||||||
),
|
|
||||||
history=Agent.default_settings.history.copy(deep=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
local = config.file_storage_backend == FileStorageBackendName.LOCAL
|
|
||||||
restrict_to_root = not local or config.restrict_to_workspace
|
|
||||||
file_storage = get_storage(
|
|
||||||
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root
|
|
||||||
)
|
|
||||||
file_storage.initialize()
|
|
||||||
|
|
||||||
agent = Agent(
|
|
||||||
settings=agent_settings,
|
|
||||||
llm_provider=_configure_llm_provider(config),
|
|
||||||
file_storage=file_storage,
|
|
||||||
legacy_config=config,
|
|
||||||
)
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# The first argument is the script name itself, second is the task
|
|
||||||
if len(sys.argv) != 2:
|
|
||||||
print("Usage: python script.py <task>")
|
|
||||||
sys.exit(1)
|
|
||||||
task = sys.argv[1]
|
|
||||||
run_specific_agent(task, continuous_mode=True)
|
|
@ -4,7 +4,7 @@ from forge.config.ai_directives import AIDirectives
|
|||||||
from forge.config.ai_profile import AIProfile
|
from forge.config.ai_profile import AIProfile
|
||||||
from forge.config.config import Config
|
from forge.config.config import Config
|
||||||
from forge.file_storage.base import FileStorage
|
from forge.file_storage.base import FileStorage
|
||||||
from forge.llm.providers import ChatModelProvider
|
from forge.llm.providers import MultiProvider
|
||||||
|
|
||||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||||
|
|
||||||
@ -15,7 +15,7 @@ def create_agent(
|
|||||||
ai_profile: AIProfile,
|
ai_profile: AIProfile,
|
||||||
app_config: Config,
|
app_config: Config,
|
||||||
file_storage: FileStorage,
|
file_storage: FileStorage,
|
||||||
llm_provider: ChatModelProvider,
|
llm_provider: MultiProvider,
|
||||||
directives: Optional[AIDirectives] = None,
|
directives: Optional[AIDirectives] = None,
|
||||||
) -> Agent:
|
) -> Agent:
|
||||||
if not task:
|
if not task:
|
||||||
@ -39,7 +39,7 @@ def configure_agent_with_state(
|
|||||||
state: AgentSettings,
|
state: AgentSettings,
|
||||||
app_config: Config,
|
app_config: Config,
|
||||||
file_storage: FileStorage,
|
file_storage: FileStorage,
|
||||||
llm_provider: ChatModelProvider,
|
llm_provider: MultiProvider,
|
||||||
) -> Agent:
|
) -> Agent:
|
||||||
return _configure_agent(
|
return _configure_agent(
|
||||||
state=state,
|
state=state,
|
||||||
@ -51,7 +51,7 @@ def configure_agent_with_state(
|
|||||||
|
|
||||||
def _configure_agent(
|
def _configure_agent(
|
||||||
app_config: Config,
|
app_config: Config,
|
||||||
llm_provider: ChatModelProvider,
|
llm_provider: MultiProvider,
|
||||||
file_storage: FileStorage,
|
file_storage: FileStorage,
|
||||||
agent_id: str = "",
|
agent_id: str = "",
|
||||||
task: str = "",
|
task: str = "",
|
||||||
@ -59,20 +59,22 @@ def _configure_agent(
|
|||||||
directives: Optional[AIDirectives] = None,
|
directives: Optional[AIDirectives] = None,
|
||||||
state: Optional[AgentSettings] = None,
|
state: Optional[AgentSettings] = None,
|
||||||
) -> Agent:
|
) -> Agent:
|
||||||
if not (state or agent_id and task and ai_profile and directives):
|
if state:
|
||||||
|
agent_state = state
|
||||||
|
elif agent_id and task and ai_profile and directives:
|
||||||
|
agent_state = state or create_agent_state(
|
||||||
|
agent_id=agent_id,
|
||||||
|
task=task,
|
||||||
|
ai_profile=ai_profile,
|
||||||
|
directives=directives,
|
||||||
|
app_config=app_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Either (state) or (agent_id, task, ai_profile, directives)"
|
"Either (state) or (agent_id, task, ai_profile, directives)"
|
||||||
" must be specified"
|
" must be specified"
|
||||||
)
|
)
|
||||||
|
|
||||||
agent_state = state or create_agent_state(
|
|
||||||
agent_id=agent_id,
|
|
||||||
task=task,
|
|
||||||
ai_profile=ai_profile,
|
|
||||||
directives=directives,
|
|
||||||
app_config=app_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
return Agent(
|
return Agent(
|
||||||
settings=agent_state,
|
settings=agent_state,
|
||||||
llm_provider=llm_provider,
|
llm_provider=llm_provider,
|
||||||
|
@ -7,7 +7,7 @@ from forge.file_storage.base import FileStorage
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from autogpt.agents.agent import Agent
|
from autogpt.agents.agent import Agent
|
||||||
from forge.config.config import Config
|
from forge.config.config import Config
|
||||||
from forge.llm.providers.schema import ChatModelProvider
|
from forge.llm.providers import MultiProvider
|
||||||
|
|
||||||
from .configurators import _configure_agent
|
from .configurators import _configure_agent
|
||||||
from .profile_generator import generate_agent_profile_for_task
|
from .profile_generator import generate_agent_profile_for_task
|
||||||
@ -18,7 +18,7 @@ async def generate_agent_for_task(
|
|||||||
task: str,
|
task: str,
|
||||||
app_config: Config,
|
app_config: Config,
|
||||||
file_storage: FileStorage,
|
file_storage: FileStorage,
|
||||||
llm_provider: ChatModelProvider,
|
llm_provider: MultiProvider,
|
||||||
) -> Agent:
|
) -> Agent:
|
||||||
ai_profile, task_directives = await generate_agent_profile_for_task(
|
ai_profile, task_directives = await generate_agent_profile_for_task(
|
||||||
task=task,
|
task=task,
|
||||||
|
@ -5,10 +5,10 @@ from forge.config.ai_directives import AIDirectives
|
|||||||
from forge.config.ai_profile import AIProfile
|
from forge.config.ai_profile import AIProfile
|
||||||
from forge.config.config import Config
|
from forge.config.config import Config
|
||||||
from forge.llm.prompting import ChatPrompt, LanguageModelClassification, PromptStrategy
|
from forge.llm.prompting import ChatPrompt, LanguageModelClassification, PromptStrategy
|
||||||
|
from forge.llm.providers import MultiProvider
|
||||||
from forge.llm.providers.schema import (
|
from forge.llm.providers.schema import (
|
||||||
AssistantChatMessage,
|
AssistantChatMessage,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
ChatModelProvider,
|
|
||||||
CompletionModelFunction,
|
CompletionModelFunction,
|
||||||
)
|
)
|
||||||
from forge.models.config import SystemConfiguration, UserConfigurable
|
from forge.models.config import SystemConfiguration, UserConfigurable
|
||||||
@ -141,7 +141,7 @@ class AgentProfileGeneratorConfiguration(SystemConfiguration):
|
|||||||
required=True,
|
required=True,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
).schema
|
).dict()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -160,7 +160,7 @@ class AgentProfileGenerator(PromptStrategy):
|
|||||||
self._model_classification = model_classification
|
self._model_classification = model_classification
|
||||||
self._system_prompt_message = system_prompt
|
self._system_prompt_message = system_prompt
|
||||||
self._user_prompt_template = user_prompt_template
|
self._user_prompt_template = user_prompt_template
|
||||||
self._create_agent_function = CompletionModelFunction.parse(
|
self._create_agent_function = CompletionModelFunction.parse_obj(
|
||||||
create_agent_function
|
create_agent_function
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -183,7 +183,7 @@ class AgentProfileGenerator(PromptStrategy):
|
|||||||
|
|
||||||
def parse_response_content(
|
def parse_response_content(
|
||||||
self,
|
self,
|
||||||
response_content: AssistantChatMessage,
|
response: AssistantChatMessage,
|
||||||
) -> tuple[AIProfile, AIDirectives]:
|
) -> tuple[AIProfile, AIDirectives]:
|
||||||
"""Parse the actual text response from the objective model.
|
"""Parse the actual text response from the objective model.
|
||||||
|
|
||||||
@ -195,15 +195,15 @@ class AgentProfileGenerator(PromptStrategy):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if not response_content.tool_calls:
|
if not response.tool_calls:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"LLM did not call {self._create_agent_function.name} function; "
|
f"LLM did not call {self._create_agent_function.name} function; "
|
||||||
"agent profile creation failed"
|
"agent profile creation failed"
|
||||||
)
|
)
|
||||||
arguments: object = response_content.tool_calls[0].function.arguments
|
arguments: object = response.tool_calls[0].function.arguments
|
||||||
ai_profile = AIProfile(
|
ai_profile = AIProfile(
|
||||||
ai_name=arguments.get("name"),
|
ai_name=arguments.get("name"), # type: ignore
|
||||||
ai_role=arguments.get("description"),
|
ai_role=arguments.get("description"), # type: ignore
|
||||||
)
|
)
|
||||||
ai_directives = AIDirectives(
|
ai_directives = AIDirectives(
|
||||||
best_practices=arguments.get("directives", {}).get("best_practices"),
|
best_practices=arguments.get("directives", {}).get("best_practices"),
|
||||||
@ -211,7 +211,7 @@ class AgentProfileGenerator(PromptStrategy):
|
|||||||
resources=[],
|
resources=[],
|
||||||
)
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.debug(f"Failed to parse this response content: {response_content}")
|
logger.debug(f"Failed to parse this response content: {response}")
|
||||||
raise
|
raise
|
||||||
return ai_profile, ai_directives
|
return ai_profile, ai_directives
|
||||||
|
|
||||||
@ -219,7 +219,7 @@ class AgentProfileGenerator(PromptStrategy):
|
|||||||
async def generate_agent_profile_for_task(
|
async def generate_agent_profile_for_task(
|
||||||
task: str,
|
task: str,
|
||||||
app_config: Config,
|
app_config: Config,
|
||||||
llm_provider: ChatModelProvider,
|
llm_provider: MultiProvider,
|
||||||
) -> tuple[AIProfile, AIDirectives]:
|
) -> tuple[AIProfile, AIDirectives]:
|
||||||
"""Generates an AIConfig object from the given string.
|
"""Generates an AIConfig object from the given string.
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ class MyAgent(Agent):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
settings: AgentSettings,
|
settings: AgentSettings,
|
||||||
llm_provider: ChatModelProvider,
|
llm_provider: MultiProvider
|
||||||
file_storage: FileStorage,
|
file_storage: FileStorage,
|
||||||
legacy_config: Config,
|
legacy_config: Config,
|
||||||
):
|
):
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
||||||
|
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
from forge.agent.base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
|
from forge.agent.base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
|
||||||
@ -14,7 +14,7 @@ from forge.agent.protocols import (
|
|||||||
DirectiveProvider,
|
DirectiveProvider,
|
||||||
MessageProvider,
|
MessageProvider,
|
||||||
)
|
)
|
||||||
from forge.command.command import Command, CommandOutput
|
from forge.command.command import Command
|
||||||
from forge.components.action_history import (
|
from forge.components.action_history import (
|
||||||
ActionHistoryComponent,
|
ActionHistoryComponent,
|
||||||
EpisodicActionHistory,
|
EpisodicActionHistory,
|
||||||
@ -34,8 +34,8 @@ from forge.llm.prompting.utils import dump_prompt
|
|||||||
from forge.llm.providers import (
|
from forge.llm.providers import (
|
||||||
AssistantFunctionCall,
|
AssistantFunctionCall,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
ChatModelProvider,
|
|
||||||
ChatModelResponse,
|
ChatModelResponse,
|
||||||
|
MultiProvider,
|
||||||
)
|
)
|
||||||
from forge.llm.providers.utils import function_specs_from_commands
|
from forge.llm.providers.utils import function_specs_from_commands
|
||||||
from forge.models.action import (
|
from forge.models.action import (
|
||||||
@ -76,7 +76,9 @@ class AgentConfiguration(BaseAgentConfiguration):
|
|||||||
|
|
||||||
|
|
||||||
class AgentSettings(BaseAgentSettings):
|
class AgentSettings(BaseAgentSettings):
|
||||||
config: AgentConfiguration = Field(default_factory=AgentConfiguration)
|
config: AgentConfiguration = Field( # type: ignore
|
||||||
|
default_factory=AgentConfiguration
|
||||||
|
)
|
||||||
|
|
||||||
history: EpisodicActionHistory[OneShotAgentActionProposal] = Field(
|
history: EpisodicActionHistory[OneShotAgentActionProposal] = Field(
|
||||||
default_factory=EpisodicActionHistory[OneShotAgentActionProposal]
|
default_factory=EpisodicActionHistory[OneShotAgentActionProposal]
|
||||||
@ -86,8 +88,8 @@ class AgentSettings(BaseAgentSettings):
|
|||||||
context: AgentContext = Field(default_factory=AgentContext)
|
context: AgentContext = Field(default_factory=AgentContext)
|
||||||
|
|
||||||
|
|
||||||
class Agent(BaseAgent, Configurable[AgentSettings]):
|
class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
|
||||||
default_settings: AgentSettings = AgentSettings(
|
default_settings: ClassVar[AgentSettings] = AgentSettings(
|
||||||
name="Agent",
|
name="Agent",
|
||||||
description=__doc__ if __doc__ else "",
|
description=__doc__ if __doc__ else "",
|
||||||
)
|
)
|
||||||
@ -95,7 +97,7 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
settings: AgentSettings,
|
settings: AgentSettings,
|
||||||
llm_provider: ChatModelProvider,
|
llm_provider: MultiProvider,
|
||||||
file_storage: FileStorage,
|
file_storage: FileStorage,
|
||||||
legacy_config: Config,
|
legacy_config: Config,
|
||||||
):
|
):
|
||||||
@ -280,7 +282,7 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _execute_tool(self, tool_call: AssistantFunctionCall) -> CommandOutput:
|
async def _execute_tool(self, tool_call: AssistantFunctionCall) -> Any:
|
||||||
"""Execute the command and return the result
|
"""Execute the command and return the result
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -43,7 +43,7 @@ class AssistantThoughts(ModelWithSummary):
|
|||||||
|
|
||||||
|
|
||||||
class OneShotAgentActionProposal(ActionProposal):
|
class OneShotAgentActionProposal(ActionProposal):
|
||||||
thoughts: AssistantThoughts
|
thoughts: AssistantThoughts # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class OneShotAgentPromptConfiguration(SystemConfiguration):
|
class OneShotAgentPromptConfiguration(SystemConfiguration):
|
||||||
@ -186,11 +186,8 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
|||||||
|
|
||||||
def response_format_instruction(self, use_functions_api: bool) -> tuple[str, str]:
|
def response_format_instruction(self, use_functions_api: bool) -> tuple[str, str]:
|
||||||
response_schema = self.response_schema.copy(deep=True)
|
response_schema = self.response_schema.copy(deep=True)
|
||||||
if (
|
assert response_schema.properties
|
||||||
use_functions_api
|
if use_functions_api and "use_tool" in response_schema.properties:
|
||||||
and response_schema.properties
|
|
||||||
and "use_tool" in response_schema.properties
|
|
||||||
):
|
|
||||||
del response_schema.properties["use_tool"]
|
del response_schema.properties["use_tool"]
|
||||||
|
|
||||||
# Unindent for performance
|
# Unindent for performance
|
||||||
@ -288,10 +285,10 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
|||||||
"Parsing object extracted from LLM response:\n"
|
"Parsing object extracted from LLM response:\n"
|
||||||
f"{json.dumps(assistant_reply_dict, indent=4)}"
|
f"{json.dumps(assistant_reply_dict, indent=4)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
parsed_response = OneShotAgentActionProposal.parse_obj(assistant_reply_dict)
|
|
||||||
if self.config.use_functions_api:
|
if self.config.use_functions_api:
|
||||||
if not response.tool_calls:
|
if not response.tool_calls:
|
||||||
raise InvalidAgentResponseError("Assistant did not use a tool")
|
raise InvalidAgentResponseError("Assistant did not use a tool")
|
||||||
parsed_response.use_tool = response.tool_calls[0].function
|
assistant_reply_dict["use_tool"] = response.tool_calls[0].function
|
||||||
|
|
||||||
|
parsed_response = OneShotAgentActionProposal.parse_obj(assistant_reply_dict)
|
||||||
return parsed_response
|
return parsed_response
|
||||||
|
@ -25,7 +25,7 @@ from forge.agent_protocol.models import (
|
|||||||
)
|
)
|
||||||
from forge.config.config import Config
|
from forge.config.config import Config
|
||||||
from forge.file_storage import FileStorage
|
from forge.file_storage import FileStorage
|
||||||
from forge.llm.providers import ChatModelProvider, ModelProviderBudget
|
from forge.llm.providers import ModelProviderBudget, MultiProvider
|
||||||
from forge.models.action import ActionErrorResult, ActionSuccessResult
|
from forge.models.action import ActionErrorResult, ActionSuccessResult
|
||||||
from forge.utils.const import ASK_COMMAND, FINISH_COMMAND
|
from forge.utils.const import ASK_COMMAND, FINISH_COMMAND
|
||||||
from forge.utils.exceptions import AgentFinished, NotFoundError
|
from forge.utils.exceptions import AgentFinished, NotFoundError
|
||||||
@ -49,7 +49,7 @@ class AgentProtocolServer:
|
|||||||
app_config: Config,
|
app_config: Config,
|
||||||
database: AgentDB,
|
database: AgentDB,
|
||||||
file_storage: FileStorage,
|
file_storage: FileStorage,
|
||||||
llm_provider: ChatModelProvider,
|
llm_provider: MultiProvider,
|
||||||
):
|
):
|
||||||
self.app_config = app_config
|
self.app_config = app_config
|
||||||
self.db = database
|
self.db = database
|
||||||
@ -444,9 +444,7 @@ class AgentProtocolServer:
|
|||||||
agent_id = task_agent_id(task_id)
|
agent_id = task_agent_id(task_id)
|
||||||
return self.file_storage.clone_with_subroot(f"agents/{agent_id}/workspace")
|
return self.file_storage.clone_with_subroot(f"agents/{agent_id}/workspace")
|
||||||
|
|
||||||
def _get_task_llm_provider(
|
def _get_task_llm_provider(self, task: Task, step_id: str = "") -> MultiProvider:
|
||||||
self, task: Task, step_id: str = ""
|
|
||||||
) -> ChatModelProvider:
|
|
||||||
"""
|
"""
|
||||||
Configures the LLM provider with headers to link outgoing requests to the task.
|
Configures the LLM provider with headers to link outgoing requests to the task.
|
||||||
"""
|
"""
|
||||||
|
@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
|
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
from forge.agent_protocol.database import AgentDB
|
from forge.agent_protocol.database import AgentDB
|
||||||
from forge.components.code_executor import (
|
from forge.components.code_executor.code_executor import (
|
||||||
is_docker_available,
|
is_docker_available,
|
||||||
we_are_running_in_a_docker_container,
|
we_are_running_in_a_docker_container,
|
||||||
)
|
)
|
||||||
@ -82,7 +82,9 @@ async def run_auto_gpt(
|
|||||||
local = config.file_storage_backend == FileStorageBackendName.LOCAL
|
local = config.file_storage_backend == FileStorageBackendName.LOCAL
|
||||||
restrict_to_root = not local or config.restrict_to_workspace
|
restrict_to_root = not local or config.restrict_to_workspace
|
||||||
file_storage = get_storage(
|
file_storage = get_storage(
|
||||||
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root
|
config.file_storage_backend,
|
||||||
|
root_path=Path("data"),
|
||||||
|
restrict_to_root=restrict_to_root,
|
||||||
)
|
)
|
||||||
file_storage.initialize()
|
file_storage.initialize()
|
||||||
|
|
||||||
@ -353,7 +355,9 @@ async def run_auto_gpt_server(
|
|||||||
local = config.file_storage_backend == FileStorageBackendName.LOCAL
|
local = config.file_storage_backend == FileStorageBackendName.LOCAL
|
||||||
restrict_to_root = not local or config.restrict_to_workspace
|
restrict_to_root = not local or config.restrict_to_workspace
|
||||||
file_storage = get_storage(
|
file_storage = get_storage(
|
||||||
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root
|
config.file_storage_backend,
|
||||||
|
root_path=Path("data"),
|
||||||
|
restrict_to_root=restrict_to_root,
|
||||||
)
|
)
|
||||||
file_storage.initialize()
|
file_storage.initialize()
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ import re
|
|||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Coroutine, ParamSpec, TypeVar
|
from typing import Any, Callable, Coroutine, ParamSpec, TypeVar, cast
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
@ -88,7 +88,7 @@ def vcs_state_diverges_from_master() -> bool:
|
|||||||
def get_git_user_email() -> str:
|
def get_git_user_email() -> str:
|
||||||
try:
|
try:
|
||||||
repo = Repo(search_parent_directories=True)
|
repo = Repo(search_parent_directories=True)
|
||||||
return repo.config_reader().get_value("user", "email", default="")
|
return cast(str, repo.config_reader().get_value("user", "email", default=""))
|
||||||
except InvalidGitRepositoryError:
|
except InvalidGitRepositoryError:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
529
autogpt/poetry.lock
generated
529
autogpt/poetry.lock
generated
File diff suppressed because one or more lines are too long
@ -1,9 +1,7 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "agpt"
|
name = "agpt"
|
||||||
version = "0.5.0"
|
version = "0.5.0"
|
||||||
authors = [
|
authors = ["Significant Gravitas <support@agpt.co>"]
|
||||||
"Significant Gravitas <support@agpt.co>",
|
|
||||||
]
|
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
description = "An open-source attempt to make GPT-4 autonomous"
|
description = "An open-source attempt to make GPT-4 autonomous"
|
||||||
homepage = "https://github.com/Significant-Gravitas/AutoGPT/tree/master/autogpt"
|
homepage = "https://github.com/Significant-Gravitas/AutoGPT/tree/master/autogpt"
|
||||||
@ -30,11 +28,10 @@ charset-normalizer = "^3.1.0"
|
|||||||
click = "*"
|
click = "*"
|
||||||
colorama = "^0.4.6"
|
colorama = "^0.4.6"
|
||||||
distro = "^1.8.0"
|
distro = "^1.8.0"
|
||||||
en-core-web-sm = {url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl"}
|
en-core-web-sm = { url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl" }
|
||||||
fastapi = "^0.109.1"
|
fastapi = "^0.109.1"
|
||||||
ftfy = "^6.1.1"
|
ftfy = "^6.1.1"
|
||||||
google-api-python-client = "*"
|
google-api-python-client = "*"
|
||||||
gTTS = "^2.3.1"
|
|
||||||
hypercorn = "^0.14.4"
|
hypercorn = "^0.14.4"
|
||||||
inflection = "*"
|
inflection = "*"
|
||||||
jsonschema = "*"
|
jsonschema = "*"
|
||||||
@ -58,21 +55,18 @@ openapi-python-client = "^0.14.0"
|
|||||||
# Benchmarking
|
# Benchmarking
|
||||||
agbenchmark = { path = "../benchmark", optional = true }
|
agbenchmark = { path = "../benchmark", optional = true }
|
||||||
# agbenchmark = {git = "https://github.com/Significant-Gravitas/AutoGPT.git", subdirectory = "benchmark", optional = true}
|
# agbenchmark = {git = "https://github.com/Significant-Gravitas/AutoGPT.git", subdirectory = "benchmark", optional = true}
|
||||||
google-cloud-logging = "^3.8.0"
|
|
||||||
google-cloud-storage = "^2.13.0"
|
|
||||||
psycopg2-binary = "^2.9.9"
|
psycopg2-binary = "^2.9.9"
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
benchmark = ["agbenchmark"]
|
benchmark = ["agbenchmark"]
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
black = "*"
|
black = "^23.12.1"
|
||||||
boto3-stubs = {extras = ["s3"], version = "^1.33.6"}
|
flake8 = "^7.0.0"
|
||||||
flake8 = "*"
|
|
||||||
gitpython = "^3.1.32"
|
gitpython = "^3.1.32"
|
||||||
isort = "*"
|
isort = "^5.13.1"
|
||||||
mypy = "*"
|
|
||||||
pre-commit = "*"
|
pre-commit = "*"
|
||||||
|
pyright = "^1.1.364"
|
||||||
types-beautifulsoup4 = "*"
|
types-beautifulsoup4 = "*"
|
||||||
types-colorama = "*"
|
types-colorama = "*"
|
||||||
types-Markdown = "*"
|
types-Markdown = "*"
|
||||||
@ -89,7 +83,7 @@ pytest-integration = "*"
|
|||||||
pytest-mock = "*"
|
pytest-mock = "*"
|
||||||
pytest-recording = "*"
|
pytest-recording = "*"
|
||||||
pytest-xdist = "*"
|
pytest-xdist = "*"
|
||||||
vcrpy = {git = "https://github.com/Significant-Gravitas/vcrpy.git", rev = "master"}
|
vcrpy = { git = "https://github.com/Significant-Gravitas/vcrpy.git", rev = "master" }
|
||||||
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
@ -101,50 +95,18 @@ build-backend = "poetry.core.masonry.api"
|
|||||||
line-length = 88
|
line-length = 88
|
||||||
target-version = ['py310']
|
target-version = ['py310']
|
||||||
include = '\.pyi?$'
|
include = '\.pyi?$'
|
||||||
packages = ["autogpt"]
|
|
||||||
extend-exclude = '.+/(dist|.venv|venv|build|data)/.+'
|
|
||||||
|
|
||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
profile = "black"
|
profile = "black"
|
||||||
multi_line_output = 3
|
skip_glob = ["data"]
|
||||||
include_trailing_comma = true
|
|
||||||
force_grid_wrap = 0
|
|
||||||
use_parentheses = true
|
|
||||||
ensure_newline_before_comments = true
|
|
||||||
line_length = 88
|
|
||||||
sections = [
|
|
||||||
"FUTURE",
|
|
||||||
"STDLIB",
|
|
||||||
"THIRDPARTY",
|
|
||||||
"FIRSTPARTY",
|
|
||||||
"LOCALFOLDER"
|
|
||||||
]
|
|
||||||
extend_skip = [
|
|
||||||
"agbenchmark_config/temp_folder/",
|
|
||||||
"data/",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.pyright]
|
||||||
follow_imports = 'skip'
|
pythonVersion = "3.10"
|
||||||
check_untyped_defs = true
|
exclude = ["data/**", "**/node_modules", "**/__pycache__", "**/.*"]
|
||||||
disallow_untyped_calls = true
|
ignore = ["../forge/**"]
|
||||||
files = [
|
|
||||||
'autogpt/**/*.py',
|
|
||||||
'tests/**/*.py'
|
|
||||||
]
|
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
|
||||||
module = [
|
|
||||||
'requests.*',
|
|
||||||
'yaml.*'
|
|
||||||
]
|
|
||||||
ignore_missing_imports = true
|
|
||||||
|
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
markers = [
|
markers = ["slow", "requires_openai_api_key", "requires_huggingface_api_key"]
|
||||||
"requires_openai_api_key",
|
|
||||||
"requires_huggingface_api_key"
|
|
||||||
]
|
|
||||||
|
@ -4,12 +4,12 @@ import sys
|
|||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import poetry.factory # noqa
|
import poetry.factory # type: ignore # noqa
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
os.system(f"{sys.executable} -m pip install 'poetry>=1.6.1,<2.0.0'")
|
os.system(f"{sys.executable} -m pip install 'poetry>=1.6.1,<2.0.0'")
|
||||||
|
|
||||||
from poetry.core.constraints.version.version import Version
|
from poetry.core.constraints.version.version import Version # type: ignore
|
||||||
from poetry.factory import Factory
|
from poetry.factory import Factory # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -20,7 +20,7 @@ from autogpt.app.utils import coroutine
|
|||||||
)
|
)
|
||||||
@coroutine
|
@coroutine
|
||||||
async def generate_release_notes(repo_path: Optional[Path] = None):
|
async def generate_release_notes(repo_path: Optional[Path] = None):
|
||||||
logger = logging.getLogger(generate_release_notes.name)
|
logger = logging.getLogger(generate_release_notes.name) # pyright: ignore
|
||||||
|
|
||||||
repo = Repo(repo_path, search_parent_directories=True)
|
repo = Repo(repo_path, search_parent_directories=True)
|
||||||
tags = list(repo.tags)
|
tags = list(repo.tags)
|
||||||
|
@ -12,7 +12,7 @@ from forge.file_storage.local import (
|
|||||||
FileStorageConfiguration,
|
FileStorageConfiguration,
|
||||||
LocalFileStorage,
|
LocalFileStorage,
|
||||||
)
|
)
|
||||||
from forge.llm.providers import ChatModelProvider
|
from forge.llm.providers import MultiProvider
|
||||||
from forge.logging.config import configure_logging
|
from forge.logging.config import configure_logging
|
||||||
|
|
||||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||||
@ -71,14 +71,12 @@ def setup_logger(config: Config):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def llm_provider(config: Config) -> ChatModelProvider:
|
def llm_provider(config: Config) -> MultiProvider:
|
||||||
return _configure_llm_provider(config)
|
return _configure_llm_provider(config)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def agent(
|
def agent(config: Config, llm_provider: MultiProvider, storage: FileStorage) -> Agent:
|
||||||
config: Config, llm_provider: ChatModelProvider, storage: FileStorage
|
|
||||||
) -> Agent:
|
|
||||||
ai_profile = AIProfile(
|
ai_profile = AIProfile(
|
||||||
ai_name="Base",
|
ai_name="Base",
|
||||||
ai_role="A base AI",
|
ai_role="A base AI",
|
||||||
|
@ -1,13 +1,16 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from forge.config.ai_profile import AIProfile
|
from forge.config.ai_profile import AIProfile
|
||||||
from forge.config.config import Config
|
from forge.config.config import Config
|
||||||
from forge.file_storage import FileStorageBackendName, get_storage
|
from forge.file_storage import FileStorageBackendName, get_storage
|
||||||
|
from forge.llm.providers import MultiProvider
|
||||||
|
|
||||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def dummy_agent(config: Config, llm_provider, memory_json_file):
|
def dummy_agent(config: Config, llm_provider: MultiProvider):
|
||||||
ai_profile = AIProfile(
|
ai_profile = AIProfile(
|
||||||
ai_name="Dummy Agent",
|
ai_name="Dummy Agent",
|
||||||
ai_role="Dummy Role",
|
ai_role="Dummy Role",
|
||||||
@ -31,7 +34,9 @@ def dummy_agent(config: Config, llm_provider, memory_json_file):
|
|||||||
local = config.file_storage_backend == FileStorageBackendName.LOCAL
|
local = config.file_storage_backend == FileStorageBackendName.LOCAL
|
||||||
restrict_to_root = not local or config.restrict_to_workspace
|
restrict_to_root = not local or config.restrict_to_workspace
|
||||||
file_storage = get_storage(
|
file_storage = get_storage(
|
||||||
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root
|
config.file_storage_backend,
|
||||||
|
root_path=Path("data"),
|
||||||
|
restrict_to_root=restrict_to_root,
|
||||||
)
|
)
|
||||||
file_storage.initialize()
|
file_storage.initialize()
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import tempfile
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from forge.components.code_executor import (
|
from forge.components.code_executor.code_executor import (
|
||||||
ALLOWLIST_CONTROL,
|
ALLOWLIST_CONTROL,
|
||||||
CodeExecutorComponent,
|
CodeExecutorComponent,
|
||||||
is_docker_available,
|
is_docker_available,
|
||||||
|
@ -257,17 +257,3 @@ def test_huggingface_fail_request_bad_image(
|
|||||||
result = image_gen_component.generate_image("astronaut riding a horse", 512)
|
result = image_gen_component.generate_image("astronaut riding a horse", 512)
|
||||||
|
|
||||||
assert result == "Error creating image."
|
assert result == "Error creating image."
|
||||||
|
|
||||||
|
|
||||||
def test_huggingface_fail_missing_api_token(
|
|
||||||
mocker, image_gen_component: ImageGeneratorComponent, agent: Agent
|
|
||||||
):
|
|
||||||
agent.legacy_config.image_provider = "huggingface"
|
|
||||||
agent.legacy_config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"
|
|
||||||
|
|
||||||
# Mock requests.post to raise ValueError
|
|
||||||
mocker.patch("requests.post", side_effect=ValueError)
|
|
||||||
|
|
||||||
# Verify request raises an error.
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
image_gen_component.generate_image("astronaut riding a horse", 512)
|
|
||||||
|
@ -67,8 +67,8 @@ def test_missing_azure_config(config: Config) -> None:
|
|||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
config.openai_credentials.load_azure_config(config_file)
|
config.openai_credentials.load_azure_config(config_file)
|
||||||
|
|
||||||
assert config.openai_credentials.api_type != "azure"
|
assert config.openai_credentials.api_type != SecretStr("azure")
|
||||||
assert config.openai_credentials.api_version == ""
|
assert config.openai_credentials.api_version is None
|
||||||
assert config.openai_credentials.azure_model_to_deploy_id_map is None
|
assert config.openai_credentials.azure_model_to_deploy_id_map is None
|
||||||
|
|
||||||
|
|
||||||
@ -98,8 +98,8 @@ azure_model_map:
|
|||||||
|
|
||||||
def test_azure_config(config_with_azure: Config) -> None:
|
def test_azure_config(config_with_azure: Config) -> None:
|
||||||
assert (credentials := config_with_azure.openai_credentials) is not None
|
assert (credentials := config_with_azure.openai_credentials) is not None
|
||||||
assert credentials.api_type == "azure"
|
assert credentials.api_type == SecretStr("azure")
|
||||||
assert credentials.api_version == "2023-06-01-preview"
|
assert credentials.api_version == SecretStr("2023-06-01-preview")
|
||||||
assert credentials.azure_endpoint == SecretStr("https://dummy.openai.azure.com")
|
assert credentials.azure_endpoint == SecretStr("https://dummy.openai.azure.com")
|
||||||
assert credentials.azure_model_to_deploy_id_map == {
|
assert credentials.azure_model_to_deploy_id_map == {
|
||||||
config_with_azure.fast_llm: "FAST-LLM_ID",
|
config_with_azure.fast_llm: "FAST-LLM_ID",
|
||||||
|
@ -4,7 +4,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from forge.file_storage import GCSFileStorage, GCSFileStorageConfiguration
|
from forge.file_storage.gcs import GCSFileStorage, GCSFileStorageConfiguration
|
||||||
from google.auth.exceptions import GoogleAuthError
|
from google.auth.exceptions import GoogleAuthError
|
||||||
from google.cloud import storage
|
from google.cloud import storage
|
||||||
from google.cloud.exceptions import NotFound
|
from google.cloud.exceptions import NotFound
|
||||||
@ -14,6 +14,8 @@ try:
|
|||||||
except GoogleAuthError:
|
except GoogleAuthError:
|
||||||
pytest.skip("Google Cloud Authentication not configured", allow_module_level=True)
|
pytest.skip("Google Cloud Authentication not configured", allow_module_level=True)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.slow
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def gcs_bucket_name() -> str:
|
def gcs_bucket_name() -> str:
|
||||||
@ -26,7 +28,7 @@ def gcs_root() -> Path:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def gcs_storage_uninitialized(gcs_bucket_name: str, gcs_root: Path) -> GCSFileStorage:
|
def gcs_storage_uninitialized(gcs_bucket_name: str, gcs_root: Path):
|
||||||
os.environ["STORAGE_BUCKET"] = gcs_bucket_name
|
os.environ["STORAGE_BUCKET"] = gcs_bucket_name
|
||||||
storage_config = GCSFileStorageConfiguration.from_env()
|
storage_config = GCSFileStorageConfiguration.from_env()
|
||||||
storage_config.root = gcs_root
|
storage_config.root = gcs_root
|
||||||
@ -52,7 +54,7 @@ def test_initialize(gcs_bucket_name: str, gcs_storage_uninitialized: GCSFileStor
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def gcs_storage(gcs_storage_uninitialized: GCSFileStorage) -> GCSFileStorage:
|
def gcs_storage(gcs_storage_uninitialized: GCSFileStorage):
|
||||||
(gcs_storage := gcs_storage_uninitialized).initialize()
|
(gcs_storage := gcs_storage_uninitialized).initialize()
|
||||||
yield gcs_storage # type: ignore
|
yield gcs_storage # type: ignore
|
||||||
|
|
||||||
@ -77,7 +79,7 @@ TEST_FILES: list[tuple[str | Path, str]] = [
|
|||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
async def gcs_storage_with_files(gcs_storage: GCSFileStorage) -> GCSFileStorage:
|
async def gcs_storage_with_files(gcs_storage: GCSFileStorage):
|
||||||
for file_name, file_content in TEST_FILES:
|
for file_name, file_content in TEST_FILES:
|
||||||
gcs_storage._bucket.blob(
|
gcs_storage._bucket.blob(
|
||||||
str(gcs_storage.get_path(file_name))
|
str(gcs_storage.get_path(file_name))
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from forge.json import json_loads
|
from forge.json.parsing import json_loads
|
||||||
|
|
||||||
_JSON_FIXABLE: list[tuple[str, str]] = [
|
_JSON_FIXABLE: list[tuple[str, str]] = [
|
||||||
# Missing comma
|
# Missing comma
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from forge.file_storage import FileStorageConfiguration, LocalFileStorage
|
from forge.file_storage.local import FileStorageConfiguration, LocalFileStorage
|
||||||
|
|
||||||
_ACCESSIBLE_PATHS = [
|
_ACCESSIBLE_PATHS = [
|
||||||
Path("."),
|
Path("."),
|
||||||
|
@ -5,7 +5,7 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError
|
||||||
from forge.file_storage import S3FileStorage, S3FileStorageConfiguration
|
from forge.file_storage.s3 import S3FileStorage, S3FileStorageConfiguration
|
||||||
|
|
||||||
if not (os.getenv("S3_ENDPOINT_URL") and os.getenv("AWS_ACCESS_KEY_ID")):
|
if not (os.getenv("S3_ENDPOINT_URL") and os.getenv("AWS_ACCESS_KEY_ID")):
|
||||||
pytest.skip("S3 environment variables are not set", allow_module_level=True)
|
pytest.skip("S3 environment variables are not set", allow_module_level=True)
|
||||||
@ -22,7 +22,7 @@ def s3_root() -> Path:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def s3_storage_uninitialized(s3_bucket_name: str, s3_root: Path) -> S3FileStorage:
|
def s3_storage_uninitialized(s3_bucket_name: str, s3_root: Path):
|
||||||
os.environ["STORAGE_BUCKET"] = s3_bucket_name
|
os.environ["STORAGE_BUCKET"] = s3_bucket_name
|
||||||
storage_config = S3FileStorageConfiguration.from_env()
|
storage_config = S3FileStorageConfiguration.from_env()
|
||||||
storage_config.root = s3_root
|
storage_config.root = s3_root
|
||||||
@ -36,12 +36,13 @@ def test_initialize(s3_bucket_name: str, s3_storage_uninitialized: S3FileStorage
|
|||||||
|
|
||||||
# test that the bucket doesn't exist yet
|
# test that the bucket doesn't exist yet
|
||||||
with pytest.raises(ClientError):
|
with pytest.raises(ClientError):
|
||||||
s3.meta.client.head_bucket(Bucket=s3_bucket_name)
|
s3.meta.client.head_bucket(Bucket=s3_bucket_name) # pyright: ignore
|
||||||
|
|
||||||
s3_storage_uninitialized.initialize()
|
s3_storage_uninitialized.initialize()
|
||||||
|
|
||||||
# test that the bucket has been created
|
# test that the bucket has been created
|
||||||
s3.meta.client.head_bucket(Bucket=s3_bucket_name)
|
s3.meta.client.head_bucket(Bucket=s3_bucket_name) # pyright: ignore
|
||||||
|
# FIXME: remove the "pyright: ignore" comments after moving this test file to forge
|
||||||
|
|
||||||
|
|
||||||
def test_workspace_bucket_name(
|
def test_workspace_bucket_name(
|
||||||
@ -52,7 +53,7 @@ def test_workspace_bucket_name(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def s3_storage(s3_storage_uninitialized: S3FileStorage) -> S3FileStorage:
|
def s3_storage(s3_storage_uninitialized: S3FileStorage):
|
||||||
(s3_storage := s3_storage_uninitialized).initialize()
|
(s3_storage := s3_storage_uninitialized).initialize()
|
||||||
yield s3_storage # type: ignore
|
yield s3_storage # type: ignore
|
||||||
|
|
||||||
@ -71,7 +72,7 @@ TEST_FILES: list[tuple[str | Path, str]] = [
|
|||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
async def s3_storage_with_files(s3_storage: S3FileStorage) -> S3FileStorage:
|
async def s3_storage_with_files(s3_storage: S3FileStorage):
|
||||||
for file_name, file_content in TEST_FILES:
|
for file_name, file_content in TEST_FILES:
|
||||||
s3_storage._bucket.Object(str(s3_storage.get_path(file_name))).put(
|
s3_storage._bucket.Object(str(s3_storage.get_path(file_name))).put(
|
||||||
Body=file_content
|
Body=file_content
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
@ -53,11 +54,14 @@ def cached_openai_client(mocker: MockerFixture) -> OpenAI:
|
|||||||
def _patched_prepare_options(self, options: FinalRequestOptions):
|
def _patched_prepare_options(self, options: FinalRequestOptions):
|
||||||
_prepare_options(options)
|
_prepare_options(options)
|
||||||
|
|
||||||
|
if not options.json_data:
|
||||||
|
return
|
||||||
|
|
||||||
headers: dict[str, str | Omit] = (
|
headers: dict[str, str | Omit] = (
|
||||||
{**options.headers} if is_given(options.headers) else {}
|
{**options.headers} if is_given(options.headers) else {}
|
||||||
)
|
)
|
||||||
options.headers = headers
|
options.headers = headers
|
||||||
data: dict = options.json_data
|
data = cast(dict, options.json_data)
|
||||||
|
|
||||||
logging.getLogger("cached_openai_client").debug(
|
logging.getLogger("cached_openai_client").debug(
|
||||||
f"Outgoing API request: {headers}\n{data if data else None}"
|
f"Outgoing API request: {headers}\n{data if data else None}"
|
||||||
|
@ -1,15 +1,12 @@
|
|||||||
[flake8]
|
[flake8]
|
||||||
max-line-length = 88
|
max-line-length = 88
|
||||||
select = "E303, W293, W291, W292, E305, E231, E302"
|
# Ignore rules that conflict with Black code style
|
||||||
|
extend-ignore = E203, W503
|
||||||
exclude =
|
exclude =
|
||||||
.tox,
|
__pycache__/,
|
||||||
__pycache__,
|
|
||||||
*.pyc,
|
*.pyc,
|
||||||
.env
|
.pytest_cache/,
|
||||||
venv*/*,
|
venv*/,
|
||||||
.venv/*,
|
.venv/,
|
||||||
reports/*,
|
reports/,
|
||||||
dist/*,
|
agbenchmark/reports/,
|
||||||
agent/*,
|
|
||||||
code,
|
|
||||||
agbenchmark/challenges/*
|
|
||||||
|
@ -1,36 +0,0 @@
|
|||||||
repos:
|
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
||||||
rev: v4.4.0
|
|
||||||
hooks:
|
|
||||||
- id: check-added-large-files
|
|
||||||
args: ['--maxkb=500']
|
|
||||||
- id: check-byte-order-marker
|
|
||||||
- id: check-case-conflict
|
|
||||||
- id: check-merge-conflict
|
|
||||||
- id: check-symlinks
|
|
||||||
- id: debug-statements
|
|
||||||
|
|
||||||
- repo: https://github.com/pycqa/isort
|
|
||||||
rev: 5.12.0
|
|
||||||
hooks:
|
|
||||||
- id: isort
|
|
||||||
language_version: python3.10
|
|
||||||
|
|
||||||
- repo: https://github.com/psf/black
|
|
||||||
rev: 23.3.0
|
|
||||||
hooks:
|
|
||||||
- id: black
|
|
||||||
language_version: python3.10
|
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
|
||||||
rev: 'v1.3.0'
|
|
||||||
hooks:
|
|
||||||
- id: mypy
|
|
||||||
|
|
||||||
- repo: local
|
|
||||||
hooks:
|
|
||||||
- id: autoflake
|
|
||||||
name: autoflake
|
|
||||||
entry: autoflake --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring --in-place agbenchmark
|
|
||||||
language: python
|
|
||||||
types: [ python ]
|
|
@ -28,7 +28,7 @@ async def run_api_agent(
|
|||||||
configuration = Configuration(host=config.host)
|
configuration = Configuration(host=config.host)
|
||||||
async with ApiClient(configuration) as api_client:
|
async with ApiClient(configuration) as api_client:
|
||||||
api_instance = AgentApi(api_client)
|
api_instance = AgentApi(api_client)
|
||||||
task_request_body = TaskRequestBody(input=task)
|
task_request_body = TaskRequestBody(input=task, additional_input=None)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
response = await api_instance.create_agent_task(
|
response = await api_instance.create_agent_task(
|
||||||
|
@ -106,8 +106,8 @@ def find_agbenchmark_without_uvicorn():
|
|||||||
|
|
||||||
|
|
||||||
class CreateReportRequest(BaseModel):
|
class CreateReportRequest(BaseModel):
|
||||||
test: str = None
|
test: str
|
||||||
test_run_id: str = None
|
test_run_id: str
|
||||||
# category: Optional[str] = []
|
# category: Optional[str] = []
|
||||||
mock: Optional[bool] = False
|
mock: Optional[bool] = False
|
||||||
|
|
||||||
@ -178,8 +178,8 @@ def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
|
|||||||
logger.debug(f"Benchmark finished running in {time.time() - start_time} s")
|
logger.debug(f"Benchmark finished running in {time.time() - start_time} s")
|
||||||
|
|
||||||
# List all folders in the current working directory
|
# List all folders in the current working directory
|
||||||
path_reports = agbenchmark_config.reports_folder
|
reports_folder = agbenchmark_config.reports_folder
|
||||||
folders = [folder for folder in path_reports.iterdir() if folder.is_dir()]
|
folders = [folder for folder in reports_folder.iterdir() if folder.is_dir()]
|
||||||
|
|
||||||
# Sort the folders based on their names
|
# Sort the folders based on their names
|
||||||
sorted_folders = sorted(folders, key=lambda x: x.name)
|
sorted_folders = sorted(folders, key=lambda x: x.name)
|
||||||
@ -196,13 +196,14 @@ def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
|
|||||||
data = json.load(file)
|
data = json.load(file)
|
||||||
logger.debug(f"Report data: {data}")
|
logger.debug(f"Report data: {data}")
|
||||||
else:
|
else:
|
||||||
logger.error(
|
raise HTTPException(
|
||||||
|
502,
|
||||||
"Could not get result after running benchmark: "
|
"Could not get result after running benchmark: "
|
||||||
f"'report.json' does not exist in '{latest_folder}'"
|
f"'report.json' does not exist in '{latest_folder}'",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(
|
raise HTTPException(
|
||||||
"Could not get result after running benchmark: no reports found"
|
504, "Could not get result after running benchmark: no reports found"
|
||||||
)
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
@ -239,7 +240,9 @@ def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
|
|||||||
api_instance = AgentApi(api_client)
|
api_instance = AgentApi(api_client)
|
||||||
task_input = challenge_info.task
|
task_input = challenge_info.task
|
||||||
|
|
||||||
task_request_body = TaskRequestBody(input=task_input)
|
task_request_body = TaskRequestBody(
|
||||||
|
input=task_input, additional_input=None
|
||||||
|
)
|
||||||
task_response = await api_instance.create_agent_task(
|
task_response = await api_instance.create_agent_task(
|
||||||
task_request_body=task_request_body
|
task_request_body=task_request_body
|
||||||
)
|
)
|
||||||
@ -276,7 +279,7 @@ def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
|
|||||||
# Forward the request
|
# Forward the request
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
new_url,
|
new_url,
|
||||||
data=await request.body(),
|
content=await request.body(),
|
||||||
headers=dict(request.headers),
|
headers=dict(request.headers),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import AsyncIterator, ClassVar, Optional
|
from typing import AsyncIterator, Awaitable, ClassVar, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from agent_protocol_client import AgentApi, Step
|
from agent_protocol_client import AgentApi, Step
|
||||||
@ -54,7 +54,7 @@ class BaseChallenge(ABC):
|
|||||||
config: AgentBenchmarkConfig,
|
config: AgentBenchmarkConfig,
|
||||||
request: pytest.FixtureRequest,
|
request: pytest.FixtureRequest,
|
||||||
i_attempt: int,
|
i_attempt: int,
|
||||||
) -> None:
|
) -> None | Awaitable[None]:
|
||||||
"""
|
"""
|
||||||
Test method for use by Pytest-based benchmark sessions. Should return normally
|
Test method for use by Pytest-based benchmark sessions. Should return normally
|
||||||
if the challenge passes, and raise a (preferably descriptive) error otherwise.
|
if the challenge passes, and raise a (preferably descriptive) error otherwise.
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from collections import deque
|
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@ -6,19 +5,17 @@ import os
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, ClassVar, Iterator, Literal, Optional
|
from typing import Annotated, Any, ClassVar, Iterator, Literal, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from agent_protocol_client import (
|
from agent_protocol_client import AgentApi, ApiClient
|
||||||
AgentApi,
|
from agent_protocol_client import Configuration as ClientConfig
|
||||||
ApiClient,
|
from agent_protocol_client import Step
|
||||||
Configuration as ClientConfig,
|
|
||||||
Step,
|
|
||||||
)
|
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
from openai import _load_client as get_openai_client
|
from openai import _load_client as get_openai_client
|
||||||
from pydantic import BaseModel, constr, Field, validator
|
from pydantic import BaseModel, Field, constr, validator
|
||||||
|
|
||||||
from agbenchmark.agent_api_interface import download_agent_artifacts_into_folder
|
from agbenchmark.agent_api_interface import download_agent_artifacts_into_folder
|
||||||
from agbenchmark.agent_interface import copy_challenge_artifacts_into_workspace
|
from agbenchmark.agent_interface import copy_challenge_artifacts_into_workspace
|
||||||
@ -49,7 +46,7 @@ class BuiltinChallengeSpec(BaseModel):
|
|||||||
|
|
||||||
class Info(BaseModel):
|
class Info(BaseModel):
|
||||||
difficulty: DifficultyLevel
|
difficulty: DifficultyLevel
|
||||||
description: constr(regex=r"^Tests if the agent can.*")
|
description: Annotated[str, constr(regex=r"^Tests if the agent can.*")]
|
||||||
side_effects: list[str] = Field(default_factory=list)
|
side_effects: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
info: Info
|
info: Info
|
||||||
@ -184,7 +181,7 @@ class BuiltinChallenge(BaseChallenge):
|
|||||||
steps: list[Step] = []
|
steps: list[Step] = []
|
||||||
try:
|
try:
|
||||||
async for step in self.run_challenge(
|
async for step in self.run_challenge(
|
||||||
config, timeout, mock=request.config.getoption("--mock")
|
config, timeout, mock=bool(request.config.getoption("--mock"))
|
||||||
):
|
):
|
||||||
if not task_id:
|
if not task_id:
|
||||||
task_id = step.task_id
|
task_id = step.task_id
|
||||||
@ -199,6 +196,8 @@ class BuiltinChallenge(BaseChallenge):
|
|||||||
timed_out = False
|
timed_out = False
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
timed_out = True
|
timed_out = True
|
||||||
|
|
||||||
|
assert isinstance(request.node, pytest.Item)
|
||||||
request.node.user_properties.append(("steps", steps))
|
request.node.user_properties.append(("steps", steps))
|
||||||
request.node.user_properties.append(("n_steps", n_steps))
|
request.node.user_properties.append(("n_steps", n_steps))
|
||||||
request.node.user_properties.append(("timed_out", timed_out))
|
request.node.user_properties.append(("timed_out", timed_out))
|
||||||
@ -411,15 +410,10 @@ class BuiltinChallenge(BaseChallenge):
|
|||||||
def load_builtin_challenges() -> Iterator[type[BuiltinChallenge]]:
|
def load_builtin_challenges() -> Iterator[type[BuiltinChallenge]]:
|
||||||
logger.info("Loading built-in challenges...")
|
logger.info("Loading built-in challenges...")
|
||||||
|
|
||||||
challenges_path = os.path.dirname(__file__)
|
challenges_path = Path(__file__).parent
|
||||||
logger.debug(f"Looking for challenge spec files in {challenges_path}...")
|
logger.debug(f"Looking for challenge spec files in {challenges_path}...")
|
||||||
|
|
||||||
json_files = deque(
|
json_files = deque(challenges_path.rglob("data.json"))
|
||||||
glob.glob(
|
|
||||||
f"{challenges_path}/**/data.json",
|
|
||||||
recursive=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"Found {len(json_files)} built-in challenges.")
|
logger.debug(f"Found {len(json_files)} built-in challenges.")
|
||||||
|
|
||||||
@ -431,7 +425,7 @@ def load_builtin_challenges() -> Iterator[type[BuiltinChallenge]]:
|
|||||||
ignored += 1
|
ignored += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
challenge = BuiltinChallenge.from_challenge_spec_file(Path(json_file))
|
challenge = BuiltinChallenge.from_challenge_spec_file(json_file)
|
||||||
logger.debug(f"Generated test for {challenge.info.name}")
|
logger.debug(f"Generated test for {challenge.info.name}")
|
||||||
yield challenge
|
yield challenge
|
||||||
|
|
||||||
@ -442,8 +436,8 @@ def load_builtin_challenges() -> Iterator[type[BuiltinChallenge]]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _challenge_should_be_ignored(json_file_path: str):
|
def _challenge_should_be_ignored(json_file_path: Path):
|
||||||
return (
|
return (
|
||||||
"challenges/deprecated" in json_file_path
|
"challenges/deprecated" in json_file_path.as_posix()
|
||||||
or "challenges/library" in json_file_path
|
or "challenges/library" in json_file_path.as_posix()
|
||||||
)
|
)
|
||||||
|
@ -23,9 +23,10 @@ def test_get_ethereum_price() -> None:
|
|||||||
real_eth_price_value = float(real_eth_price)
|
real_eth_price_value = float(real_eth_price)
|
||||||
|
|
||||||
# Check if the eth price is within $50 of the actual Ethereum price
|
# Check if the eth price is within $50 of the actual Ethereum price
|
||||||
assert (
|
assert abs(real_eth_price_value - eth_price_value) <= 50, (
|
||||||
abs(real_eth_price_value - eth_price_value) <= 50
|
"AssertionError: Ethereum price is not within $50 of the actual Ethereum price "
|
||||||
), f"AssertionError: Ethereum price is not within $50 of the actual Ethereum price (Provided price: ${eth_price}, Real price: ${real_eth_price})"
|
f"(Provided price: ${eth_price}, Real price: ${real_eth_price})"
|
||||||
|
)
|
||||||
|
|
||||||
print("Matches")
|
print("Matches")
|
||||||
|
|
||||||
|
@ -23,9 +23,10 @@ def test_get_ethereum_price() -> None:
|
|||||||
real_eth_price_value = float(real_eth_price)
|
real_eth_price_value = float(real_eth_price)
|
||||||
|
|
||||||
# Check if the eth price is within $50 of the actual Ethereum price
|
# Check if the eth price is within $50 of the actual Ethereum price
|
||||||
assert (
|
assert abs(real_eth_price_value - eth_price_value) <= 50, (
|
||||||
abs(real_eth_price_value - eth_price_value) <= 50
|
"AssertionError: Ethereum price is not within $50 of the actual Ethereum price "
|
||||||
), f"AssertionError: Ethereum price is not within $50 of the actual Ethereum price (Provided price: ${eth_price}, Real price: ${real_eth_price})"
|
f"(Provided price: ${eth_price}, Real price: ${real_eth_price})"
|
||||||
|
)
|
||||||
|
|
||||||
print("Matches")
|
print("Matches")
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# mypy: ignore-errors
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# mypy: ignore-errors
|
# pyright: reportMissingImports=false
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from sample_code import three_sum
|
from sample_code import three_sum
|
||||||
|
@ -21,7 +21,6 @@ def generate_password(length: int = 8) -> str:
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
password_length = (
|
password_length = (
|
||||||
int(sys.argv[sys.argv.index("--length") + 1])
|
int(sys.argv[sys.argv.index("--length") + 1]) if "--length" in sys.argv else 8
|
||||||
if "--length" in sys.argv else 8
|
|
||||||
)
|
)
|
||||||
print(generate_password(password_length))
|
print(generate_password(password_length))
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# pyright: reportMissingImports=false
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import password_generator
|
import password_generator
|
||||||
@ -18,7 +19,9 @@ class TestPasswordGenerator(unittest.TestCase):
|
|||||||
def test_password_content(self):
|
def test_password_content(self):
|
||||||
password = password_generator.generate_password()
|
password = password_generator.generate_password()
|
||||||
self.assertTrue(any(c.isdigit() for c in password))
|
self.assertTrue(any(c.isdigit() for c in password))
|
||||||
self.assertTrue(any(c in password_generator.string.punctuation for c in password))
|
self.assertTrue(
|
||||||
|
any(c in password_generator.string.punctuation for c in password)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# pyright: reportMissingImports=false
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from url_shortener import retrieve_url, shorten_url
|
from url_shortener import retrieve_url, shorten_url
|
||||||
|
@ -56,7 +56,7 @@ def winner(board):
|
|||||||
|
|
||||||
def getLocation():
|
def getLocation():
|
||||||
location = input(
|
location = input(
|
||||||
"Choose where to play. Enter two numbers separated by a comma, for example: 1,1 "
|
"Choose where to play. Enter two numbers separated by a comma [example: 1,1]: "
|
||||||
)
|
)
|
||||||
print(f"\nYou picked {location}")
|
print(f"\nYou picked {location}")
|
||||||
coordinates = [int(x) for x in location.split(",")]
|
coordinates = [int(x) for x in location.split(",")]
|
||||||
@ -69,7 +69,8 @@ def getLocation():
|
|||||||
):
|
):
|
||||||
print("You inputted a location in an invalid format")
|
print("You inputted a location in an invalid format")
|
||||||
location = input(
|
location = input(
|
||||||
"Choose where to play. Enter two numbers separated by a comma, for example: 1,1 "
|
"Choose where to play. Enter two numbers separated by a comma "
|
||||||
|
"[example: 1,1]: "
|
||||||
)
|
)
|
||||||
coordinates = [int(x) for x in location.split(",")]
|
coordinates = [int(x) for x in location.split(",")]
|
||||||
return coordinates
|
return coordinates
|
||||||
|
@ -37,15 +37,14 @@ class GameStatus(BaseModel):
|
|||||||
winner: Optional[str]
|
winner: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
|
|
||||||
class Game(BaseModel):
|
class Game(BaseModel):
|
||||||
game_id: str
|
game_id: str
|
||||||
players: List[str]
|
players: list[str]
|
||||||
board: dict # This could represent the state of the game board, you might need to flesh this out further
|
# This could represent the state of the game board,
|
||||||
ships: List[ShipPlacement] # List of ship placements for this game
|
# you might need to flesh this out further:
|
||||||
turns: List[Turn] # List of turns that have been taken
|
board: dict
|
||||||
|
ships: list[ShipPlacement] # List of ship placements for this game
|
||||||
|
turns: list[Turn] # List of turns that have been taken
|
||||||
|
|
||||||
|
|
||||||
class AbstractBattleship(ABC):
|
class AbstractBattleship(ABC):
|
||||||
@ -86,7 +85,7 @@ class AbstractBattleship(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_game(self) -> Game:
|
def get_game(self) -> Game | None:
|
||||||
"""
|
"""
|
||||||
Retrieve the state of the game.
|
Retrieve the state of the game.
|
||||||
"""
|
"""
|
||||||
@ -103,5 +102,8 @@ class AbstractBattleship(ABC):
|
|||||||
def create_game(self) -> None:
|
def create_game(self) -> None:
|
||||||
"""
|
"""
|
||||||
Create a new game.
|
Create a new game.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The ID of the created game.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# pyright: reportMissingImports=false
|
||||||
import pytest
|
import pytest
|
||||||
from abstract_class import ShipPlacement, Turn
|
from abstract_class import ShipPlacement, Turn
|
||||||
from battleship import Battleship
|
from battleship import Battleship
|
||||||
|
@ -50,7 +50,7 @@ def test_cant_hit_before_ships_placed(battleship_game):
|
|||||||
|
|
||||||
|
|
||||||
def test_cant_place_ship_after_all_ships_placed(battleship_game, initialized_game_id):
|
def test_cant_place_ship_after_all_ships_placed(battleship_game, initialized_game_id):
|
||||||
game = battleship_game.get_game(initialized_game_id)
|
battleship_game.get_game(initialized_game_id)
|
||||||
additional_ship = ShipPlacement(
|
additional_ship = ShipPlacement(
|
||||||
ship_type="carrier", start={"row": 2, "column": "E"}, direction="horizontal"
|
ship_type="carrier", start={"row": 2, "column": "E"}, direction="horizontal"
|
||||||
)
|
)
|
||||||
|
@ -61,6 +61,7 @@ def test_ship_sinking_feedback(battleship_game, initialized_game_id):
|
|||||||
{"row": 1, "column": "H"},
|
{"row": 1, "column": "H"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
response = None
|
||||||
for index, hit in enumerate(hits):
|
for index, hit in enumerate(hits):
|
||||||
turn = Turn(target={"row": 2, "column": hit})
|
turn = Turn(target={"row": 2, "column": hit})
|
||||||
response = battleship_game.create_turn(initialized_game_id, turn)
|
response = battleship_game.create_turn(initialized_game_id, turn)
|
||||||
@ -69,7 +70,7 @@ def test_ship_sinking_feedback(battleship_game, initialized_game_id):
|
|||||||
static_turn = Turn(target=static_moves[index])
|
static_turn = Turn(target=static_moves[index])
|
||||||
battleship_game.create_turn(initialized_game_id, static_turn)
|
battleship_game.create_turn(initialized_game_id, static_turn)
|
||||||
|
|
||||||
assert response.result == "sunk"
|
assert response and response.result == "sunk"
|
||||||
|
|
||||||
|
|
||||||
def test_restart_game(battleship_game):
|
def test_restart_game(battleship_game):
|
||||||
|
@ -37,15 +37,14 @@ class GameStatus(BaseModel):
|
|||||||
winner: Optional[str]
|
winner: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
|
|
||||||
class Game(BaseModel):
|
class Game(BaseModel):
|
||||||
game_id: str
|
game_id: str
|
||||||
players: List[str]
|
players: list[str]
|
||||||
board: dict # This could represent the state of the game board, you might need to flesh this out further
|
# This could represent the state of the game board,
|
||||||
ships: List[ShipPlacement] # List of ship placements for this game
|
# you might need to flesh this out further:
|
||||||
turns: List[Turn] # List of turns that have been taken
|
board: dict
|
||||||
|
ships: list[ShipPlacement] # List of ship placements for this game
|
||||||
|
turns: list[Turn] # List of turns that have been taken
|
||||||
|
|
||||||
|
|
||||||
class AbstractBattleship(ABC):
|
class AbstractBattleship(ABC):
|
||||||
@ -86,7 +85,7 @@ class AbstractBattleship(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_game(self) -> Game:
|
def get_game(self, game_id: str) -> Game | None:
|
||||||
"""
|
"""
|
||||||
Retrieve the state of the game.
|
Retrieve the state of the game.
|
||||||
"""
|
"""
|
||||||
@ -100,8 +99,11 @@ class AbstractBattleship(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_game(self) -> None:
|
def create_game(self) -> str:
|
||||||
"""
|
"""
|
||||||
Create a new game.
|
Create a new game.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The ID of the created game.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
@ -1,14 +1,20 @@
|
|||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from abstract_class import (AbstractBattleship, Game, GameStatus,
|
from abstract_class import (
|
||||||
ShipPlacement, Turn, TurnResponse)
|
AbstractBattleship,
|
||||||
|
Game,
|
||||||
|
GameStatus,
|
||||||
|
ShipPlacement,
|
||||||
|
Turn,
|
||||||
|
TurnResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Battleship(AbstractBattleship):
|
class Battleship(AbstractBattleship):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.games: Dict[int, Game] = {}
|
self.games: Dict[str, Game] = {}
|
||||||
|
|
||||||
def create_game(self) -> int:
|
def create_game(self) -> str:
|
||||||
game_id = str(len(self.games))
|
game_id = str(len(self.games))
|
||||||
new_game = Game(
|
new_game = Game(
|
||||||
game_id=game_id,
|
game_id=game_id,
|
||||||
@ -19,7 +25,7 @@ class Battleship(AbstractBattleship):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.games[game_id] = new_game
|
self.games[game_id] = new_game
|
||||||
return new_game.game_id
|
return game_id
|
||||||
|
|
||||||
def create_ship_placement(self, game_id: str, placement: ShipPlacement) -> None:
|
def create_ship_placement(self, game_id: str, placement: ShipPlacement) -> None:
|
||||||
game = self.games.get(game_id)
|
game = self.games.get(game_id)
|
||||||
@ -79,38 +85,34 @@ class Battleship(AbstractBattleship):
|
|||||||
|
|
||||||
game.turns.append(turn)
|
game.turns.append(turn)
|
||||||
|
|
||||||
if hit_ship == "hit":
|
if not hit_ship or hit_ship == "hit": # if no ship or already hit
|
||||||
return TurnResponse(result="miss", ship_type=None)
|
return TurnResponse(result="miss", ship_type=None)
|
||||||
|
|
||||||
if hit_ship:
|
ship_placement = next(sp for sp in game.ships if sp.ship_type == hit_ship)
|
||||||
ship_placement = next(sp for sp in game.ships if sp.ship_type == hit_ship)
|
start_row, start_col = (
|
||||||
|
ship_placement.start["row"],
|
||||||
|
ord(ship_placement.start["column"]) - ord("A"),
|
||||||
|
)
|
||||||
|
ship_positions = [
|
||||||
|
(
|
||||||
|
start_row + (i if ship_placement.direction == "vertical" else 0),
|
||||||
|
start_col + (i if ship_placement.direction == "horizontal" else 0),
|
||||||
|
)
|
||||||
|
for i in range(self.SHIP_LENGTHS[hit_ship])
|
||||||
|
]
|
||||||
|
|
||||||
if hit_ship:
|
targeted_positions = {
|
||||||
ship_placement = next(sp for sp in game.ships if sp.ship_type == hit_ship)
|
(t.target["row"], ord(t.target["column"]) - ord("A")) for t in game.turns
|
||||||
start_row, start_col = ship_placement.start["row"], ord(
|
}
|
||||||
ship_placement.start["column"]
|
|
||||||
) - ord("A")
|
|
||||||
ship_positions = [
|
|
||||||
(
|
|
||||||
start_row + (i if ship_placement.direction == "vertical" else 0),
|
|
||||||
start_col + (i if ship_placement.direction == "horizontal" else 0),
|
|
||||||
)
|
|
||||||
for i in range(self.SHIP_LENGTHS[hit_ship])
|
|
||||||
]
|
|
||||||
|
|
||||||
targeted_positions = {
|
game.board[(target_row, target_col)] = "hit"
|
||||||
(t.target["row"], ord(t.target["column"]) - ord("A"))
|
|
||||||
for t in game.turns
|
|
||||||
}
|
|
||||||
|
|
||||||
game.board[(target_row, target_col)] = "hit"
|
if set(ship_positions).issubset(targeted_positions):
|
||||||
|
for pos in ship_positions:
|
||||||
if set(ship_positions).issubset(targeted_positions):
|
game.board[pos] = "hit"
|
||||||
for pos in ship_positions:
|
return TurnResponse(result="sunk", ship_type=hit_ship)
|
||||||
game.board[pos] = "hit"
|
else:
|
||||||
return TurnResponse(result="sunk", ship_type=hit_ship)
|
return TurnResponse(result="hit", ship_type=hit_ship)
|
||||||
else:
|
|
||||||
return TurnResponse(result="hit", ship_type=hit_ship)
|
|
||||||
|
|
||||||
def get_game_status(self, game_id: str) -> GameStatus:
|
def get_game_status(self, game_id: str) -> GameStatus:
|
||||||
game = self.games.get(game_id)
|
game = self.games.get(game_id)
|
||||||
@ -132,12 +134,12 @@ class Battleship(AbstractBattleship):
|
|||||||
def get_winner(self, game_id: str) -> str:
|
def get_winner(self, game_id: str) -> str:
|
||||||
game_status = self.get_game_status(game_id)
|
game_status = self.get_game_status(game_id)
|
||||||
|
|
||||||
if game_status.is_game_over:
|
if game_status.is_game_over and game_status.winner:
|
||||||
return game_status.winner
|
return game_status.winner
|
||||||
else:
|
else:
|
||||||
return None
|
raise ValueError(f"Game {game_id} isn't over yet")
|
||||||
|
|
||||||
def get_game(self, game_id: str) -> Game:
|
def get_game(self, game_id: str) -> Game | None:
|
||||||
return self.games.get(game_id)
|
return self.games.get(game_id)
|
||||||
|
|
||||||
def delete_game(self, game_id: str) -> None:
|
def delete_game(self, game_id: str) -> None:
|
||||||
|
@ -50,7 +50,7 @@ def test_cant_hit_before_ships_placed(battleship_game):
|
|||||||
|
|
||||||
|
|
||||||
def test_cant_place_ship_after_all_ships_placed(battleship_game, initialized_game_id):
|
def test_cant_place_ship_after_all_ships_placed(battleship_game, initialized_game_id):
|
||||||
game = battleship_game.get_game(initialized_game_id)
|
battleship_game.get_game(initialized_game_id)
|
||||||
additional_ship = ShipPlacement(
|
additional_ship = ShipPlacement(
|
||||||
ship_type="carrier", start={"row": 2, "column": "E"}, direction="horizontal"
|
ship_type="carrier", start={"row": 2, "column": "E"}, direction="horizontal"
|
||||||
)
|
)
|
||||||
|
@ -61,6 +61,7 @@ def test_ship_sinking_feedback(battleship_game, initialized_game_id):
|
|||||||
{"row": 1, "column": "H"},
|
{"row": 1, "column": "H"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
response = None
|
||||||
for index, hit in enumerate(hits):
|
for index, hit in enumerate(hits):
|
||||||
turn = Turn(target={"row": 2, "column": hit})
|
turn = Turn(target={"row": 2, "column": hit})
|
||||||
response = battleship_game.create_turn(initialized_game_id, turn)
|
response = battleship_game.create_turn(initialized_game_id, turn)
|
||||||
@ -69,7 +70,7 @@ def test_ship_sinking_feedback(battleship_game, initialized_game_id):
|
|||||||
static_turn = Turn(target=static_moves[index])
|
static_turn = Turn(target=static_moves[index])
|
||||||
battleship_game.create_turn(initialized_game_id, static_turn)
|
battleship_game.create_turn(initialized_game_id, static_turn)
|
||||||
|
|
||||||
assert response.result == "sunk"
|
assert response and response.result == "sunk"
|
||||||
|
|
||||||
|
|
||||||
def test_restart_game(battleship_game):
|
def test_restart_game(battleship_game):
|
||||||
|
@ -6,7 +6,7 @@ from typing import ClassVar, Iterator, Literal
|
|||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
from agent_protocol_client import AgentApi, Step
|
from agent_protocol_client import AgentApi, Step
|
||||||
from pydantic import BaseModel, validator, ValidationError
|
from pydantic import BaseModel, ValidationError, validator
|
||||||
|
|
||||||
from agbenchmark.config import AgentBenchmarkConfig
|
from agbenchmark.config import AgentBenchmarkConfig
|
||||||
from agbenchmark.utils.data_types import Category, EvalResult
|
from agbenchmark.utils.data_types import Category, EvalResult
|
||||||
@ -93,11 +93,12 @@ class Eval(ABC):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class StringEval(BaseModel, Eval):
|
class BaseStringEval(BaseModel, Eval):
|
||||||
type: ReferenceAnswerType
|
# type: ReferenceAnswerType
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ExactStringMatchEval(StringEval):
|
class ExactStringMatchEval(BaseStringEval):
|
||||||
type: Literal["exact_match"] = "exact_match"
|
type: Literal["exact_match"] = "exact_match"
|
||||||
reference_answer: str
|
reference_answer: str
|
||||||
|
|
||||||
@ -109,7 +110,7 @@ class ExactStringMatchEval(StringEval):
|
|||||||
return string == self.reference_answer
|
return string == self.reference_answer
|
||||||
|
|
||||||
|
|
||||||
class FuzzyStringMatchEval(StringEval):
|
class FuzzyStringMatchEval(BaseStringEval):
|
||||||
type: Literal["fuzzy_match"] = "fuzzy_match"
|
type: Literal["fuzzy_match"] = "fuzzy_match"
|
||||||
reference_answer: str
|
reference_answer: str
|
||||||
|
|
||||||
@ -122,7 +123,7 @@ class FuzzyStringMatchEval(StringEval):
|
|||||||
return self.reference_answer.lower() in string.lower()
|
return self.reference_answer.lower() in string.lower()
|
||||||
|
|
||||||
|
|
||||||
class MustIncludeStringEval(StringEval):
|
class MustIncludeStringEval(BaseStringEval):
|
||||||
type: Literal["must_include"] = "must_include"
|
type: Literal["must_include"] = "must_include"
|
||||||
reference_answer: str
|
reference_answer: str
|
||||||
|
|
||||||
@ -134,6 +135,9 @@ class MustIncludeStringEval(StringEval):
|
|||||||
return self.reference_answer.lower() in string.lower()
|
return self.reference_answer.lower() in string.lower()
|
||||||
|
|
||||||
|
|
||||||
|
StringEval = ExactStringMatchEval | FuzzyStringMatchEval | MustIncludeStringEval
|
||||||
|
|
||||||
|
|
||||||
class UrlMatchEval(BaseModel, Eval):
|
class UrlMatchEval(BaseModel, Eval):
|
||||||
url: str
|
url: str
|
||||||
"""Example: `"__WIKI__/wiki/Octopus"`"""
|
"""Example: `"__WIKI__/wiki/Octopus"`"""
|
||||||
@ -142,8 +146,8 @@ class UrlMatchEval(BaseModel, Eval):
|
|||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return f"Agent must navigate to '{self.url}'"
|
return f"Agent must navigate to '{self.url}'"
|
||||||
|
|
||||||
def evaluate(self, url: str) -> bool:
|
def evaluate(self, string: str) -> bool:
|
||||||
return url == resolve_uri(self.url)
|
return string == resolve_uri(self.url)
|
||||||
|
|
||||||
|
|
||||||
class ProgramHtmlEval(BaseModel):
|
class ProgramHtmlEval(BaseModel):
|
||||||
@ -258,7 +262,8 @@ class WebArenaChallengeSpec(BaseModel):
|
|||||||
f"{' and '.join(s.base_url for s in sites)}.\n\n"
|
f"{' and '.join(s.base_url for s in sites)}.\n\n"
|
||||||
+ "\n".join(
|
+ "\n".join(
|
||||||
s.additional_info.format(url=s.base_url)
|
s.additional_info.format(url=s.base_url)
|
||||||
for s in sites if s.additional_info
|
for s in sites
|
||||||
|
if s.additional_info
|
||||||
)
|
)
|
||||||
).strip()
|
).strip()
|
||||||
|
|
||||||
@ -391,7 +396,9 @@ class WebArenaChallenge(BaseChallenge):
|
|||||||
if request.config.getoption("--nc"):
|
if request.config.getoption("--nc"):
|
||||||
timeout = 100000
|
timeout = 100000
|
||||||
elif cutoff := request.config.getoption("--cutoff"):
|
elif cutoff := request.config.getoption("--cutoff"):
|
||||||
timeout = int(cutoff)
|
timeout = int(cutoff) # type: ignore
|
||||||
|
|
||||||
|
assert isinstance(request.node, pytest.Item)
|
||||||
|
|
||||||
n_steps = 0
|
n_steps = 0
|
||||||
timed_out = None
|
timed_out = None
|
||||||
@ -400,7 +407,7 @@ class WebArenaChallenge(BaseChallenge):
|
|||||||
eval_results_per_step: list[list[tuple[_Eval, EvalResult]]] = []
|
eval_results_per_step: list[list[tuple[_Eval, EvalResult]]] = []
|
||||||
try:
|
try:
|
||||||
async for step in self.run_challenge(
|
async for step in self.run_challenge(
|
||||||
config, timeout, mock=request.config.getoption("--mock")
|
config, timeout, mock=bool(request.config.getoption("--mock"))
|
||||||
):
|
):
|
||||||
if not step.output:
|
if not step.output:
|
||||||
logger.warn(f"Step has no output: {step}")
|
logger.warn(f"Step has no output: {step}")
|
||||||
@ -415,7 +422,7 @@ class WebArenaChallenge(BaseChallenge):
|
|||||||
)
|
)
|
||||||
|
|
||||||
step_eval_results = self.evaluate_step_result(
|
step_eval_results = self.evaluate_step_result(
|
||||||
step, mock=request.config.getoption("--mock")
|
step, mock=bool(request.config.getoption("--mock"))
|
||||||
)
|
)
|
||||||
logger.debug(f"Intermediary results: {step_eval_results}")
|
logger.debug(f"Intermediary results: {step_eval_results}")
|
||||||
eval_results_per_step.append(step_eval_results)
|
eval_results_per_step.append(step_eval_results)
|
||||||
@ -462,7 +469,7 @@ class WebArenaChallenge(BaseChallenge):
|
|||||||
|
|
||||||
|
|
||||||
def load_webarena_challenges(
|
def load_webarena_challenges(
|
||||||
skip_unavailable: bool = True
|
skip_unavailable: bool = True,
|
||||||
) -> Iterator[type[WebArenaChallenge]]:
|
) -> Iterator[type[WebArenaChallenge]]:
|
||||||
logger.info("Loading WebArena challenges...")
|
logger.info("Loading WebArena challenges...")
|
||||||
|
|
||||||
|
@ -123,8 +123,10 @@ def check_regression(request: pytest.FixtureRequest) -> None:
|
|||||||
with contextlib.suppress(FileNotFoundError):
|
with contextlib.suppress(FileNotFoundError):
|
||||||
rt_tracker = RegressionTestsTracker(agbenchmark_config.regression_tests_file)
|
rt_tracker = RegressionTestsTracker(agbenchmark_config.regression_tests_file)
|
||||||
|
|
||||||
|
assert isinstance(request.node, pytest.Function)
|
||||||
|
assert isinstance(request.node.parent, pytest.Class)
|
||||||
test_name = request.node.parent.name
|
test_name = request.node.parent.name
|
||||||
challenge_location = getattr(request.node.parent.cls, "CHALLENGE_LOCATION", "")
|
challenge_location = getattr(request.node.cls, "CHALLENGE_LOCATION", "")
|
||||||
skip_string = f"Skipping {test_name} at {challenge_location}"
|
skip_string = f"Skipping {test_name} at {challenge_location}"
|
||||||
|
|
||||||
# Check if the test name exists in the regression tests
|
# Check if the test name exists in the regression tests
|
||||||
@ -148,7 +150,9 @@ def mock(request: pytest.FixtureRequest) -> bool:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: Whether `--mock` is set for this session.
|
bool: Whether `--mock` is set for this session.
|
||||||
"""
|
"""
|
||||||
return request.config.getoption("--mock")
|
mock = request.config.getoption("--mock")
|
||||||
|
assert isinstance(mock, bool)
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
test_reports: dict[str, Test] = {}
|
test_reports: dict[str, Test] = {}
|
||||||
@ -221,7 +225,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc):
|
|||||||
|
|
||||||
|
|
||||||
def pytest_collection_modifyitems(
|
def pytest_collection_modifyitems(
|
||||||
items: list[pytest.Item], config: pytest.Config
|
items: list[pytest.Function], config: pytest.Config
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Pytest hook that is called after initial test collection has been performed.
|
Pytest hook that is called after initial test collection has been performed.
|
||||||
@ -248,8 +252,9 @@ def pytest_collection_modifyitems(
|
|||||||
i = 0
|
i = 0
|
||||||
while i < len(items):
|
while i < len(items):
|
||||||
item = items[i]
|
item = items[i]
|
||||||
|
assert item.cls and issubclass(item.cls, BaseChallenge)
|
||||||
challenge = item.cls
|
challenge = item.cls
|
||||||
challenge_name = item.cls.__name__
|
challenge_name = challenge.info.name
|
||||||
|
|
||||||
if not issubclass(challenge, BaseChallenge):
|
if not issubclass(challenge, BaseChallenge):
|
||||||
item.warn(
|
item.warn(
|
||||||
|
@ -18,9 +18,9 @@ def run_benchmark(
|
|||||||
maintain: bool = False,
|
maintain: bool = False,
|
||||||
improve: bool = False,
|
improve: bool = False,
|
||||||
explore: bool = False,
|
explore: bool = False,
|
||||||
tests: tuple[str] = tuple(),
|
tests: tuple[str, ...] = tuple(),
|
||||||
categories: tuple[str] = tuple(),
|
categories: tuple[str, ...] = tuple(),
|
||||||
skip_categories: tuple[str] = tuple(),
|
skip_categories: tuple[str, ...] = tuple(),
|
||||||
attempts_per_challenge: int = 1,
|
attempts_per_challenge: int = 1,
|
||||||
mock: bool = False,
|
mock: bool = False,
|
||||||
no_dep: bool = False,
|
no_dep: bool = False,
|
||||||
|
@ -53,9 +53,9 @@ class SingletonReportManager:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def clear_instance(cls):
|
def clear_instance(cls):
|
||||||
cls.instance = None
|
cls.instance = None
|
||||||
cls.INFO_MANAGER = None
|
del cls.INFO_MANAGER
|
||||||
cls.REGRESSION_MANAGER = None
|
del cls.REGRESSION_MANAGER
|
||||||
cls.SUCCESS_RATE_TRACKER = None
|
del cls.SUCCESS_RATE_TRACKER
|
||||||
|
|
||||||
|
|
||||||
class BaseReportManager:
|
class BaseReportManager:
|
||||||
@ -99,7 +99,8 @@ class BaseReportManager:
|
|||||||
class SessionReportManager(BaseReportManager):
|
class SessionReportManager(BaseReportManager):
|
||||||
"""Abstracts interaction with the regression tests file"""
|
"""Abstracts interaction with the regression tests file"""
|
||||||
|
|
||||||
tests: dict[str, Test] | Report
|
tests: dict[str, Test]
|
||||||
|
report: Report | None = None
|
||||||
|
|
||||||
def __init__(self, report_file: Path, benchmark_start_time: datetime):
|
def __init__(self, report_file: Path, benchmark_start_time: datetime):
|
||||||
super().__init__(report_file)
|
super().__init__(report_file)
|
||||||
@ -109,20 +110,21 @@ class SessionReportManager(BaseReportManager):
|
|||||||
|
|
||||||
def save(self) -> None:
|
def save(self) -> None:
|
||||||
with self.report_file.open("w") as f:
|
with self.report_file.open("w") as f:
|
||||||
if isinstance(self.tests, Report):
|
if self.report:
|
||||||
f.write(self.tests.json(indent=4))
|
f.write(self.report.json(indent=4))
|
||||||
else:
|
else:
|
||||||
json.dump({k: v.dict() for k, v in self.tests.items()}, f, indent=4)
|
json.dump({k: v.dict() for k, v in self.tests.items()}, f, indent=4)
|
||||||
|
|
||||||
def load(self) -> None:
|
def load(self) -> None:
|
||||||
super().load()
|
super().load()
|
||||||
if "tests" in self.tests: # type: ignore
|
|
||||||
self.tests = Report.parse_obj(self.tests)
|
if "tests" in self.tests:
|
||||||
|
self.report = Report.parse_obj(self.tests)
|
||||||
else:
|
else:
|
||||||
self.tests = {n: Test.parse_obj(d) for n, d in self.tests.items()}
|
self.tests = {n: Test.parse_obj(d) for n, d in self.tests.items()}
|
||||||
|
|
||||||
def add_test_report(self, test_name: str, test_report: Test) -> None:
|
def add_test_report(self, test_name: str, test_report: Test) -> None:
|
||||||
if isinstance(self.tests, Report):
|
if self.report:
|
||||||
raise RuntimeError("Session report already finalized")
|
raise RuntimeError("Session report already finalized")
|
||||||
|
|
||||||
if test_name.startswith("Test"):
|
if test_name.startswith("Test"):
|
||||||
@ -134,10 +136,10 @@ class SessionReportManager(BaseReportManager):
|
|||||||
def finalize_session_report(self, config: AgentBenchmarkConfig) -> None:
|
def finalize_session_report(self, config: AgentBenchmarkConfig) -> None:
|
||||||
command = " ".join(sys.argv)
|
command = " ".join(sys.argv)
|
||||||
|
|
||||||
if isinstance(self.tests, Report):
|
if self.report:
|
||||||
raise RuntimeError("Session report already finalized")
|
raise RuntimeError("Session report already finalized")
|
||||||
|
|
||||||
self.tests = Report(
|
self.report = Report(
|
||||||
command=command.split(os.sep)[-1],
|
command=command.split(os.sep)[-1],
|
||||||
benchmark_git_commit_sha="---",
|
benchmark_git_commit_sha="---",
|
||||||
agent_git_commit_sha="---",
|
agent_git_commit_sha="---",
|
||||||
@ -156,7 +158,7 @@ class SessionReportManager(BaseReportManager):
|
|||||||
config=config.dict(exclude={"reports_folder"}, exclude_none=True),
|
config=config.dict(exclude={"reports_folder"}, exclude_none=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
agent_categories = get_highest_achieved_difficulty_per_category(self.tests)
|
agent_categories = get_highest_achieved_difficulty_per_category(self.report)
|
||||||
if len(agent_categories) > 1:
|
if len(agent_categories) > 1:
|
||||||
save_single_radar_chart(
|
save_single_radar_chart(
|
||||||
agent_categories,
|
agent_categories,
|
||||||
@ -166,8 +168,8 @@ class SessionReportManager(BaseReportManager):
|
|||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
def get_total_costs(self):
|
def get_total_costs(self):
|
||||||
if isinstance(self.tests, Report):
|
if self.report:
|
||||||
tests = self.tests.tests
|
tests = self.report.tests
|
||||||
else:
|
else:
|
||||||
tests = self.tests
|
tests = self.tests
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ Model definitions used internally and for reports generated during command-line
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List
|
from typing import Annotated, Any, Dict, List
|
||||||
|
|
||||||
from agent_protocol_client import Step
|
from agent_protocol_client import Step
|
||||||
from pydantic import BaseModel, Field, constr, validator
|
from pydantic import BaseModel, Field, constr, validator
|
||||||
@ -88,7 +88,7 @@ class Test(BaseModel):
|
|||||||
class ReportBase(BaseModel):
|
class ReportBase(BaseModel):
|
||||||
command: str
|
command: str
|
||||||
completion_time: str | None = None
|
completion_time: str | None = None
|
||||||
benchmark_start_time: constr(regex=datetime_format)
|
benchmark_start_time: Annotated[str, constr(regex=datetime_format)]
|
||||||
metrics: MetricsOverall
|
metrics: MetricsOverall
|
||||||
config: Dict[str, str | dict[str, str]]
|
config: Dict[str, str | dict[str, str]]
|
||||||
agent_git_commit_sha: str | None = None
|
agent_git_commit_sha: str | None = None
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Model definitions for use in the API"""
|
"""Model definitions for use in the API"""
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
from pydantic import BaseModel, constr
|
from pydantic import BaseModel, constr
|
||||||
|
|
||||||
@ -36,7 +37,7 @@ class RunDetails(BaseModel):
|
|||||||
run_id: str | None = None
|
run_id: str | None = None
|
||||||
command: str
|
command: str
|
||||||
completion_time: str | None = None
|
completion_time: str | None = None
|
||||||
benchmark_start_time: constr(regex=datetime_format)
|
benchmark_start_time: Annotated[str, constr(regex=datetime_format)]
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkRun(BaseModel):
|
class BenchmarkRun(BaseModel):
|
||||||
|
@ -1,14 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class TaskInput(BaseModel):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TaskRequestBody(BaseModel):
|
class TaskRequestBody(BaseModel):
|
||||||
input: str = Field(
|
input: str = Field(
|
||||||
...,
|
...,
|
||||||
@ -16,7 +12,7 @@ class TaskRequestBody(BaseModel):
|
|||||||
description="Input prompt for the task.",
|
description="Input prompt for the task.",
|
||||||
example="Write the words you receive to the file 'output.txt'.",
|
example="Write the words you receive to the file 'output.txt'.",
|
||||||
)
|
)
|
||||||
additional_input: Optional[TaskInput] = {}
|
additional_input: Optional[dict[str, Any]] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class TaskEvalRequestBody(TaskRequestBody):
|
class TaskEvalRequestBody(TaskRequestBody):
|
||||||
|
@ -32,7 +32,10 @@ def _add_ini_and_option(
|
|||||||
default: str | bool | int,
|
default: str | bool | int,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add an option to both the ini file as well as the command line flags, with the latter overriding the former."""
|
"""
|
||||||
|
Add an option to both the ini file and the command line flags.
|
||||||
|
Command line flags/options takes precedence over the ini config.
|
||||||
|
"""
|
||||||
parser.addini(
|
parser.addini(
|
||||||
name,
|
name,
|
||||||
help + " This overrides the similarly named option from the config.",
|
help + " This overrides the similarly named option from the config.",
|
||||||
@ -44,7 +47,10 @@ def _add_ini_and_option(
|
|||||||
def _get_ini_or_option(
|
def _get_ini_or_option(
|
||||||
config: Any, name: str, choices: Optional[list[str]]
|
config: Any, name: str, choices: Optional[list[str]]
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""Get an option from either the ini file or the command line flags, the latter taking precedence."""
|
"""
|
||||||
|
Get an option from either the ini file or the command line flags,
|
||||||
|
with the latter taking precedence.
|
||||||
|
"""
|
||||||
value = config.getini(name)
|
value = config.getini(name)
|
||||||
if value is not None and choices is not None and value not in choices:
|
if value is not None and choices is not None and value not in choices:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -73,7 +79,7 @@ def pytest_addoption(parser: Parser) -> None:
|
|||||||
default=False,
|
default=False,
|
||||||
help=(
|
help=(
|
||||||
"List all non-nodeid dependency names + the tests they resolve to. "
|
"List all non-nodeid dependency names + the tests they resolve to. "
|
||||||
"Will also list all nodeid dependency names when verbosity is high enough."
|
"Will also list all nodeid dependency names in verbose mode."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -83,7 +89,10 @@ def pytest_addoption(parser: Parser) -> None:
|
|||||||
"--list-processed-dependencies",
|
"--list-processed-dependencies",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=False,
|
default=False,
|
||||||
help="List all dependencies of all tests as a list of nodeids + the names that could not be resolved.",
|
help=(
|
||||||
|
"List all dependencies of all tests as a list of nodeids "
|
||||||
|
"+ the names that could not be resolved."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add an ini option + flag to choose the action to take for failed dependencies
|
# Add an ini option + flag to choose the action to take for failed dependencies
|
||||||
@ -94,7 +103,8 @@ def pytest_addoption(parser: Parser) -> None:
|
|||||||
name="failed_dependency_action",
|
name="failed_dependency_action",
|
||||||
help=(
|
help=(
|
||||||
"The action to take when a test has dependencies that failed. "
|
"The action to take when a test has dependencies that failed. "
|
||||||
'Use "run" to run the test anyway, "skip" to skip the test, and "fail" to fail the test.'
|
'Use "run" to run the test anyway, "skip" to skip the test, '
|
||||||
|
'and "fail" to fail the test.'
|
||||||
),
|
),
|
||||||
default="skip",
|
default="skip",
|
||||||
choices=DEPENDENCY_PROBLEM_ACTIONS.keys(),
|
choices=DEPENDENCY_PROBLEM_ACTIONS.keys(),
|
||||||
@ -107,8 +117,10 @@ def pytest_addoption(parser: Parser) -> None:
|
|||||||
group,
|
group,
|
||||||
name="missing_dependency_action",
|
name="missing_dependency_action",
|
||||||
help=(
|
help=(
|
||||||
"The action to take when a test has dependencies that cannot be found within the current scope. "
|
"The action to take when a test has dependencies that cannot be found "
|
||||||
'Use "run" to run the test anyway, "skip" to skip the test, and "fail" to fail the test.'
|
"within the current scope. "
|
||||||
|
'Use "run" to run the test anyway, "skip" to skip the test, '
|
||||||
|
'and "fail" to fail the test.'
|
||||||
),
|
),
|
||||||
default="warning",
|
default="warning",
|
||||||
choices=DEPENDENCY_PROBLEM_ACTIONS.keys(),
|
choices=DEPENDENCY_PROBLEM_ACTIONS.keys(),
|
||||||
@ -139,7 +151,7 @@ def pytest_configure(config: Any) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.hookimpl(trylast=True)
|
@pytest.hookimpl(trylast=True)
|
||||||
def pytest_collection_modifyitems(config: Any, items: list[Item]) -> None:
|
def pytest_collection_modifyitems(config: Any, items: list[pytest.Function]) -> None:
|
||||||
manager = managers[-1]
|
manager = managers[-1]
|
||||||
|
|
||||||
# Register the founds tests on the manager
|
# Register the founds tests on the manager
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
# The name of the marker used
|
# The name of the marker used
|
||||||
MARKER_NAME = "depends"
|
MARKER_NAME = "depends"
|
||||||
|
|
||||||
# The name of the keyword argument for the marker that contains custom name(s) for the tests
|
# The name of the kwarg for 'depends' markers that contains custom name(s) for the tests
|
||||||
MARKER_KWARG_ID = "name"
|
MARKER_KWARG_ID = "name"
|
||||||
|
|
||||||
# The name of the keyword argument for the marker that specifies the tests to depend on
|
# The name of the keyword argument for the marker that specifies the tests to depend on
|
||||||
|
@ -57,8 +57,10 @@ def curved_edges(
|
|||||||
"""
|
"""
|
||||||
ax = plt.gca()
|
ax = plt.gca()
|
||||||
for u, v, data in G.edges(data=True):
|
for u, v, data in G.edges(data=True):
|
||||||
src = np.array(pos[u])
|
_src = pos[u]
|
||||||
dst = np.array(pos[v])
|
_dst = pos[v]
|
||||||
|
src = np.array(_src)
|
||||||
|
dst = np.array(_dst)
|
||||||
|
|
||||||
same_level = abs(src[1] - dst[1]) < 0.01
|
same_level = abs(src[1] - dst[1]) < 0.01
|
||||||
|
|
||||||
@ -68,7 +70,7 @@ def curved_edges(
|
|||||||
arrow = patches.FancyArrowPatch(
|
arrow = patches.FancyArrowPatch(
|
||||||
posA=curve[0], # type: ignore
|
posA=curve[0], # type: ignore
|
||||||
posB=curve[-1], # type: ignore
|
posB=curve[-1], # type: ignore
|
||||||
connectionstyle=f"arc3,rad=0.2",
|
connectionstyle="arc3,rad=0.2",
|
||||||
color="gray",
|
color="gray",
|
||||||
arrowstyle="-|>",
|
arrowstyle="-|>",
|
||||||
mutation_scale=15.0,
|
mutation_scale=15.0,
|
||||||
@ -80,8 +82,8 @@ def curved_edges(
|
|||||||
else:
|
else:
|
||||||
ax.annotate(
|
ax.annotate(
|
||||||
"",
|
"",
|
||||||
xy=dst,
|
xy=_dst,
|
||||||
xytext=src,
|
xytext=_src,
|
||||||
arrowprops=dict(
|
arrowprops=dict(
|
||||||
arrowstyle="-|>", color="gray", lw=1, shrinkA=10, shrinkB=10
|
arrowstyle="-|>", color="gray", lw=1, shrinkA=10, shrinkB=10
|
||||||
),
|
),
|
||||||
@ -89,7 +91,8 @@ def curved_edges(
|
|||||||
|
|
||||||
|
|
||||||
def tree_layout(graph: nx.DiGraph, root_node: Any) -> Dict[Any, Tuple[float, float]]:
|
def tree_layout(graph: nx.DiGraph, root_node: Any) -> Dict[Any, Tuple[float, float]]:
|
||||||
"""Compute positions as a tree layout centered on the root with alternating vertical shifts."""
|
"""Compute positions as a tree layout centered on the root
|
||||||
|
with alternating vertical shifts."""
|
||||||
bfs_tree = nx.bfs_tree(graph, source=root_node)
|
bfs_tree = nx.bfs_tree(graph, source=root_node)
|
||||||
levels = {
|
levels = {
|
||||||
node: depth
|
node: depth
|
||||||
@ -137,7 +140,7 @@ def tree_layout(graph: nx.DiGraph, root_node: Any) -> Dict[Any, Tuple[float, flo
|
|||||||
def graph_spring_layout(
|
def graph_spring_layout(
|
||||||
dag: nx.DiGraph, labels: Dict[Any, str], tree: bool = True
|
dag: nx.DiGraph, labels: Dict[Any, str], tree: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
num_nodes = len(dag.nodes())
|
num_nodes = len(list(dag.nodes()))
|
||||||
# Setting up the figure and axis
|
# Setting up the figure and axis
|
||||||
fig, ax = plt.subplots()
|
fig, ax = plt.subplots()
|
||||||
ax.axis("off") # Turn off the axis
|
ax.axis("off") # Turn off the axis
|
||||||
@ -288,7 +291,8 @@ def graph_interactive_network(
|
|||||||
|
|
||||||
# Optionally, save to a file
|
# Optionally, save to a file
|
||||||
# Sync with the flutter UI
|
# Sync with the flutter UI
|
||||||
# this literally only works in the AutoGPT repo, but this part of the code is not reached if BUILD_SKILL_TREE is false
|
# this literally only works in the AutoGPT repo, but this part of the code
|
||||||
|
# is not reached if BUILD_SKILL_TREE is false
|
||||||
write_pretty_json(graph_data, flutter_app_path / "tree_structure.json")
|
write_pretty_json(graph_data, flutter_app_path / "tree_structure.json")
|
||||||
validate_skill_tree(graph_data, "")
|
validate_skill_tree(graph_data, "")
|
||||||
|
|
||||||
@ -332,11 +336,13 @@ def graph_interactive_network(
|
|||||||
|
|
||||||
def extract_subgraph_based_on_category(graph, category):
|
def extract_subgraph_based_on_category(graph, category):
|
||||||
"""
|
"""
|
||||||
Extracts a subgraph that includes all nodes and edges required to reach all nodes with a specified category.
|
Extracts a subgraph that includes all nodes and edges required to reach all nodes
|
||||||
|
with a specified category.
|
||||||
|
|
||||||
:param graph: The original graph.
|
:param graph: The original graph.
|
||||||
:param category: The target category.
|
:param category: The target category.
|
||||||
:return: Subgraph with nodes and edges required to reach the nodes with the given category.
|
:return: Subgraph with nodes and edges required to reach the nodes
|
||||||
|
with the given category.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
subgraph = {"nodes": [], "edges": []}
|
subgraph = {"nodes": [], "edges": []}
|
||||||
@ -424,7 +430,8 @@ def get_roots(graph):
|
|||||||
|
|
||||||
def validate_skill_tree(graph, skill_tree_name):
|
def validate_skill_tree(graph, skill_tree_name):
|
||||||
"""
|
"""
|
||||||
Validate if a given graph represents a valid skill tree and raise appropriate exceptions if not.
|
Validate if a given graph represents a valid skill tree
|
||||||
|
and raise appropriate exceptions if not.
|
||||||
|
|
||||||
:param graph: A dictionary representing the graph with 'nodes' and 'edges'.
|
:param graph: A dictionary representing the graph with 'nodes' and 'edges'.
|
||||||
:raises: ValueError with a description of the invalidity.
|
:raises: ValueError with a description of the invalidity.
|
||||||
@ -434,7 +441,8 @@ def validate_skill_tree(graph, skill_tree_name):
|
|||||||
if cycle_path:
|
if cycle_path:
|
||||||
cycle_str = " -> ".join(cycle_path)
|
cycle_str = " -> ".join(cycle_path)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{skill_tree_name} skill tree is circular! Circular path detected: {cycle_str}."
|
f"{skill_tree_name} skill tree is circular! "
|
||||||
|
f"Detected circular path: {cycle_str}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check for multiple roots
|
# Check for multiple roots
|
||||||
|
@ -1,18 +1,19 @@
|
|||||||
"""
|
"""
|
||||||
A module to manage dependencies between pytest tests.
|
A module to manage dependencies between pytest tests.
|
||||||
|
|
||||||
This module provides the methods implementing the main logic. These are used in the pytest hooks that are in
|
This module provides the methods implementing the main logic.
|
||||||
__init__.py.
|
These are used in the pytest hooks that are in __init__.py.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from typing import Any, Generator
|
from typing import Any, Generator
|
||||||
|
|
||||||
import colorama
|
import colorama
|
||||||
import networkx
|
import networkx
|
||||||
from _pytest.nodes import Item
|
from pytest import Function, Item
|
||||||
|
|
||||||
|
from agbenchmark.challenges.base import BaseChallenge
|
||||||
|
|
||||||
from .constants import MARKER_KWARG_DEPENDENCIES, MARKER_NAME
|
from .constants import MARKER_KWARG_DEPENDENCIES, MARKER_NAME
|
||||||
from .graphs import graph_interactive_network
|
from .graphs import graph_interactive_network
|
||||||
@ -38,7 +39,8 @@ class TestResult(object):
|
|||||||
)
|
)
|
||||||
if result.when in self.results:
|
if result.when in self.results:
|
||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
f"Received multiple results for step {result.when} of test {self.nodeid}"
|
f"Received multiple results for step {result.when} "
|
||||||
|
f"of test {self.nodeid}"
|
||||||
)
|
)
|
||||||
self.results[result.when] = result.outcome
|
self.results[result.when] = result.outcome
|
||||||
|
|
||||||
@ -66,7 +68,7 @@ class TestDependencies(object):
|
|||||||
for dep in marker.kwargs.get(MARKER_KWARG_DEPENDENCIES, [])
|
for dep in marker.kwargs.get(MARKER_KWARG_DEPENDENCIES, [])
|
||||||
]
|
]
|
||||||
for dependency in dependencies:
|
for dependency in dependencies:
|
||||||
# If the name is not known, try to make it absolute (ie file::[class::]method)
|
# If the name is not known, try to make it absolute (file::[class::]method)
|
||||||
if dependency not in manager.name_to_nodeids:
|
if dependency not in manager.name_to_nodeids:
|
||||||
absolute_dependency = get_absolute_nodeid(dependency, self.nodeid)
|
absolute_dependency = get_absolute_nodeid(dependency, self.nodeid)
|
||||||
if absolute_dependency in manager.name_to_nodeids:
|
if absolute_dependency in manager.name_to_nodeids:
|
||||||
@ -86,20 +88,20 @@ class DependencyManager(object):
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Create a new DependencyManager."""
|
"""Create a new DependencyManager."""
|
||||||
self.options: dict[str, Any] = {}
|
self.options: dict[str, Any] = {}
|
||||||
self._items: list[Item] | None = None
|
self._items: list[Function] | None = None
|
||||||
self._name_to_nodeids: Any = None
|
self._name_to_nodeids: Any = None
|
||||||
self._nodeid_to_item: Any = None
|
self._nodeid_to_item: Any = None
|
||||||
self._results: Any = None
|
self._results: Any = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def items(self) -> list[Item]:
|
def items(self) -> list[Function]:
|
||||||
"""The collected tests that are managed by this instance."""
|
"""The collected tests that are managed by this instance."""
|
||||||
if self._items is None:
|
if self._items is None:
|
||||||
raise AttributeError("The items attribute has not been set yet")
|
raise AttributeError("The items attribute has not been set yet")
|
||||||
return self._items
|
return self._items
|
||||||
|
|
||||||
@items.setter
|
@items.setter
|
||||||
def items(self, items: list[Item]) -> None:
|
def items(self, items: list[Function]) -> None:
|
||||||
if self._items is not None:
|
if self._items is not None:
|
||||||
raise AttributeError("The items attribute has already been set")
|
raise AttributeError("The items attribute has already been set")
|
||||||
self._items = items
|
self._items = items
|
||||||
@ -125,7 +127,8 @@ class DependencyManager(object):
|
|||||||
for item in items:
|
for item in items:
|
||||||
nodeid = clean_nodeid(item.nodeid)
|
nodeid = clean_nodeid(item.nodeid)
|
||||||
# Process the dependencies of this test
|
# Process the dependencies of this test
|
||||||
# This uses the mappings created in the previous loop, and can thus not be merged into that loop
|
# This uses the mappings created in the previous loop,
|
||||||
|
# and can thus not be merged into that loop
|
||||||
self._dependencies[nodeid] = TestDependencies(item, self)
|
self._dependencies[nodeid] = TestDependencies(item, self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -135,7 +138,7 @@ class DependencyManager(object):
|
|||||||
return self._name_to_nodeids
|
return self._name_to_nodeids
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def nodeid_to_item(self) -> dict[str, Item]:
|
def nodeid_to_item(self) -> dict[str, Function]:
|
||||||
"""A mapping from node ids to test items."""
|
"""A mapping from node ids to test items."""
|
||||||
assert self.items is not None
|
assert self.items is not None
|
||||||
return self._nodeid_to_item
|
return self._nodeid_to_item
|
||||||
@ -194,7 +197,9 @@ class DependencyManager(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def sorted_items(self) -> Generator:
|
def sorted_items(self) -> Generator:
|
||||||
"""Get a sorted list of tests where all tests are sorted after their dependencies."""
|
"""
|
||||||
|
Get a sorted list of tests where all tests are sorted after their dependencies.
|
||||||
|
"""
|
||||||
# Build a directed graph for sorting
|
# Build a directed graph for sorting
|
||||||
build_skill_tree = os.getenv("BUILD_SKILL_TREE")
|
build_skill_tree = os.getenv("BUILD_SKILL_TREE")
|
||||||
BUILD_SKILL_TREE = (
|
BUILD_SKILL_TREE = (
|
||||||
@ -202,8 +207,8 @@ class DependencyManager(object):
|
|||||||
)
|
)
|
||||||
dag = networkx.DiGraph()
|
dag = networkx.DiGraph()
|
||||||
|
|
||||||
# Insert all items as nodes, to prevent items that have no dependencies and are not dependencies themselves from
|
# Insert all items as nodes, to prevent items that have no dependencies
|
||||||
# being lost
|
# and are not dependencies themselves from being lost
|
||||||
dag.add_nodes_from(self.items)
|
dag.add_nodes_from(self.items)
|
||||||
|
|
||||||
# Insert edges for all the dependencies
|
# Insert edges for all the dependencies
|
||||||
@ -214,11 +219,8 @@ class DependencyManager(object):
|
|||||||
|
|
||||||
labels = {}
|
labels = {}
|
||||||
for item in self.items:
|
for item in self.items:
|
||||||
try:
|
assert item.cls and issubclass(item.cls, BaseChallenge)
|
||||||
with open(item.cls.CHALLENGE_LOCATION) as f:
|
data = item.cls.info.dict()
|
||||||
data = json.load(f)
|
|
||||||
except:
|
|
||||||
data = {}
|
|
||||||
|
|
||||||
node_name = get_name(item)
|
node_name = get_name(item)
|
||||||
data["name"] = node_name
|
data["name"] = node_name
|
||||||
|
@ -38,7 +38,8 @@ def strip_nodeid_parameters(nodeid: str) -> str:
|
|||||||
|
|
||||||
def get_absolute_nodeid(nodeid: str, scope: str) -> str:
|
def get_absolute_nodeid(nodeid: str, scope: str) -> str:
|
||||||
"""
|
"""
|
||||||
Transform a possibly relative node id to an absolute one using the scope in which it is used.
|
Transform a possibly relative node id to an absolute one
|
||||||
|
using the scope in which it is used.
|
||||||
|
|
||||||
>>> scope = 'test_file.py::TestClass::test'
|
>>> scope = 'test_file.py::TestClass::test'
|
||||||
>>> get_absolute_nodeid('test2', scope)
|
>>> get_absolute_nodeid('test2', scope)
|
||||||
@ -49,7 +50,7 @@ def get_absolute_nodeid(nodeid: str, scope: str) -> str:
|
|||||||
'test_file2.py::TestClass2::test2'
|
'test_file2.py::TestClass2::test2'
|
||||||
"""
|
"""
|
||||||
parts = nodeid.split("::")
|
parts = nodeid.split("::")
|
||||||
# Completely relative (test_name), so add the full current scope (either file::class or file)
|
# Completely relative (test_name): add the full current scope (file::class or file)
|
||||||
if len(parts) == 1:
|
if len(parts) == 1:
|
||||||
base_nodeid = scope.rsplit("::", 1)[0]
|
base_nodeid = scope.rsplit("::", 1)[0]
|
||||||
nodeid = f"{base_nodeid}::{nodeid}"
|
nodeid = f"{base_nodeid}::{nodeid}"
|
||||||
|
@ -15,7 +15,8 @@ def get_data_from_helicone(challenge: str) -> Optional[float]:
|
|||||||
# Define the endpoint of your GraphQL server
|
# Define the endpoint of your GraphQL server
|
||||||
url = "https://www.helicone.ai/api/graphql"
|
url = "https://www.helicone.ai/api/graphql"
|
||||||
|
|
||||||
# Set the headers, usually you'd need to set the content type and possibly an authorization token
|
# Set the headers, usually you'd need to set the content type
|
||||||
|
# and possibly an authorization token
|
||||||
headers = {"authorization": f"Bearer {os.environ.get('HELICONE_API_KEY')}"}
|
headers = {"authorization": f"Bearer {os.environ.get('HELICONE_API_KEY')}"}
|
||||||
|
|
||||||
# Define the query, variables, and operation name
|
# Define the query, variables, and operation name
|
||||||
|
@ -1,7 +1,18 @@
|
|||||||
SCORING_MAP = {
|
SCORING_MAP = {
|
||||||
"percentage": "assign a float score that will represent a percentage out of 100. Use decimal points to be even more accurate. 0 represents the worst possible generation, while 100 represents the ideal generation",
|
"percentage": (
|
||||||
"scale": "assign an integer score from a scale of 1-10. 1 represents a really bad generation, while 10 represents an ideal generation",
|
"assign a float score that will represent a percentage out of 100. "
|
||||||
"binary": "assign a binary score of either 0 or 1. 0 represents a failure, while 1 represents a success",
|
"Use decimal points to be even more accurate. "
|
||||||
|
"0 represents the worst possible generation, "
|
||||||
|
"while 100 represents the ideal generation"
|
||||||
|
),
|
||||||
|
"scale": (
|
||||||
|
"assign an integer score from a scale of 1-10. "
|
||||||
|
"1 represents a really bad generation, while 10 represents an ideal generation"
|
||||||
|
),
|
||||||
|
"binary": (
|
||||||
|
"assign a binary score of either 0 or 1. "
|
||||||
|
"0 represents a failure, while 1 represents a success"
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -17,7 +28,7 @@ Here is the ideal response you're comparing to based on the task:
|
|||||||
Here is the current machine generated response to the task that you need to evaluate:
|
Here is the current machine generated response to the task that you need to evaluate:
|
||||||
{response}
|
{response}
|
||||||
|
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
RUBRIC_PROMPT = """Ignore previous directions. You are now an expert at evaluating machine generated responses to given tasks.
|
RUBRIC_PROMPT = """Ignore previous directions. You are now an expert at evaluating machine generated responses to given tasks.
|
||||||
In order to score the generated texts you will {scoring}. Make sure to factor in rubric into your thinking, deliberation, and final result regarding scoring. Return nothing but a float score.
|
In order to score the generated texts you will {scoring}. Make sure to factor in rubric into your thinking, deliberation, and final result regarding scoring. Return nothing but a float score.
|
||||||
@ -31,7 +42,7 @@ Use the below rubric to guide your thinking about scoring:
|
|||||||
Here is the current machine generated response to the task that you need to evaluate:
|
Here is the current machine generated response to the task that you need to evaluate:
|
||||||
{response}
|
{response}
|
||||||
|
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
QUESTION_PROMPT = """Ignore previous directions. You are now an expert at evaluating machine generated responses to given tasks.
|
QUESTION_PROMPT = """Ignore previous directions. You are now an expert at evaluating machine generated responses to given tasks.
|
||||||
In order to score the generated texts you will {scoring}. Make sure to think about whether the generated response answers the question well in order to score accurately. Return nothing but a float score.
|
In order to score the generated texts you will {scoring}. Make sure to think about whether the generated response answers the question well in order to score accurately. Return nothing but a float score.
|
||||||
@ -45,12 +56,12 @@ Here is a question that checks if the task was completed correctly:
|
|||||||
Here is the current machine generated response to the task that you need to evaluate:
|
Here is the current machine generated response to the task that you need to evaluate:
|
||||||
{response}
|
{response}
|
||||||
|
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
FEW_SHOT_EXAMPLES = """Here are some examples of how to score a machine generated response based on the above:
|
FEW_SHOT_EXAMPLES = """Here are some examples of how to score a machine generated response based on the above:
|
||||||
{examples}
|
{examples}
|
||||||
|
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
CUSTOM_PROMPT = """{custom}
|
CUSTOM_PROMPT = """{custom}
|
||||||
{scoring}
|
{scoring}
|
||||||
|
@ -202,11 +202,15 @@ def sorted_by_enum_index(
|
|||||||
sortable: Iterable[T],
|
sortable: Iterable[T],
|
||||||
enum: type[Enum],
|
enum: type[Enum],
|
||||||
*,
|
*,
|
||||||
key: Callable[[T], Enum | None] = lambda x: x, # type: ignore
|
key: Optional[Callable[[T], Enum | None]] = None,
|
||||||
reverse: bool = False,
|
reverse: bool = False,
|
||||||
) -> list[T]:
|
) -> list[T]:
|
||||||
return sorted(
|
return sorted(
|
||||||
sortable,
|
sortable,
|
||||||
key=lambda x: enum._member_names_.index(e.name) if (e := key(x)) else 420e3,
|
key=lambda x: (
|
||||||
|
enum._member_names_.index(e.name) # type: ignore
|
||||||
|
if (e := key(x) if key else x)
|
||||||
|
else 420e3
|
||||||
|
),
|
||||||
reverse=reverse,
|
reverse=reverse,
|
||||||
)
|
)
|
||||||
|
@ -1,13 +0,0 @@
|
|||||||
[mypy]
|
|
||||||
namespace_packages = True
|
|
||||||
follow_imports = skip
|
|
||||||
check_untyped_defs = True
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
exclude = ^(agbenchmark/challenges/|agent/|venv|venv-dev)
|
|
||||||
ignore_missing_imports = True
|
|
||||||
|
|
||||||
[mypy-agbenchmark.utils.data_types.*]
|
|
||||||
ignore_errors = True
|
|
||||||
|
|
||||||
[mypy-numpy.*]
|
|
||||||
ignore_errors = True
|
|
213
benchmark/poetry.lock
generated
213
benchmark/poetry.lock
generated
@ -1,4 +1,4 @@
|
|||||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "agent-protocol-client"
|
name = "agent-protocol-client"
|
||||||
@ -197,63 +197,49 @@ tests = ["attrs[tests-no-zope]", "zope-interface"]
|
|||||||
tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"]
|
tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"]
|
||||||
tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"]
|
tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "autoflake"
|
|
||||||
version = "1.7.8"
|
|
||||||
description = "Removes unused imports and unused variables"
|
|
||||||
optional = false
|
|
||||||
python-versions = ">=3.7"
|
|
||||||
files = [
|
|
||||||
{file = "autoflake-1.7.8-py3-none-any.whl", hash = "sha256:46373ef69b6714f5064c923bb28bd797c4f8a9497f557d87fc36665c6d956b39"},
|
|
||||||
{file = "autoflake-1.7.8.tar.gz", hash = "sha256:e7e46372dee46fa1c97acf310d99d922b63d369718a270809d7c278d34a194cf"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
pyflakes = ">=1.1.0,<3"
|
|
||||||
tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""}
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "black"
|
name = "black"
|
||||||
version = "22.3.0"
|
version = "23.12.1"
|
||||||
description = "The uncompromising code formatter."
|
description = "The uncompromising code formatter."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.6.2"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "black-22.3.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:2497f9c2386572e28921fa8bec7be3e51de6801f7459dffd6e62492531c47e09"},
|
{file = "black-23.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0aaf6041986767a5e0ce663c7a2f0e9eaf21e6ff87a5f95cbf3675bfd4c41d2"},
|
||||||
{file = "black-22.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5795a0375eb87bfe902e80e0c8cfaedf8af4d49694d69161e5bd3206c18618bb"},
|
{file = "black-23.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c88b3711d12905b74206227109272673edce0cb29f27e1385f33b0163c414bba"},
|
||||||
{file = "black-22.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e3556168e2e5c49629f7b0f377070240bd5511e45e25a4497bb0073d9dda776a"},
|
{file = "black-23.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a920b569dc6b3472513ba6ddea21f440d4b4c699494d2e972a1753cdc25df7b0"},
|
||||||
{file = "black-22.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67c8301ec94e3bcc8906740fe071391bce40a862b7be0b86fb5382beefecd968"},
|
{file = "black-23.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:3fa4be75ef2a6b96ea8d92b1587dd8cb3a35c7e3d51f0738ced0781c3aa3a5a3"},
|
||||||
{file = "black-22.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:fd57160949179ec517d32ac2ac898b5f20d68ed1a9c977346efbac9c2f1e779d"},
|
{file = "black-23.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8d4df77958a622f9b5a4c96edb4b8c0034f8434032ab11077ec6c56ae9f384ba"},
|
||||||
{file = "black-22.3.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:cc1e1de68c8e5444e8f94c3670bb48a2beef0e91dddfd4fcc29595ebd90bb9ce"},
|
{file = "black-23.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:602cfb1196dc692424c70b6507593a2b29aac0547c1be9a1d1365f0d964c353b"},
|
||||||
{file = "black-22.3.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2fc92002d44746d3e7db7cf9313cf4452f43e9ea77a2c939defce3b10b5c82"},
|
{file = "black-23.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c4352800f14be5b4864016882cdba10755bd50805c95f728011bcb47a4afd59"},
|
||||||
{file = "black-22.3.0-cp36-cp36m-win_amd64.whl", hash = "sha256:a6342964b43a99dbc72f72812bf88cad8f0217ae9acb47c0d4f141a6416d2d7b"},
|
{file = "black-23.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:0808494f2b2df923ffc5723ed3c7b096bd76341f6213989759287611e9837d50"},
|
||||||
{file = "black-22.3.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:328efc0cc70ccb23429d6be184a15ce613f676bdfc85e5fe8ea2a9354b4e9015"},
|
{file = "black-23.12.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:25e57fd232a6d6ff3f4478a6fd0580838e47c93c83eaf1ccc92d4faf27112c4e"},
|
||||||
{file = "black-22.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06f9d8846f2340dfac80ceb20200ea5d1b3f181dd0556b47af4e8e0b24fa0a6b"},
|
{file = "black-23.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2d9e13db441c509a3763a7a3d9a49ccc1b4e974a47be4e08ade2a228876500ec"},
|
||||||
{file = "black-22.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:ad4efa5fad66b903b4a5f96d91461d90b9507a812b3c5de657d544215bb7877a"},
|
{file = "black-23.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1bd9c210f8b109b1762ec9fd36592fdd528485aadb3f5849b2740ef17e674e"},
|
||||||
{file = "black-22.3.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:e8477ec6bbfe0312c128e74644ac8a02ca06bcdb8982d4ee06f209be28cdf163"},
|
{file = "black-23.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:ae76c22bde5cbb6bfd211ec343ded2163bba7883c7bc77f6b756a1049436fbb9"},
|
||||||
{file = "black-22.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:637a4014c63fbf42a692d22b55d8ad6968a946b4a6ebc385c5505d9625b6a464"},
|
{file = "black-23.12.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1fa88a0f74e50e4487477bc0bb900c6781dbddfdfa32691e780bf854c3b4a47f"},
|
||||||
{file = "black-22.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:863714200ada56cbc366dc9ae5291ceb936573155f8bf8e9de92aef51f3ad0f0"},
|
{file = "black-23.12.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a4d6a9668e45ad99d2f8ec70d5c8c04ef4f32f648ef39048d010b0689832ec6d"},
|
||||||
{file = "black-22.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10dbe6e6d2988049b4655b2b739f98785a884d4d6b85bc35133a8fb9a2233176"},
|
{file = "black-23.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b18fb2ae6c4bb63eebe5be6bd869ba2f14fd0259bda7d18a46b764d8fb86298a"},
|
||||||
{file = "black-22.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:cee3e11161dde1b2a33a904b850b0899e0424cc331b7295f2a9698e79f9a69a0"},
|
{file = "black-23.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:c04b6d9d20e9c13f43eee8ea87d44156b8505ca8a3c878773f68b4e4812a421e"},
|
||||||
{file = "black-22.3.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5891ef8abc06576985de8fa88e95ab70641de6c1fca97e2a15820a9b69e51b20"},
|
{file = "black-23.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3e1b38b3135fd4c025c28c55ddfc236b05af657828a8a6abe5deec419a0b7055"},
|
||||||
{file = "black-22.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:30d78ba6bf080eeaf0b7b875d924b15cd46fec5fd044ddfbad38c8ea9171043a"},
|
{file = "black-23.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4f0031eaa7b921db76decd73636ef3a12c942ed367d8c3841a0739412b260a54"},
|
||||||
{file = "black-22.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ee8f1f7228cce7dffc2b464f07ce769f478968bfb3dd1254a4c2eeed84928aad"},
|
{file = "black-23.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97e56155c6b737854e60a9ab1c598ff2533d57e7506d97af5481141671abf3ea"},
|
||||||
{file = "black-22.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ee227b696ca60dd1c507be80a6bc849a5a6ab57ac7352aad1ffec9e8b805f21"},
|
{file = "black-23.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:dd15245c8b68fe2b6bd0f32c1556509d11bb33aec9b5d0866dd8e2ed3dba09c2"},
|
||||||
{file = "black-22.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:9b542ced1ec0ceeff5b37d69838106a6348e60db7b8fdd245294dc1d26136265"},
|
{file = "black-23.12.1-py3-none-any.whl", hash = "sha256:78baad24af0f033958cad29731e27363183e140962595def56423e626f4bee3e"},
|
||||||
{file = "black-22.3.0-py3-none-any.whl", hash = "sha256:bc58025940a896d7e5356952228b68f793cf5fcb342be703c3a2669a1488cb72"},
|
{file = "black-23.12.1.tar.gz", hash = "sha256:4ce3ef14ebe8d9509188014d96af1c456a910d5b5cbf434a09fef7e024b3d0d5"},
|
||||||
{file = "black-22.3.0.tar.gz", hash = "sha256:35020b8886c022ced9282b51b5a875b6d1ab0c387b31a065b84db7c33085ca79"},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
click = ">=8.0.0"
|
click = ">=8.0.0"
|
||||||
mypy-extensions = ">=0.4.3"
|
mypy-extensions = ">=0.4.3"
|
||||||
|
packaging = ">=22.0"
|
||||||
pathspec = ">=0.9.0"
|
pathspec = ">=0.9.0"
|
||||||
platformdirs = ">=2"
|
platformdirs = ">=2"
|
||||||
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
|
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
|
||||||
|
typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""}
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
colorama = ["colorama (>=0.4.3)"]
|
colorama = ["colorama (>=0.4.3)"]
|
||||||
d = ["aiohttp (>=3.7.4)"]
|
d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"]
|
||||||
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
|
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
|
||||||
uvloop = ["uvloop (>=0.15.2)"]
|
uvloop = ["uvloop (>=0.15.2)"]
|
||||||
|
|
||||||
@ -558,6 +544,73 @@ mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.6.1)", "types-Pill
|
|||||||
test = ["Pillow", "contourpy[test-no-images]", "matplotlib"]
|
test = ["Pillow", "contourpy[test-no-images]", "matplotlib"]
|
||||||
test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"]
|
test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "coverage"
|
||||||
|
version = "7.5.1"
|
||||||
|
description = "Code coverage measurement for Python"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "coverage-7.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0884920835a033b78d1c73b6d3bbcda8161a900f38a488829a83982925f6c2e"},
|
||||||
|
{file = "coverage-7.5.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:39afcd3d4339329c5f58de48a52f6e4e50f6578dd6099961cf22228feb25f38f"},
|
||||||
|
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a7b0ceee8147444347da6a66be737c9d78f3353b0681715b668b72e79203e4a"},
|
||||||
|
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a9ca3f2fae0088c3c71d743d85404cec8df9be818a005ea065495bedc33da35"},
|
||||||
|
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd215c0c7d7aab005221608a3c2b46f58c0285a819565887ee0b718c052aa4e"},
|
||||||
|
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4bf0655ab60d754491004a5efd7f9cccefcc1081a74c9ef2da4735d6ee4a6223"},
|
||||||
|
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:61c4bf1ba021817de12b813338c9be9f0ad5b1e781b9b340a6d29fc13e7c1b5e"},
|
||||||
|
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:db66fc317a046556a96b453a58eced5024af4582a8dbdc0c23ca4dbc0d5b3146"},
|
||||||
|
{file = "coverage-7.5.1-cp310-cp310-win32.whl", hash = "sha256:b016ea6b959d3b9556cb401c55a37547135a587db0115635a443b2ce8f1c7228"},
|
||||||
|
{file = "coverage-7.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:df4e745a81c110e7446b1cc8131bf986157770fa405fe90e15e850aaf7619bc8"},
|
||||||
|
{file = "coverage-7.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:796a79f63eca8814ca3317a1ea443645c9ff0d18b188de470ed7ccd45ae79428"},
|
||||||
|
{file = "coverage-7.5.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4fc84a37bfd98db31beae3c2748811a3fa72bf2007ff7902f68746d9757f3746"},
|
||||||
|
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6175d1a0559986c6ee3f7fccfc4a90ecd12ba0a383dcc2da30c2b9918d67d8a3"},
|
||||||
|
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fc81d5878cd6274ce971e0a3a18a8803c3fe25457165314271cf78e3aae3aa2"},
|
||||||
|
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:556cf1a7cbc8028cb60e1ff0be806be2eded2daf8129b8811c63e2b9a6c43bca"},
|
||||||
|
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9981706d300c18d8b220995ad22627647be11a4276721c10911e0e9fa44c83e8"},
|
||||||
|
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d7fed867ee50edf1a0b4a11e8e5d0895150e572af1cd6d315d557758bfa9c057"},
|
||||||
|
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ef48e2707fb320c8f139424a596f5b69955a85b178f15af261bab871873bb987"},
|
||||||
|
{file = "coverage-7.5.1-cp311-cp311-win32.whl", hash = "sha256:9314d5678dcc665330df5b69c1e726a0e49b27df0461c08ca12674bcc19ef136"},
|
||||||
|
{file = "coverage-7.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:5fa567e99765fe98f4e7d7394ce623e794d7cabb170f2ca2ac5a4174437e90dd"},
|
||||||
|
{file = "coverage-7.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b6cf3764c030e5338e7f61f95bd21147963cf6aa16e09d2f74f1fa52013c1206"},
|
||||||
|
{file = "coverage-7.5.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ec92012fefebee89a6b9c79bc39051a6cb3891d562b9270ab10ecfdadbc0c34"},
|
||||||
|
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16db7f26000a07efcf6aea00316f6ac57e7d9a96501e990a36f40c965ec7a95d"},
|
||||||
|
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:beccf7b8a10b09c4ae543582c1319c6df47d78fd732f854ac68d518ee1fb97fa"},
|
||||||
|
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8748731ad392d736cc9ccac03c9845b13bb07d020a33423fa5b3a36521ac6e4e"},
|
||||||
|
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7352b9161b33fd0b643ccd1f21f3a3908daaddf414f1c6cb9d3a2fd618bf2572"},
|
||||||
|
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:7a588d39e0925f6a2bff87154752481273cdb1736270642aeb3635cb9b4cad07"},
|
||||||
|
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:68f962d9b72ce69ea8621f57551b2fa9c70509af757ee3b8105d4f51b92b41a7"},
|
||||||
|
{file = "coverage-7.5.1-cp312-cp312-win32.whl", hash = "sha256:f152cbf5b88aaeb836127d920dd0f5e7edff5a66f10c079157306c4343d86c19"},
|
||||||
|
{file = "coverage-7.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:5a5740d1fb60ddf268a3811bcd353de34eb56dc24e8f52a7f05ee513b2d4f596"},
|
||||||
|
{file = "coverage-7.5.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e2213def81a50519d7cc56ed643c9e93e0247f5bbe0d1247d15fa520814a7cd7"},
|
||||||
|
{file = "coverage-7.5.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5037f8fcc2a95b1f0e80585bd9d1ec31068a9bcb157d9750a172836e98bc7a90"},
|
||||||
|
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3721c2c9e4c4953a41a26c14f4cef64330392a6d2d675c8b1db3b645e31f0e"},
|
||||||
|
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca498687ca46a62ae590253fba634a1fe9836bc56f626852fb2720f334c9e4e5"},
|
||||||
|
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cdcbc320b14c3e5877ee79e649677cb7d89ef588852e9583e6b24c2e5072661"},
|
||||||
|
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:57e0204b5b745594e5bc14b9b50006da722827f0b8c776949f1135677e88d0b8"},
|
||||||
|
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fe7502616b67b234482c3ce276ff26f39ffe88adca2acf0261df4b8454668b4"},
|
||||||
|
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9e78295f4144f9dacfed4f92935fbe1780021247c2fabf73a819b17f0ccfff8d"},
|
||||||
|
{file = "coverage-7.5.1-cp38-cp38-win32.whl", hash = "sha256:1434e088b41594baa71188a17533083eabf5609e8e72f16ce8c186001e6b8c41"},
|
||||||
|
{file = "coverage-7.5.1-cp38-cp38-win_amd64.whl", hash = "sha256:0646599e9b139988b63704d704af8e8df7fa4cbc4a1f33df69d97f36cb0a38de"},
|
||||||
|
{file = "coverage-7.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4cc37def103a2725bc672f84bd939a6fe4522310503207aae4d56351644682f1"},
|
||||||
|
{file = "coverage-7.5.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fc0b4d8bfeabd25ea75e94632f5b6e047eef8adaed0c2161ada1e922e7f7cece"},
|
||||||
|
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d0a0f5e06881ecedfe6f3dd2f56dcb057b6dbeb3327fd32d4b12854df36bf26"},
|
||||||
|
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9735317685ba6ec7e3754798c8871c2f49aa5e687cc794a0b1d284b2389d1bd5"},
|
||||||
|
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d21918e9ef11edf36764b93101e2ae8cc82aa5efdc7c5a4e9c6c35a48496d601"},
|
||||||
|
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c3e757949f268364b96ca894b4c342b41dc6f8f8b66c37878aacef5930db61be"},
|
||||||
|
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:79afb6197e2f7f60c4824dd4b2d4c2ec5801ceb6ba9ce5d2c3080e5660d51a4f"},
|
||||||
|
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d1d0d98d95dd18fe29dc66808e1accf59f037d5716f86a501fc0256455219668"},
|
||||||
|
{file = "coverage-7.5.1-cp39-cp39-win32.whl", hash = "sha256:1cc0fe9b0b3a8364093c53b0b4c0c2dd4bb23acbec4c9240b5f284095ccf7981"},
|
||||||
|
{file = "coverage-7.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:dde0070c40ea8bb3641e811c1cfbf18e265d024deff6de52c5950677a8fb1e0f"},
|
||||||
|
{file = "coverage-7.5.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:6537e7c10cc47c595828b8a8be04c72144725c383c4702703ff4e42e44577312"},
|
||||||
|
{file = "coverage-7.5.1.tar.gz", hash = "sha256:54de9ef3a9da981f7af93eafde4ede199e0846cd819eb27c88e2b712aae9708c"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""}
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
toml = ["tomli"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cycler"
|
name = "cycler"
|
||||||
version = "0.12.1"
|
version = "0.12.1"
|
||||||
@ -671,19 +724,19 @@ typing = ["typing-extensions (>=4.8)"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "flake8"
|
name = "flake8"
|
||||||
version = "3.9.2"
|
version = "7.0.0"
|
||||||
description = "the modular source code checker: pep8 pyflakes and co"
|
description = "the modular source code checker: pep8 pyflakes and co"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
|
python-versions = ">=3.8.1"
|
||||||
files = [
|
files = [
|
||||||
{file = "flake8-3.9.2-py2.py3-none-any.whl", hash = "sha256:bf8fd333346d844f616e8d47905ef3a3384edae6b4e9beb0c5101e25e3110907"},
|
{file = "flake8-7.0.0-py2.py3-none-any.whl", hash = "sha256:a6dfbb75e03252917f2473ea9653f7cd799c3064e54d4c8140044c5c065f53c3"},
|
||||||
{file = "flake8-3.9.2.tar.gz", hash = "sha256:07528381786f2a6237b061f6e96610a4167b226cb926e2aa2b6b1d78057c576b"},
|
{file = "flake8-7.0.0.tar.gz", hash = "sha256:33f96621059e65eec474169085dc92bf26e7b2d47366b70be2f67ab80dc25132"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
mccabe = ">=0.6.0,<0.7.0"
|
mccabe = ">=0.7.0,<0.8.0"
|
||||||
pycodestyle = ">=2.7.0,<2.8.0"
|
pycodestyle = ">=2.11.0,<2.12.0"
|
||||||
pyflakes = ">=2.3.0,<2.4.0"
|
pyflakes = ">=3.2.0,<3.3.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fonttools"
|
name = "fonttools"
|
||||||
@ -1376,13 +1429,13 @@ traitlets = "*"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mccabe"
|
name = "mccabe"
|
||||||
version = "0.6.1"
|
version = "0.7.0"
|
||||||
description = "McCabe checker, plugin for flake8"
|
description = "McCabe checker, plugin for flake8"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "*"
|
python-versions = ">=3.6"
|
||||||
files = [
|
files = [
|
||||||
{file = "mccabe-0.6.1-py2.py3-none-any.whl", hash = "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42"},
|
{file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"},
|
||||||
{file = "mccabe-0.6.1.tar.gz", hash = "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f"},
|
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -1973,13 +2026,13 @@ pyasn1 = ">=0.4.6,<0.6.0"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pycodestyle"
|
name = "pycodestyle"
|
||||||
version = "2.7.0"
|
version = "2.11.1"
|
||||||
description = "Python style guide checker"
|
description = "Python style guide checker"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "pycodestyle-2.7.0-py2.py3-none-any.whl", hash = "sha256:514f76d918fcc0b55c6680472f0a37970994e07bbb80725808c17089be302068"},
|
{file = "pycodestyle-2.11.1-py2.py3-none-any.whl", hash = "sha256:44fe31000b2d866f2e41841b18528a505fbd7fef9017b04eff4e2648a0fadc67"},
|
||||||
{file = "pycodestyle-2.7.0.tar.gz", hash = "sha256:c389c1d06bf7904078ca03399a4816f974a1d590090fecea0c63ec26ebaf1cef"},
|
{file = "pycodestyle-2.11.1.tar.gz", hash = "sha256:41ba0e7afc9752dfb53ced5489e89f8186be00e599e712660695b7a75ff2663f"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -2047,13 +2100,13 @@ email = ["email-validator (>=1.0.3)"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyflakes"
|
name = "pyflakes"
|
||||||
version = "2.3.1"
|
version = "3.2.0"
|
||||||
description = "passive checker of Python programs"
|
description = "passive checker of Python programs"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "pyflakes-2.3.1-py2.py3-none-any.whl", hash = "sha256:7893783d01b8a89811dd72d7dfd4d84ff098e5eed95cfa8905b22bbffe52efc3"},
|
{file = "pyflakes-3.2.0-py2.py3-none-any.whl", hash = "sha256:84b5be138a2dfbb40689ca07e2152deb896a65c3a3e24c251c5c62489568074a"},
|
||||||
{file = "pyflakes-2.3.1.tar.gz", hash = "sha256:f5bc8ecabc05bb9d291eb5203d6810b49040f6ff446a756326104746cc00c1db"},
|
{file = "pyflakes-3.2.0.tar.gz", hash = "sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -2085,6 +2138,24 @@ files = [
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
diagrams = ["jinja2", "railroad-diagrams"]
|
diagrams = ["jinja2", "railroad-diagrams"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyright"
|
||||||
|
version = "1.1.364"
|
||||||
|
description = "Command line wrapper for pyright"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "pyright-1.1.364-py3-none-any.whl", hash = "sha256:865f1e02873c5dc7427c95acf53659a118574010e6fb364e27e47ec5c46a9f26"},
|
||||||
|
{file = "pyright-1.1.364.tar.gz", hash = "sha256:612a2106a4078ec57efc22b5620729e9bdf4a3c17caba013b534bd33f7d08e5a"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
nodeenv = ">=1.6.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
all = ["twine (>=3.4.1)"]
|
||||||
|
dev = ["twine (>=3.4.1)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pysocks"
|
name = "pysocks"
|
||||||
version = "1.7.1"
|
version = "1.7.1"
|
||||||
@ -2137,6 +2208,24 @@ pytest = ">=7.0.0"
|
|||||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
|
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
|
||||||
testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
|
testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytest-cov"
|
||||||
|
version = "5.0.0"
|
||||||
|
description = "Pytest plugin for measuring coverage."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"},
|
||||||
|
{file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
coverage = {version = ">=5.2.1", extras = ["toml"]}
|
||||||
|
pytest = ">=4.6"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "python-dateutil"
|
name = "python-dateutil"
|
||||||
version = "2.8.2"
|
version = "2.8.2"
|
||||||
@ -2774,4 +2863,4 @@ multidict = ">=4.0"
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "6eefdbbefb500de627cac39eb6eb1fdcecab76dd4c3599cf08ef6dc647cf71c9"
|
content-hash = "4a980e6d8f54a2f7f6a3c55d4f40ac3a4b27b5ac6573dd2a39e11213a4b126dd"
|
||||||
|
@ -37,59 +37,49 @@ click-default-group = "^1.2.4"
|
|||||||
tabulate = "^0.9.0"
|
tabulate = "^0.9.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
flake8 = "^3.9.2"
|
black = "^23.12.1"
|
||||||
isort = "^5.9.3"
|
flake8 = "^7.0.0"
|
||||||
black = "22.3"
|
isort = "^5.13.1"
|
||||||
autoflake = "^1.4"
|
pyright = "^1.1.364"
|
||||||
pandas = "^2.0.3"
|
pandas = "^2.0.3"
|
||||||
gspread = "^5.10.0"
|
gspread = "^5.10.0"
|
||||||
oauth2client = "^4.1.3"
|
oauth2client = "^4.1.3"
|
||||||
pre-commit = "^3.3.3"
|
pre-commit = "^3.3.3"
|
||||||
|
pytest-cov = "^5.0.0"
|
||||||
|
|
||||||
|
[tool.poetry.scripts]
|
||||||
|
agbenchmark = "agbenchmark.__main__:cli"
|
||||||
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
|
||||||
minversion = "6.0"
|
|
||||||
addopts = "-ra -q"
|
|
||||||
testpaths = [
|
|
||||||
"tests", "agbenchmark",
|
|
||||||
]
|
|
||||||
asyncio_mode = "auto"
|
|
||||||
markers = [
|
|
||||||
"interface",
|
|
||||||
"code",
|
|
||||||
"memory",
|
|
||||||
"iterate",
|
|
||||||
"adaptability",
|
|
||||||
"safety",
|
|
||||||
"content_gen",
|
|
||||||
"product_advisor"
|
|
||||||
]
|
|
||||||
filterwarnings = [
|
|
||||||
"ignore::pytest.PytestAssertRewriteWarning",
|
|
||||||
"ignore::matplotlib.MatplotlibDeprecationWarning"
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 88
|
line-length = 88
|
||||||
target-version = ['py310']
|
target-version = ['py310']
|
||||||
include = '\.pyi?$'
|
include = '\.pyi?$'
|
||||||
packages = ["autogpt"]
|
|
||||||
extend-exclude = '(/dist|/.venv|/venv|/build|/agent|agbenchmark/challenges)/'
|
|
||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
profile = "black"
|
profile = "black"
|
||||||
multi_line_output = 3
|
skip_glob = ["reports"]
|
||||||
include_trailing_comma = true
|
|
||||||
force_grid_wrap = 0
|
|
||||||
use_parentheses = true
|
|
||||||
ensure_newline_before_comments = true
|
|
||||||
line_length = 88
|
|
||||||
sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"]
|
|
||||||
skip_glob = [".tox", "__pycache__", "*.pyc", "venv*/*", "reports", "venv", "env", "node_modules", ".env", ".venv", "dist", "agent/*", "agbenchmark/challenges/*"]
|
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
|
||||||
agbenchmark = "agbenchmark.__main__:cli"
|
[tool.pyright]
|
||||||
|
pythonVersion = "3.10"
|
||||||
|
exclude = [
|
||||||
|
"notebooks/**",
|
||||||
|
"reports/**",
|
||||||
|
"**/node_modules",
|
||||||
|
"**/__pycache__",
|
||||||
|
"**/.*",
|
||||||
|
]
|
||||||
|
ignore = [
|
||||||
|
"../forge/**"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
@ -17,7 +17,7 @@ def print_markdown_report(report_json_file: str):
|
|||||||
report = Report.parse_file(report_json_file)
|
report = Report.parse_file(report_json_file)
|
||||||
|
|
||||||
# Header and metadata
|
# Header and metadata
|
||||||
click.echo(f"# Benchmark Report")
|
click.echo("# Benchmark Report")
|
||||||
click.echo(f"- ⌛ **Run time:** `{report.metrics.run_time}`")
|
click.echo(f"- ⌛ **Run time:** `{report.metrics.run_time}`")
|
||||||
click.echo(
|
click.echo(
|
||||||
f" - **Started at:** `{report.benchmark_start_time[:16].replace('T', '` `')}`"
|
f" - **Started at:** `{report.benchmark_start_time[:16].replace('T', '` `')}`"
|
||||||
|
@ -1,11 +1,16 @@
|
|||||||
|
import datetime
|
||||||
|
import time
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
URL_BENCHMARK = "http://localhost:8080/ap/v1"
|
URL_BENCHMARK = "http://localhost:8080/ap/v1"
|
||||||
URL_AGENT = "http://localhost:8000/ap/v1"
|
URL_AGENT = "http://localhost:8000/ap/v1"
|
||||||
|
|
||||||
import datetime
|
try:
|
||||||
import time
|
response = requests.get(f"{URL_AGENT}/agent/tasks")
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
pytest.skip("No agent available to test against", allow_module_level=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -20,7 +25,8 @@ import time
|
|||||||
),
|
),
|
||||||
(
|
(
|
||||||
"f219f3d3-a41b-45a9-a3d0-389832086ee8",
|
"f219f3d3-a41b-45a9-a3d0-389832086ee8",
|
||||||
"Read the file called file_to_read.txt and write its content to a file called output.txt",
|
"Read the file called file_to_read.txt "
|
||||||
|
"and write its content to a file called output.txt",
|
||||||
1,
|
1,
|
||||||
"ReadFile",
|
"ReadFile",
|
||||||
False,
|
False,
|
||||||
@ -28,7 +34,11 @@ import time
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_entire_workflow(
|
def test_entire_workflow(
|
||||||
eval_id, input_text, expected_artifact_length, test_name, should_be_successful
|
eval_id: str,
|
||||||
|
input_text: str,
|
||||||
|
expected_artifact_length: int,
|
||||||
|
test_name: str,
|
||||||
|
should_be_successful: bool,
|
||||||
):
|
):
|
||||||
task_request = {"eval_id": eval_id, "input": input_text}
|
task_request = {"eval_id": eval_id, "input": input_text}
|
||||||
response = requests.get(f"{URL_AGENT}/agent/tasks")
|
response = requests.get(f"{URL_AGENT}/agent/tasks")
|
||||||
@ -64,7 +74,7 @@ def test_entire_workflow(
|
|||||||
)
|
)
|
||||||
assert step_response.status_code == 200
|
assert step_response.status_code == 200
|
||||||
step_response = step_response.json()
|
step_response = step_response.json()
|
||||||
assert step_response["is_last"] == True # Assuming is_last is always True
|
assert step_response["is_last"] is True # Assuming is_last is always True
|
||||||
|
|
||||||
eval_response = requests.post(
|
eval_response = requests.post(
|
||||||
URL_BENCHMARK + "/agent/tasks/" + task_response_benchmark_id + "/evaluations",
|
URL_BENCHMARK + "/agent/tasks/" + task_response_benchmark_id + "/evaluations",
|
||||||
|
8
cli.py
8
cli.py
@ -131,7 +131,9 @@ def start(agent_name: str, no_setup: bool):
|
|||||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
agent_dir = os.path.join(
|
agent_dir = os.path.join(
|
||||||
script_dir,
|
script_dir,
|
||||||
f"agents/{agent_name}" if agent_name not in ["autogpt", "forge"] else agent_name,
|
f"agents/{agent_name}"
|
||||||
|
if agent_name not in ["autogpt", "forge"]
|
||||||
|
else agent_name,
|
||||||
)
|
)
|
||||||
run_command = os.path.join(agent_dir, "run")
|
run_command = os.path.join(agent_dir, "run")
|
||||||
run_bench_command = os.path.join(agent_dir, "run_benchmark")
|
run_bench_command = os.path.join(agent_dir, "run_benchmark")
|
||||||
@ -247,7 +249,9 @@ def start(agent_name, subprocess_args):
|
|||||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
agent_dir = os.path.join(
|
agent_dir = os.path.join(
|
||||||
script_dir,
|
script_dir,
|
||||||
f"agents/{agent_name}" if agent_name not in ["autogpt", "forge"] else agent_name,
|
f"agents/{agent_name}"
|
||||||
|
if agent_name not in ["autogpt", "forge"]
|
||||||
|
else agent_name,
|
||||||
)
|
)
|
||||||
benchmark_script = os.path.join(agent_dir, "run_benchmark")
|
benchmark_script = os.path.join(agent_dir, "run_benchmark")
|
||||||
if os.path.exists(agent_dir) and os.path.isfile(benchmark_script):
|
if os.path.exists(agent_dir) and os.path.isfile(benchmark_script):
|
||||||
|
@ -202,7 +202,7 @@ class MyAgent(Agent):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
settings: AgentSettings,
|
settings: AgentSettings,
|
||||||
llm_provider: ChatModelProvider,
|
llm_provider: MultiProvider,
|
||||||
file_storage: FileStorage,
|
file_storage: FileStorage,
|
||||||
legacy_config: Config,
|
legacy_config: Config,
|
||||||
):
|
):
|
||||||
@ -219,7 +219,7 @@ class MyAgent(Agent):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
settings: AgentSettings,
|
settings: AgentSettings,
|
||||||
llm_provider: ChatModelProvider,
|
llm_provider: MultiProvider,
|
||||||
file_storage: FileStorage,
|
file_storage: FileStorage,
|
||||||
legacy_config: Config,
|
legacy_config: Config,
|
||||||
):
|
):
|
||||||
|
@ -1,15 +1,11 @@
|
|||||||
[flake8]
|
[flake8]
|
||||||
max-line-length = 88
|
max-line-length = 88
|
||||||
select = "E303, W293, W292, E305, E231, E302"
|
# Ignore rules that conflict with Black code style
|
||||||
|
extend-ignore = E203, W503
|
||||||
exclude =
|
exclude =
|
||||||
.tox,
|
.git,
|
||||||
__pycache__,
|
__pycache__/,
|
||||||
*.pyc,
|
*.pyc,
|
||||||
.env
|
.pytest_cache/,
|
||||||
venv*/*,
|
venv*/,
|
||||||
.venv/*,
|
.venv/,
|
||||||
reports/*,
|
|
||||||
dist/*,
|
|
||||||
agent/*,
|
|
||||||
code,
|
|
||||||
agbenchmark/challenges/*
|
|
||||||
|
5
forge/.gitignore
vendored
5
forge/.gitignore
vendored
@ -160,7 +160,8 @@ CURRENT_BULLETIN.md
|
|||||||
|
|
||||||
agbenchmark_config/workspace
|
agbenchmark_config/workspace
|
||||||
agbenchmark_config/reports
|
agbenchmark_config/reports
|
||||||
*.sqlite
|
*.sqlite*
|
||||||
|
*.db
|
||||||
.agbench
|
.agbench
|
||||||
.agbenchmark
|
.agbenchmark
|
||||||
.benchmarks
|
.benchmarks
|
||||||
@ -168,7 +169,7 @@ agbenchmark_config/reports
|
|||||||
.pytest_cache
|
.pytest_cache
|
||||||
.vscode
|
.vscode
|
||||||
ig_*
|
ig_*
|
||||||
agent.db
|
|
||||||
agbenchmark_config/updates.json
|
agbenchmark_config/updates.json
|
||||||
agbenchmark_config/challenges_already_beaten.json
|
agbenchmark_config/challenges_already_beaten.json
|
||||||
agbenchmark_config/temp_folder/*
|
agbenchmark_config/temp_folder/*
|
||||||
|
test_workspace/
|
||||||
|
@ -1,43 +0,0 @@
|
|||||||
repos:
|
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
||||||
rev: v4.4.0
|
|
||||||
hooks:
|
|
||||||
- id: check-added-large-files
|
|
||||||
args: ['--maxkb=500']
|
|
||||||
- id: check-byte-order-marker
|
|
||||||
- id: check-case-conflict
|
|
||||||
- id: check-merge-conflict
|
|
||||||
- id: check-symlinks
|
|
||||||
- id: debug-statements
|
|
||||||
|
|
||||||
- repo: https://github.com/pycqa/isort
|
|
||||||
rev: 5.12.0
|
|
||||||
hooks:
|
|
||||||
- id: isort
|
|
||||||
language_version: python3.11
|
|
||||||
|
|
||||||
- repo: https://github.com/psf/black
|
|
||||||
rev: 23.3.0
|
|
||||||
hooks:
|
|
||||||
- id: black
|
|
||||||
language_version: python3.11
|
|
||||||
|
|
||||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
|
||||||
# rev: 'v1.3.0'
|
|
||||||
# hooks:
|
|
||||||
# - id: mypy
|
|
||||||
|
|
||||||
- repo: local
|
|
||||||
hooks:
|
|
||||||
- id: autoflake
|
|
||||||
name: autoflake
|
|
||||||
entry: autoflake --in-place --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring forge/autogpt
|
|
||||||
language: python
|
|
||||||
types: [ python ]
|
|
||||||
# Mono repo has bronken this TODO: fix
|
|
||||||
# - id: pytest-check
|
|
||||||
# name: pytest-check
|
|
||||||
# entry: pytest
|
|
||||||
# language: system
|
|
||||||
# pass_filenames: false
|
|
||||||
# always_run: true
|
|
@ -9,27 +9,24 @@ from forge.logging.config import configure_logging
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
logo = """\n\n
|
logo = """\n\n
|
||||||
d8888 888 .d8888b. 8888888b. 88888888888
|
d8888 888 .d8888b. 8888888b. 88888888888
|
||||||
d88888 888 d88P Y88b 888 Y88b 888
|
d88P888 888 888 888 888 888 888
|
||||||
d88P888 888 888 888 888 888 888
|
d88P 888 888 888 888888 .d88b. 888 888 d88P 888
|
||||||
d88P 888 888 888 888888 .d88b. 888 888 d88P 888
|
d88P 888 888 888 888 d88""88b 888 88888 8888888P" 888
|
||||||
d88P 888 888 888 888 d88""88b 888 88888 8888888P" 888
|
d88P 888 888 888 888 888 888 888 888 888 888
|
||||||
d88P 888 888 888 888 888 888 888 888 888 888
|
d8888888888 Y88b 888 Y88b. Y88..88P Y88b d88P 888 888
|
||||||
d8888888888 Y88b 888 Y88b. Y88..88P Y88b d88P 888 888
|
d88P 888 "Y88888 "Y888 "Y88P" "Y8888P88 888 888
|
||||||
d88P 888 "Y88888 "Y888 "Y88P" "Y8888P88 888 888
|
|
||||||
|
|
||||||
|
8888888888
|
||||||
|
888
|
||||||
8888888888
|
888 .d88b. 888d888 .d88b. .d88b.
|
||||||
888
|
888888 d88""88b 888P" d88P"88b d8P Y8b
|
||||||
888
|
888 888 888 888 888 888 88888888
|
||||||
8888888 .d88b. 888d888 .d88b. .d88b.
|
888 Y88..88P 888 Y88b 888 Y8b.
|
||||||
888 d88""88b 888P" d88P"88b d8P Y8b
|
888 "Y88P" 888 "Y88888 "Y8888
|
||||||
888 888 888 888 888 888 88888888
|
888
|
||||||
888 Y88..88P 888 Y88b 888 Y8b.
|
Y8b d88P
|
||||||
888 "Y88P" 888 "Y88888 "Y8888
|
|
||||||
888
|
|
||||||
Y8b d88P
|
|
||||||
"Y88P" v0.1.0
|
"Y88P" v0.1.0
|
||||||
\n"""
|
\n"""
|
||||||
|
|
||||||
|
@ -1,15 +1,7 @@
|
|||||||
from .base import AgentMeta, BaseAgent, BaseAgentConfiguration, BaseAgentSettings
|
from .base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
|
||||||
from .components import (
|
|
||||||
AgentComponent,
|
__all__ = [
|
||||||
ComponentEndpointError,
|
"BaseAgent",
|
||||||
ComponentSystemError,
|
"BaseAgentConfiguration",
|
||||||
EndpointPipelineError,
|
"BaseAgentSettings",
|
||||||
)
|
]
|
||||||
from .protocols import (
|
|
||||||
AfterExecute,
|
|
||||||
AfterParse,
|
|
||||||
CommandProvider,
|
|
||||||
DirectiveProvider,
|
|
||||||
ExecutionFailure,
|
|
||||||
MessageProvider,
|
|
||||||
)
|
|
||||||
|
@ -24,7 +24,6 @@ from forge.agent_protocol.models.task import (
|
|||||||
TaskStepsListResponse,
|
TaskStepsListResponse,
|
||||||
)
|
)
|
||||||
from forge.file_storage.base import FileStorage
|
from forge.file_storage.base import FileStorage
|
||||||
from forge.utils.exceptions import NotFoundError
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -79,7 +78,8 @@ class Agent:
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Frontend not found. {frontend_path} does not exist. The frontend will not be served"
|
f"Frontend not found. {frontend_path} does not exist. "
|
||||||
|
"The frontend will not be served."
|
||||||
)
|
)
|
||||||
app.add_middleware(AgentMiddleware, agent=self)
|
app.add_middleware(AgentMiddleware, agent=self)
|
||||||
|
|
||||||
@ -94,34 +94,25 @@ class Agent:
|
|||||||
"""
|
"""
|
||||||
Create a task for the agent.
|
Create a task for the agent.
|
||||||
"""
|
"""
|
||||||
try:
|
task = await self.db.create_task(
|
||||||
task = await self.db.create_task(
|
input=task_request.input,
|
||||||
input=task_request.input,
|
additional_input=task_request.additional_input,
|
||||||
additional_input=task_request.additional_input,
|
)
|
||||||
)
|
return task
|
||||||
return task
|
|
||||||
except Exception as e:
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def list_tasks(self, page: int = 1, pageSize: int = 10) -> TaskListResponse:
|
async def list_tasks(self, page: int = 1, pageSize: int = 10) -> TaskListResponse:
|
||||||
"""
|
"""
|
||||||
List all tasks that the agent has created.
|
List all tasks that the agent has created.
|
||||||
"""
|
"""
|
||||||
try:
|
tasks, pagination = await self.db.list_tasks(page, pageSize)
|
||||||
tasks, pagination = await self.db.list_tasks(page, pageSize)
|
response = TaskListResponse(tasks=tasks, pagination=pagination)
|
||||||
response = TaskListResponse(tasks=tasks, pagination=pagination)
|
return response
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_task(self, task_id: str) -> Task:
|
async def get_task(self, task_id: str) -> Task:
|
||||||
"""
|
"""
|
||||||
Get a task by ID.
|
Get a task by ID.
|
||||||
"""
|
"""
|
||||||
try:
|
task = await self.db.get_task(task_id)
|
||||||
task = await self.db.get_task(task_id)
|
|
||||||
except Exception as e:
|
|
||||||
raise
|
|
||||||
return task
|
return task
|
||||||
|
|
||||||
async def list_steps(
|
async def list_steps(
|
||||||
@ -130,12 +121,9 @@ class Agent:
|
|||||||
"""
|
"""
|
||||||
List the IDs of all steps that the task has created.
|
List the IDs of all steps that the task has created.
|
||||||
"""
|
"""
|
||||||
try:
|
steps, pagination = await self.db.list_steps(task_id, page, pageSize)
|
||||||
steps, pagination = await self.db.list_steps(task_id, page, pageSize)
|
response = TaskStepsListResponse(steps=steps, pagination=pagination)
|
||||||
response = TaskStepsListResponse(steps=steps, pagination=pagination)
|
return response
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def execute_step(self, task_id: str, step_request: StepRequestBody) -> Step:
|
async def execute_step(self, task_id: str, step_request: StepRequestBody) -> Step:
|
||||||
"""
|
"""
|
||||||
@ -147,11 +135,8 @@ class Agent:
|
|||||||
"""
|
"""
|
||||||
Get a step by ID.
|
Get a step by ID.
|
||||||
"""
|
"""
|
||||||
try:
|
step = await self.db.get_step(task_id, step_id)
|
||||||
step = await self.db.get_step(task_id, step_id)
|
return step
|
||||||
return step
|
|
||||||
except Exception as e:
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def list_artifacts(
|
async def list_artifacts(
|
||||||
self, task_id: str, page: int = 1, pageSize: int = 10
|
self, task_id: str, page: int = 1, pageSize: int = 10
|
||||||
@ -159,62 +144,45 @@ class Agent:
|
|||||||
"""
|
"""
|
||||||
List the artifacts that the task has created.
|
List the artifacts that the task has created.
|
||||||
"""
|
"""
|
||||||
try:
|
artifacts, pagination = await self.db.list_artifacts(task_id, page, pageSize)
|
||||||
artifacts, pagination = await self.db.list_artifacts(
|
return TaskArtifactsListResponse(artifacts=artifacts, pagination=pagination)
|
||||||
task_id, page, pageSize
|
|
||||||
)
|
|
||||||
return TaskArtifactsListResponse(artifacts=artifacts, pagination=pagination)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def create_artifact(
|
async def create_artifact(
|
||||||
self, task_id: str, file: UploadFile, relative_path: str
|
self, task_id: str, file: UploadFile, relative_path: str = ""
|
||||||
) -> Artifact:
|
) -> Artifact:
|
||||||
"""
|
"""
|
||||||
Create an artifact for the task.
|
Create an artifact for the task.
|
||||||
"""
|
"""
|
||||||
data = None
|
|
||||||
file_name = file.filename or str(uuid4())
|
file_name = file.filename or str(uuid4())
|
||||||
try:
|
data = b""
|
||||||
data = b""
|
while contents := file.file.read(1024 * 1024):
|
||||||
while contents := file.file.read(1024 * 1024):
|
data += contents
|
||||||
data += contents
|
# Check if relative path ends with filename
|
||||||
# Check if relative path ends with filename
|
if relative_path.endswith(file_name):
|
||||||
if relative_path.endswith(file_name):
|
file_path = relative_path
|
||||||
file_path = relative_path
|
else:
|
||||||
else:
|
file_path = os.path.join(relative_path, file_name)
|
||||||
file_path = os.path.join(relative_path, file_name)
|
|
||||||
|
|
||||||
await self.workspace.write_file(file_path, data)
|
await self.workspace.write_file(file_path, data)
|
||||||
|
|
||||||
artifact = await self.db.create_artifact(
|
artifact = await self.db.create_artifact(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
file_name=file_name,
|
file_name=file_name,
|
||||||
relative_path=relative_path,
|
relative_path=relative_path,
|
||||||
agent_created=False,
|
agent_created=False,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
raise
|
|
||||||
return artifact
|
return artifact
|
||||||
|
|
||||||
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
|
async def get_artifact(self, task_id: str, artifact_id: str) -> StreamingResponse:
|
||||||
"""
|
"""
|
||||||
Get an artifact by ID.
|
Get an artifact by ID.
|
||||||
"""
|
"""
|
||||||
try:
|
artifact = await self.db.get_artifact(artifact_id)
|
||||||
artifact = await self.db.get_artifact(artifact_id)
|
if artifact.file_name not in artifact.relative_path:
|
||||||
if artifact.file_name not in artifact.relative_path:
|
file_path = os.path.join(artifact.relative_path, artifact.file_name)
|
||||||
file_path = os.path.join(artifact.relative_path, artifact.file_name)
|
else:
|
||||||
else:
|
file_path = artifact.relative_path
|
||||||
file_path = artifact.relative_path
|
retrieved_artifact = self.workspace.read_file(file_path, binary=True)
|
||||||
retrieved_artifact = self.workspace.read_file(file_path)
|
|
||||||
except NotFoundError as e:
|
|
||||||
raise
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
raise
|
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
BytesIO(retrieved_artifact),
|
BytesIO(retrieved_artifact),
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from fastapi import UploadFile
|
||||||
|
|
||||||
from forge.agent_protocol.database.db import AgentDB
|
from forge.agent_protocol.database.db import AgentDB
|
||||||
from forge.agent_protocol.models.task import (
|
from forge.agent_protocol.models.task import (
|
||||||
@ -16,16 +17,23 @@ from .agent import Agent
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def agent():
|
def agent(test_workspace: Path):
|
||||||
db = AgentDB("sqlite:///test.db")
|
db = AgentDB("sqlite:///test.db")
|
||||||
config = FileStorageConfiguration(root=Path("./test_workspace"))
|
config = FileStorageConfiguration(root=test_workspace)
|
||||||
workspace = LocalFileStorage(config)
|
workspace = LocalFileStorage(config)
|
||||||
return Agent(db, workspace)
|
return Agent(db, workspace)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
@pytest.fixture
|
||||||
|
def file_upload():
|
||||||
|
this_file = Path(__file__)
|
||||||
|
file_handle = this_file.open("rb")
|
||||||
|
yield UploadFile(file_handle, filename=this_file.name)
|
||||||
|
file_handle.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_task(agent):
|
async def test_create_task(agent: Agent):
|
||||||
task_request = TaskRequestBody(
|
task_request = TaskRequestBody(
|
||||||
input="test_input", additional_input={"input": "additional_test_input"}
|
input="test_input", additional_input={"input": "additional_test_input"}
|
||||||
)
|
)
|
||||||
@ -33,20 +41,18 @@ async def test_create_task(agent):
|
|||||||
assert task.input == "test_input"
|
assert task.input == "test_input"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_tasks(agent):
|
async def test_list_tasks(agent: Agent):
|
||||||
task_request = TaskRequestBody(
|
task_request = TaskRequestBody(
|
||||||
input="test_input", additional_input={"input": "additional_test_input"}
|
input="test_input", additional_input={"input": "additional_test_input"}
|
||||||
)
|
)
|
||||||
task = await agent.create_task(task_request)
|
await agent.create_task(task_request)
|
||||||
tasks = await agent.list_tasks()
|
tasks = await agent.list_tasks()
|
||||||
assert isinstance(tasks, TaskListResponse)
|
assert isinstance(tasks, TaskListResponse)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_task(agent):
|
async def test_get_task(agent: Agent):
|
||||||
task_request = TaskRequestBody(
|
task_request = TaskRequestBody(
|
||||||
input="test_input", additional_input={"input": "additional_test_input"}
|
input="test_input", additional_input={"input": "additional_test_input"}
|
||||||
)
|
)
|
||||||
@ -55,9 +61,9 @@ async def test_get_task(agent):
|
|||||||
assert retrieved_task.task_id == task.task_id
|
assert retrieved_task.task_id == task.task_id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
@pytest.mark.xfail(reason="execute_step is not implemented")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_and_execute_step(agent):
|
async def test_execute_step(agent: Agent):
|
||||||
task_request = TaskRequestBody(
|
task_request = TaskRequestBody(
|
||||||
input="test_input", additional_input={"input": "additional_test_input"}
|
input="test_input", additional_input={"input": "additional_test_input"}
|
||||||
)
|
)
|
||||||
@ -65,14 +71,14 @@ async def test_create_and_execute_step(agent):
|
|||||||
step_request = StepRequestBody(
|
step_request = StepRequestBody(
|
||||||
input="step_input", additional_input={"input": "additional_test_input"}
|
input="step_input", additional_input={"input": "additional_test_input"}
|
||||||
)
|
)
|
||||||
step = await agent.create_and_execute_step(task.task_id, step_request)
|
step = await agent.execute_step(task.task_id, step_request)
|
||||||
assert step.input == "step_input"
|
assert step.input == "step_input"
|
||||||
assert step.additional_input == {"input": "additional_test_input"}
|
assert step.additional_input == {"input": "additional_test_input"}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
@pytest.mark.xfail(reason="execute_step is not implemented")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_step(agent):
|
async def test_get_step(agent: Agent):
|
||||||
task_request = TaskRequestBody(
|
task_request = TaskRequestBody(
|
||||||
input="test_input", additional_input={"input": "additional_test_input"}
|
input="test_input", additional_input={"input": "additional_test_input"}
|
||||||
)
|
)
|
||||||
@ -80,38 +86,52 @@ async def test_get_step(agent):
|
|||||||
step_request = StepRequestBody(
|
step_request = StepRequestBody(
|
||||||
input="step_input", additional_input={"input": "additional_test_input"}
|
input="step_input", additional_input={"input": "additional_test_input"}
|
||||||
)
|
)
|
||||||
step = await agent.create_and_execute_step(task.task_id, step_request)
|
step = await agent.execute_step(task.task_id, step_request)
|
||||||
retrieved_step = await agent.get_step(task.task_id, step.step_id)
|
retrieved_step = await agent.get_step(task.task_id, step.step_id)
|
||||||
assert retrieved_step.step_id == step.step_id
|
assert retrieved_step.step_id == step.step_id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_artifacts(agent):
|
async def test_list_artifacts(agent: Agent):
|
||||||
artifacts = await agent.list_artifacts()
|
tasks = await agent.list_tasks()
|
||||||
assert isinstance(artifacts, list)
|
assert tasks.tasks, "No tasks in test.db"
|
||||||
|
|
||||||
|
artifacts = await agent.list_artifacts(tasks.tasks[0].task_id)
|
||||||
|
assert isinstance(artifacts.artifacts, list)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_artifact(agent):
|
async def test_create_artifact(agent: Agent, file_upload: UploadFile):
|
||||||
task_request = TaskRequestBody(
|
task_request = TaskRequestBody(
|
||||||
input="test_input", additional_input={"input": "additional_test_input"}
|
input="test_input", additional_input={"input": "additional_test_input"}
|
||||||
)
|
)
|
||||||
task = await agent.create_task(task_request)
|
task = await agent.create_task(task_request)
|
||||||
artifact_request = ArtifactRequestBody(file=None, uri="test_uri")
|
artifact = await agent.create_artifact(
|
||||||
artifact = await agent.create_artifact(task.task_id, artifact_request)
|
task_id=task.task_id,
|
||||||
assert artifact.uri == "test_uri"
|
file=file_upload,
|
||||||
|
relative_path=f"a_dir/{file_upload.filename}",
|
||||||
|
)
|
||||||
|
assert artifact.file_name == file_upload.filename
|
||||||
|
assert artifact.relative_path == f"a_dir/{file_upload.filename}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_artifact(agent):
|
async def test_create_and_get_artifact(agent: Agent, file_upload: UploadFile):
|
||||||
task_request = TaskRequestBody(
|
task_request = TaskRequestBody(
|
||||||
input="test_input", additional_input={"input": "additional_test_input"}
|
input="test_input", additional_input={"input": "additional_test_input"}
|
||||||
)
|
)
|
||||||
task = await agent.create_task(task_request)
|
task = await agent.create_task(task_request)
|
||||||
artifact_request = ArtifactRequestBody(file=None, uri="test_uri")
|
|
||||||
artifact = await agent.create_artifact(task.task_id, artifact_request)
|
artifact = await agent.create_artifact(
|
||||||
|
task_id=task.task_id,
|
||||||
|
file=file_upload,
|
||||||
|
relative_path=f"b_dir/{file_upload.filename}",
|
||||||
|
)
|
||||||
|
await file_upload.seek(0)
|
||||||
|
file_upload_content = await file_upload.read()
|
||||||
|
|
||||||
retrieved_artifact = await agent.get_artifact(task.task_id, artifact.artifact_id)
|
retrieved_artifact = await agent.get_artifact(task.task_id, artifact.artifact_id)
|
||||||
assert retrieved_artifact.artifact_id == artifact.artifact_id
|
retrieved_artifact_content = bytearray()
|
||||||
|
async for b in retrieved_artifact.body_iterator:
|
||||||
|
retrieved_artifact_content.extend(b) # type: ignore
|
||||||
|
assert retrieved_artifact_content == file_upload_content
|
||||||
|
@ -5,22 +5,21 @@ import inspect
|
|||||||
import logging
|
import logging
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
Any,
|
||||||
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
|
Generic,
|
||||||
Iterator,
|
Iterator,
|
||||||
Optional,
|
Optional,
|
||||||
ParamSpec,
|
ParamSpec,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
|
cast,
|
||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
from colorama import Fore
|
from colorama import Fore
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from forge.models.action import ActionProposal, ActionResult
|
|
||||||
|
|
||||||
from forge.agent import protocols
|
from forge.agent import protocols
|
||||||
from forge.agent.components import (
|
from forge.agent.components import (
|
||||||
AgentComponent,
|
AgentComponent,
|
||||||
@ -29,15 +28,10 @@ from forge.agent.components import (
|
|||||||
)
|
)
|
||||||
from forge.config.ai_directives import AIDirectives
|
from forge.config.ai_directives import AIDirectives
|
||||||
from forge.config.ai_profile import AIProfile
|
from forge.config.ai_profile import AIProfile
|
||||||
from forge.config.config import ConfigBuilder
|
|
||||||
from forge.llm.providers import CHAT_MODELS, ModelName, OpenAIModelName
|
from forge.llm.providers import CHAT_MODELS, ModelName, OpenAIModelName
|
||||||
from forge.llm.providers.schema import ChatModelInfo
|
from forge.llm.providers.schema import ChatModelInfo
|
||||||
from forge.models.config import (
|
from forge.models.action import ActionResult, AnyProposal
|
||||||
Configurable,
|
from forge.models.config import SystemConfiguration, SystemSettings, UserConfigurable
|
||||||
SystemConfiguration,
|
|
||||||
SystemSettings,
|
|
||||||
UserConfigurable,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -133,17 +127,7 @@ class AgentMeta(ABCMeta):
|
|||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAgent(Generic[AnyProposal], metaclass=AgentMeta):
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
|
||||||
C = TypeVar("C", bound=AgentComponent)
|
|
||||||
|
|
||||||
default_settings = BaseAgentSettings(
|
|
||||||
name="BaseAgent",
|
|
||||||
description=__doc__ if __doc__ else "",
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
settings: BaseAgentSettings,
|
settings: BaseAgentSettings,
|
||||||
@ -173,13 +157,13 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
|||||||
return self.config.send_token_limit or self.llm.max_tokens * 3 // 4
|
return self.config.send_token_limit or self.llm.max_tokens * 3 // 4
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def propose_action(self) -> ActionProposal:
|
async def propose_action(self) -> AnyProposal:
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self,
|
||||||
proposal: ActionProposal,
|
proposal: AnyProposal,
|
||||||
user_feedback: str = "",
|
user_feedback: str = "",
|
||||||
) -> ActionResult:
|
) -> ActionResult:
|
||||||
...
|
...
|
||||||
@ -187,7 +171,7 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def do_not_execute(
|
async def do_not_execute(
|
||||||
self,
|
self,
|
||||||
denied_proposal: ActionProposal,
|
denied_proposal: AnyProposal,
|
||||||
user_feedback: str,
|
user_feedback: str,
|
||||||
) -> ActionResult:
|
) -> ActionResult:
|
||||||
...
|
...
|
||||||
@ -203,13 +187,16 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
|||||||
|
|
||||||
@overload
|
@overload
|
||||||
async def run_pipeline(
|
async def run_pipeline(
|
||||||
self, protocol_method: Callable[P, None], *args, retry_limit: int = 3
|
self,
|
||||||
|
protocol_method: Callable[P, None | Awaitable[None]],
|
||||||
|
*args,
|
||||||
|
retry_limit: int = 3,
|
||||||
) -> list[None]:
|
) -> list[None]:
|
||||||
...
|
...
|
||||||
|
|
||||||
async def run_pipeline(
|
async def run_pipeline(
|
||||||
self,
|
self,
|
||||||
protocol_method: Callable[P, Iterator[T] | None],
|
protocol_method: Callable[P, Iterator[T] | None | Awaitable[None]],
|
||||||
*args,
|
*args,
|
||||||
retry_limit: int = 3,
|
retry_limit: int = 3,
|
||||||
) -> list[T] | list[None]:
|
) -> list[T] | list[None]:
|
||||||
@ -240,7 +227,10 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
method = getattr(component, method_name, None)
|
method = cast(
|
||||||
|
Callable[..., Iterator[T] | None | Awaitable[None]] | None,
|
||||||
|
getattr(component, method_name, None),
|
||||||
|
)
|
||||||
if not callable(method):
|
if not callable(method):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -248,10 +238,9 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
|||||||
while component_attempts < retry_limit:
|
while component_attempts < retry_limit:
|
||||||
try:
|
try:
|
||||||
component_args = self._selective_copy(args)
|
component_args = self._selective_copy(args)
|
||||||
if inspect.iscoroutinefunction(method):
|
result = method(*component_args)
|
||||||
result = await method(*component_args)
|
if inspect.isawaitable(result):
|
||||||
else:
|
result = await result
|
||||||
result = method(*component_args)
|
|
||||||
if result is not None:
|
if result is not None:
|
||||||
method_result.extend(result)
|
method_result.extend(result)
|
||||||
args = component_args
|
args = component_args
|
||||||
@ -269,9 +258,9 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
|||||||
break
|
break
|
||||||
# Successful pipeline execution
|
# Successful pipeline execution
|
||||||
break
|
break
|
||||||
except EndpointPipelineError:
|
except EndpointPipelineError as e:
|
||||||
self._trace.append(
|
self._trace.append(
|
||||||
f"❌ {Fore.LIGHTRED_EX}{component.__class__.__name__}: "
|
f"❌ {Fore.LIGHTRED_EX}{e.triggerer.__class__.__name__}: "
|
||||||
f"EndpointPipelineError{Fore.RESET}"
|
f"EndpointPipelineError{Fore.RESET}"
|
||||||
)
|
)
|
||||||
# Restart from the beginning on EndpointPipelineError
|
# Restart from the beginning on EndpointPipelineError
|
||||||
|
@ -36,8 +36,9 @@ class AgentComponent(ABC):
|
|||||||
class ComponentEndpointError(Exception):
|
class ComponentEndpointError(Exception):
|
||||||
"""Error of a single protocol method on a component."""
|
"""Error of a single protocol method on a component."""
|
||||||
|
|
||||||
def __init__(self, message: str = ""):
|
def __init__(self, message: str, component: AgentComponent):
|
||||||
self.message = message
|
self.message = message
|
||||||
|
self.triggerer = component
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import TYPE_CHECKING, Iterator
|
from typing import TYPE_CHECKING, Awaitable, Generic, Iterator
|
||||||
|
|
||||||
|
from forge.models.action import ActionResult, AnyProposal
|
||||||
|
|
||||||
from .components import AgentComponent
|
from .components import AgentComponent
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from forge.command.command import Command
|
from forge.command.command import Command
|
||||||
from forge.llm.providers import ChatMessage
|
from forge.llm.providers import ChatMessage
|
||||||
from forge.models.action import ActionResult
|
|
||||||
|
|
||||||
from .base import ActionProposal
|
|
||||||
|
|
||||||
|
|
||||||
class DirectiveProvider(AgentComponent):
|
class DirectiveProvider(AgentComponent):
|
||||||
@ -34,19 +33,19 @@ class MessageProvider(AgentComponent):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class AfterParse(AgentComponent):
|
class AfterParse(AgentComponent, Generic[AnyProposal]):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def after_parse(self, result: "ActionProposal") -> None:
|
def after_parse(self, result: AnyProposal) -> None | Awaitable[None]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class ExecutionFailure(AgentComponent):
|
class ExecutionFailure(AgentComponent):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def execution_failure(self, error: Exception) -> None:
|
def execution_failure(self, error: Exception) -> None | Awaitable[None]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class AfterExecute(AgentComponent):
|
class AfterExecute(AgentComponent):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def after_execute(self, result: "ActionResult") -> None:
|
def after_execute(self, result: "ActionResult") -> None | Awaitable[None]:
|
||||||
...
|
...
|
||||||
|
@ -1,39 +1,16 @@
|
|||||||
"""
|
"""
|
||||||
Routes for the Agent Service.
|
Routes for the Agent Service.
|
||||||
|
|
||||||
This module defines the API routes for the Agent service. While there are multiple endpoints provided by the service,
|
This module defines the API routes for the Agent service.
|
||||||
the ones that require special attention due to their complexity are:
|
|
||||||
|
|
||||||
1. `execute_agent_task_step`:
|
Developers and contributors should be especially careful when making modifications
|
||||||
This route is significant because this is where the agent actually performs the work. The function handles
|
to these routes to ensure consistency and correctness in the system's behavior.
|
||||||
executing the next step for a task based on its current state, and it requires careful implementation to ensure
|
|
||||||
all scenarios (like the presence or absence of steps or a step marked as `last_step`) are handled correctly.
|
|
||||||
|
|
||||||
2. `upload_agent_task_artifacts`:
|
|
||||||
This route allows for the upload of artifacts, supporting various URI types (e.g., s3, gcs, ftp, http).
|
|
||||||
The support for different URI types makes it a bit more complex, and it's important to ensure that all
|
|
||||||
supported URI types are correctly managed. NOTE: The AutoGPT team will eventually handle the most common
|
|
||||||
uri types for you.
|
|
||||||
|
|
||||||
3. `create_agent_task`:
|
|
||||||
While this is a simpler route, it plays a crucial role in the workflow, as it's responsible for the creation
|
|
||||||
of a new task.
|
|
||||||
|
|
||||||
Developers and contributors should be especially careful when making modifications to these routes to ensure
|
|
||||||
consistency and correctness in the system's behavior.
|
|
||||||
"""
|
"""
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Query, Request, Response, UploadFile
|
from fastapi import APIRouter, HTTPException, Query, Request, Response, UploadFile
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from forge.utils.exceptions import (
|
|
||||||
NotFoundError,
|
|
||||||
get_detailed_traceback,
|
|
||||||
get_exception_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .models import (
|
from .models import (
|
||||||
Artifact,
|
Artifact,
|
||||||
@ -46,6 +23,9 @@ from .models import (
|
|||||||
TaskStepsListResponse,
|
TaskStepsListResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from forge.agent.agent import Agent
|
||||||
|
|
||||||
base_router = APIRouter()
|
base_router = APIRouter()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -73,10 +53,10 @@ async def create_agent_task(request: Request, task_request: TaskRequestBody) ->
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
request (Request): FastAPI request object.
|
request (Request): FastAPI request object.
|
||||||
task (TaskRequestBody): The task request containing input and additional input data.
|
task (TaskRequestBody): The task request containing input data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Task: A new task with task_id, input, additional_input, and empty lists for artifacts and steps.
|
Task: A new task with task_id, input, and additional_input set.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
Request (TaskRequestBody defined in schema.py):
|
Request (TaskRequestBody defined in schema.py):
|
||||||
@ -93,46 +73,32 @@ async def create_agent_task(request: Request, task_request: TaskRequestBody) ->
|
|||||||
"artifacts": [],
|
"artifacts": [],
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
agent = request["agent"]
|
agent: "Agent" = request["agent"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
task_request = await agent.create_task(task_request)
|
task = await agent.create_task(task_request)
|
||||||
return Response(
|
return task
|
||||||
content=task_request.json(),
|
|
||||||
status_code=200,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Error whilst trying to create a task: {task_request}")
|
logger.exception(f"Error whilst trying to create a task: {task_request}")
|
||||||
return Response(
|
raise
|
||||||
content=json.dumps(
|
|
||||||
{
|
|
||||||
"error": "Internal server error",
|
|
||||||
"exception": get_exception_message(),
|
|
||||||
"traceback": get_detailed_traceback(),
|
|
||||||
}
|
|
||||||
),
|
|
||||||
status_code=500,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@base_router.get("/agent/tasks", tags=["agent"], response_model=TaskListResponse)
|
@base_router.get("/agent/tasks", tags=["agent"], response_model=TaskListResponse)
|
||||||
async def list_agent_tasks(
|
async def list_agent_tasks(
|
||||||
request: Request,
|
request: Request,
|
||||||
page: Optional[int] = Query(1, ge=1),
|
page: int = Query(1, ge=1),
|
||||||
page_size: Optional[int] = Query(10, ge=1),
|
page_size: int = Query(10, ge=1),
|
||||||
) -> TaskListResponse:
|
) -> TaskListResponse:
|
||||||
"""
|
"""
|
||||||
Retrieves a paginated list of all tasks.
|
Retrieves a paginated list of all tasks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request (Request): FastAPI request object.
|
request (Request): FastAPI request object.
|
||||||
page (int, optional): The page number for pagination. Defaults to 1.
|
page (int, optional): Page number for pagination. Default: 1
|
||||||
page_size (int, optional): The number of tasks per page for pagination. Defaults to 10.
|
page_size (int, optional): Number of tasks per page for pagination. Default: 10
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
TaskListResponse: A response object containing a list of tasks and pagination details.
|
TaskListResponse: A list of tasks, and pagination details.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
Request:
|
Request:
|
||||||
@ -158,34 +124,13 @@ async def list_agent_tasks(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
agent = request["agent"]
|
agent: "Agent" = request["agent"]
|
||||||
try:
|
try:
|
||||||
tasks = await agent.list_tasks(page, page_size)
|
tasks = await agent.list_tasks(page, page_size)
|
||||||
return Response(
|
return tasks
|
||||||
content=tasks.json(),
|
|
||||||
status_code=200,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
except NotFoundError:
|
|
||||||
logger.exception("Error whilst trying to list tasks")
|
|
||||||
return Response(
|
|
||||||
content=json.dumps({"error": "Tasks not found"}),
|
|
||||||
status_code=404,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error whilst trying to list tasks")
|
logger.exception("Error whilst trying to list tasks")
|
||||||
return Response(
|
raise
|
||||||
content=json.dumps(
|
|
||||||
{
|
|
||||||
"error": "Internal server error",
|
|
||||||
"exception": get_exception_message(),
|
|
||||||
"traceback": get_detailed_traceback(),
|
|
||||||
}
|
|
||||||
),
|
|
||||||
status_code=500,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@base_router.get("/agent/tasks/{task_id}", tags=["agent"], response_model=Task)
|
@base_router.get("/agent/tasks/{task_id}", tags=["agent"], response_model=Task)
|
||||||
@ -239,36 +184,14 @@ async def get_agent_task(request: Request, task_id: str) -> Task:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
"""
|
""" # noqa: E501
|
||||||
agent = request["agent"]
|
agent: "Agent" = request["agent"]
|
||||||
try:
|
try:
|
||||||
task = await agent.get_task(task_id)
|
task = await agent.get_task(task_id)
|
||||||
|
return task
|
||||||
return Response(
|
|
||||||
content=task.json(),
|
|
||||||
status_code=200,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
except NotFoundError:
|
|
||||||
logger.exception(f"Error whilst trying to get task: {task_id}")
|
|
||||||
return Response(
|
|
||||||
content=json.dumps({"error": "Task not found"}),
|
|
||||||
status_code=404,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Error whilst trying to get task: {task_id}")
|
logger.exception(f"Error whilst trying to get task: {task_id}")
|
||||||
return Response(
|
raise
|
||||||
content=json.dumps(
|
|
||||||
{
|
|
||||||
"error": "Internal server error",
|
|
||||||
"exception": get_exception_message(),
|
|
||||||
"traceback": get_detailed_traceback(),
|
|
||||||
}
|
|
||||||
),
|
|
||||||
status_code=500,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@base_router.get(
|
@base_router.get(
|
||||||
@ -279,8 +202,8 @@ async def get_agent_task(request: Request, task_id: str) -> Task:
|
|||||||
async def list_agent_task_steps(
|
async def list_agent_task_steps(
|
||||||
request: Request,
|
request: Request,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
page: Optional[int] = Query(1, ge=1),
|
page: int = Query(1, ge=1),
|
||||||
page_size: Optional[int] = Query(10, ge=1, alias="pageSize"),
|
page_size: int = Query(10, ge=1, alias="pageSize"),
|
||||||
) -> TaskStepsListResponse:
|
) -> TaskStepsListResponse:
|
||||||
"""
|
"""
|
||||||
Retrieves a paginated list of steps associated with a specific task.
|
Retrieves a paginated list of steps associated with a specific task.
|
||||||
@ -289,10 +212,10 @@ async def list_agent_task_steps(
|
|||||||
request (Request): FastAPI request object.
|
request (Request): FastAPI request object.
|
||||||
task_id (str): The ID of the task.
|
task_id (str): The ID of the task.
|
||||||
page (int, optional): The page number for pagination. Defaults to 1.
|
page (int, optional): The page number for pagination. Defaults to 1.
|
||||||
page_size (int, optional): The number of steps per page for pagination. Defaults to 10.
|
page_size (int, optional): Number of steps per page for pagination. Default: 10.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
TaskStepsListResponse: A response object containing a list of steps and pagination details.
|
TaskStepsListResponse: A list of steps, and pagination details.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
Request:
|
Request:
|
||||||
@ -315,54 +238,40 @@ async def list_agent_task_steps(
|
|||||||
"pageSize": 10
|
"pageSize": 10
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
""" # noqa: E501
|
||||||
agent = request["agent"]
|
agent: "Agent" = request["agent"]
|
||||||
try:
|
try:
|
||||||
steps = await agent.list_steps(task_id, page, page_size)
|
steps = await agent.list_steps(task_id, page, page_size)
|
||||||
return Response(
|
return steps
|
||||||
content=steps.json(),
|
|
||||||
status_code=200,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
except NotFoundError:
|
|
||||||
logger.exception("Error whilst trying to list steps")
|
|
||||||
return Response(
|
|
||||||
content=json.dumps({"error": "Steps not found"}),
|
|
||||||
status_code=404,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error whilst trying to list steps")
|
logger.exception("Error whilst trying to list steps")
|
||||||
return Response(
|
raise
|
||||||
content=json.dumps(
|
|
||||||
{
|
|
||||||
"error": "Internal server error",
|
|
||||||
"exception": get_exception_message(),
|
|
||||||
"traceback": get_detailed_traceback(),
|
|
||||||
}
|
|
||||||
),
|
|
||||||
status_code=500,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@base_router.post("/agent/tasks/{task_id}/steps", tags=["agent"], response_model=Step)
|
@base_router.post("/agent/tasks/{task_id}/steps", tags=["agent"], response_model=Step)
|
||||||
async def execute_agent_task_step(
|
async def execute_agent_task_step(
|
||||||
request: Request, task_id: str, step: Optional[StepRequestBody] = None
|
request: Request, task_id: str, step_request: Optional[StepRequestBody] = None
|
||||||
) -> Step:
|
) -> Step:
|
||||||
"""
|
"""
|
||||||
Executes the next step for a specified task based on the current task status and returns the
|
Executes the next step for a specified task based on the current task status and
|
||||||
executed step with additional feedback fields.
|
returns the executed step with additional feedback fields.
|
||||||
|
|
||||||
Depending on the current state of the task, the following scenarios are supported:
|
This route is significant because this is where the agent actually performs work.
|
||||||
|
The function handles executing the next step for a task based on its current state,
|
||||||
|
and it requires careful implementation to ensure all scenarios (like the presence
|
||||||
|
or absence of steps or a step marked as `last_step`) are handled correctly.
|
||||||
|
|
||||||
|
Depending on the current state of the task, the following scenarios are possible:
|
||||||
1. No steps exist for the task.
|
1. No steps exist for the task.
|
||||||
2. There is at least one step already for the task, and the task does not have a completed step marked as `last_step`.
|
2. There is at least one step already for the task, and the task does not have a
|
||||||
|
completed step marked as `last_step`.
|
||||||
3. There is a completed step marked as `last_step` already on the task.
|
3. There is a completed step marked as `last_step` already on the task.
|
||||||
|
|
||||||
In each of these scenarios, a step object will be returned with two additional fields: `output` and `additional_output`.
|
In each of these scenarios, a step object will be returned with two additional
|
||||||
|
fields: `output` and `additional_output`.
|
||||||
- `output`: Provides the primary response or feedback to the user.
|
- `output`: Provides the primary response or feedback to the user.
|
||||||
- `additional_output`: Supplementary information or data. Its specific content is not strictly defined and can vary based on the step or agent's implementation.
|
- `additional_output`: Supplementary information or data. Its specific content is
|
||||||
|
not strictly defined and can vary based on the step or agent's implementation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request (Request): FastAPI request object.
|
request (Request): FastAPI request object.
|
||||||
@ -389,39 +298,17 @@ async def execute_agent_task_step(
|
|||||||
...
|
...
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
agent = request["agent"]
|
agent: "Agent" = request["agent"]
|
||||||
try:
|
try:
|
||||||
# An empty step request represents a yes to continue command
|
# An empty step request represents a yes to continue command
|
||||||
if not step:
|
if not step_request:
|
||||||
step = StepRequestBody(input="y")
|
step_request = StepRequestBody(input="y")
|
||||||
|
|
||||||
step = await agent.execute_step(task_id, step)
|
step = await agent.execute_step(task_id, step_request)
|
||||||
|
return step
|
||||||
return Response(
|
|
||||||
content=step.json(),
|
|
||||||
status_code=200,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
except NotFoundError:
|
|
||||||
logger.exception(f"Error whilst trying to execute a task step: {task_id}")
|
|
||||||
return Response(
|
|
||||||
content=json.dumps({"error": f"Task not found {task_id}"}),
|
|
||||||
status_code=404,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Error whilst trying to execute a task step: {task_id}")
|
logger.exception(f"Error whilst trying to execute a task step: {task_id}")
|
||||||
return Response(
|
raise
|
||||||
content=json.dumps(
|
|
||||||
{
|
|
||||||
"error": "Internal server error",
|
|
||||||
"exception": get_exception_message(),
|
|
||||||
"traceback": get_detailed_traceback(),
|
|
||||||
}
|
|
||||||
),
|
|
||||||
status_code=500,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@base_router.get(
|
@base_router.get(
|
||||||
@ -450,31 +337,13 @@ async def get_agent_task_step(request: Request, task_id: str, step_id: str) -> S
|
|||||||
...
|
...
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
agent = request["agent"]
|
agent: "Agent" = request["agent"]
|
||||||
try:
|
try:
|
||||||
step = await agent.get_step(task_id, step_id)
|
step = await agent.get_step(task_id, step_id)
|
||||||
|
return step
|
||||||
return Response(content=step.json(), status_code=200)
|
|
||||||
except NotFoundError:
|
|
||||||
logger.exception(f"Error whilst trying to get step: {step_id}")
|
|
||||||
return Response(
|
|
||||||
content=json.dumps({"error": "Step not found"}),
|
|
||||||
status_code=404,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Error whilst trying to get step: {step_id}")
|
logger.exception(f"Error whilst trying to get step: {step_id}")
|
||||||
return Response(
|
raise
|
||||||
content=json.dumps(
|
|
||||||
{
|
|
||||||
"error": "Internal server error",
|
|
||||||
"exception": get_exception_message(),
|
|
||||||
"traceback": get_detailed_traceback(),
|
|
||||||
}
|
|
||||||
),
|
|
||||||
status_code=500,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@base_router.get(
|
@base_router.get(
|
||||||
@ -485,8 +354,8 @@ async def get_agent_task_step(request: Request, task_id: str, step_id: str) -> S
|
|||||||
async def list_agent_task_artifacts(
|
async def list_agent_task_artifacts(
|
||||||
request: Request,
|
request: Request,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
page: Optional[int] = Query(1, ge=1),
|
page: int = Query(1, ge=1),
|
||||||
page_size: Optional[int] = Query(10, ge=1, alias="pageSize"),
|
page_size: int = Query(10, ge=1, alias="pageSize"),
|
||||||
) -> TaskArtifactsListResponse:
|
) -> TaskArtifactsListResponse:
|
||||||
"""
|
"""
|
||||||
Retrieves a paginated list of artifacts associated with a specific task.
|
Retrieves a paginated list of artifacts associated with a specific task.
|
||||||
@ -495,10 +364,10 @@ async def list_agent_task_artifacts(
|
|||||||
request (Request): FastAPI request object.
|
request (Request): FastAPI request object.
|
||||||
task_id (str): The ID of the task.
|
task_id (str): The ID of the task.
|
||||||
page (int, optional): The page number for pagination. Defaults to 1.
|
page (int, optional): The page number for pagination. Defaults to 1.
|
||||||
page_size (int, optional): The number of items per page for pagination. Defaults to 10.
|
page_size (int, optional): Number of items per page for pagination. Default: 10.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
TaskArtifactsListResponse: A response object containing a list of artifacts and pagination details.
|
TaskArtifactsListResponse: A list of artifacts, and pagination details.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
Request:
|
Request:
|
||||||
@ -518,52 +387,33 @@ async def list_agent_task_artifacts(
|
|||||||
"pageSize": 10
|
"pageSize": 10
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
""" # noqa: E501
|
||||||
agent = request["agent"]
|
agent: "Agent" = request["agent"]
|
||||||
try:
|
try:
|
||||||
artifacts: TaskArtifactsListResponse = await agent.list_artifacts(
|
artifacts = await agent.list_artifacts(task_id, page, page_size)
|
||||||
task_id, page, page_size
|
|
||||||
)
|
|
||||||
return artifacts
|
return artifacts
|
||||||
except NotFoundError:
|
|
||||||
logger.exception("Error whilst trying to list artifacts")
|
|
||||||
return Response(
|
|
||||||
content=json.dumps({"error": "Artifacts not found for task_id"}),
|
|
||||||
status_code=404,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error whilst trying to list artifacts")
|
logger.exception("Error whilst trying to list artifacts")
|
||||||
return Response(
|
raise
|
||||||
content=json.dumps(
|
|
||||||
{
|
|
||||||
"error": "Internal server error",
|
|
||||||
"exception": get_exception_message(),
|
|
||||||
"traceback": get_detailed_traceback(),
|
|
||||||
}
|
|
||||||
),
|
|
||||||
status_code=500,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@base_router.post(
|
@base_router.post(
|
||||||
"/agent/tasks/{task_id}/artifacts", tags=["agent"], response_model=Artifact
|
"/agent/tasks/{task_id}/artifacts", tags=["agent"], response_model=Artifact
|
||||||
)
|
)
|
||||||
async def upload_agent_task_artifacts(
|
async def upload_agent_task_artifacts(
|
||||||
request: Request, task_id: str, file: UploadFile, relative_path: Optional[str] = ""
|
request: Request, task_id: str, file: UploadFile, relative_path: str = ""
|
||||||
) -> Artifact:
|
) -> Artifact:
|
||||||
"""
|
"""
|
||||||
This endpoint is used to upload an artifact associated with a specific task. The artifact is provided as a file.
|
This endpoint is used to upload an artifact (file) associated with a specific task.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request (Request): The FastAPI request object.
|
request (Request): The FastAPI request object.
|
||||||
task_id (str): The unique identifier of the task for which the artifact is being uploaded.
|
task_id (str): The ID of the task for which the artifact is being uploaded.
|
||||||
file (UploadFile): The file being uploaded as an artifact.
|
file (UploadFile): The file being uploaded as an artifact.
|
||||||
relative_path (str): The relative path for the file. This is a query parameter.
|
relative_path (str): The relative path for the file. This is a query parameter.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Artifact: An object containing metadata of the uploaded artifact, including its unique identifier.
|
Artifact: Metadata object for the uploaded artifact, including its ID and path.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
Request:
|
Request:
|
||||||
@ -579,35 +429,17 @@ async def upload_agent_task_artifacts(
|
|||||||
"relative_path": "/my_folder/my_other_folder/",
|
"relative_path": "/my_folder/my_other_folder/",
|
||||||
"file_name": "main.py"
|
"file_name": "main.py"
|
||||||
}
|
}
|
||||||
"""
|
""" # noqa: E501
|
||||||
agent = request["agent"]
|
agent: "Agent" = request["agent"]
|
||||||
|
|
||||||
if file is None:
|
if file is None:
|
||||||
return Response(
|
raise HTTPException(status_code=400, detail="File must be specified")
|
||||||
content=json.dumps({"error": "File must be specified"}),
|
|
||||||
status_code=404,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
artifact = await agent.create_artifact(task_id, file, relative_path)
|
artifact = await agent.create_artifact(task_id, file, relative_path)
|
||||||
return Response(
|
return artifact
|
||||||
content=artifact.json(),
|
|
||||||
status_code=200,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Error whilst trying to upload artifact: {task_id}")
|
logger.exception(f"Error whilst trying to upload artifact: {task_id}")
|
||||||
return Response(
|
raise
|
||||||
content=json.dumps(
|
|
||||||
{
|
|
||||||
"error": "Internal server error",
|
|
||||||
"exception": get_exception_message(),
|
|
||||||
"traceback": get_detailed_traceback(),
|
|
||||||
}
|
|
||||||
),
|
|
||||||
status_code=500,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@base_router.get(
|
@base_router.get(
|
||||||
@ -617,7 +449,7 @@ async def upload_agent_task_artifacts(
|
|||||||
)
|
)
|
||||||
async def download_agent_task_artifact(
|
async def download_agent_task_artifact(
|
||||||
request: Request, task_id: str, artifact_id: str
|
request: Request, task_id: str, artifact_id: str
|
||||||
) -> FileResponse:
|
) -> StreamingResponse:
|
||||||
"""
|
"""
|
||||||
Downloads an artifact associated with a specific task.
|
Downloads an artifact associated with a specific task.
|
||||||
|
|
||||||
@ -636,32 +468,9 @@ async def download_agent_task_artifact(
|
|||||||
Response:
|
Response:
|
||||||
<file_content_of_artifact>
|
<file_content_of_artifact>
|
||||||
"""
|
"""
|
||||||
agent = request["agent"]
|
agent: "Agent" = request["agent"]
|
||||||
try:
|
try:
|
||||||
return await agent.get_artifact(task_id, artifact_id)
|
return await agent.get_artifact(task_id, artifact_id)
|
||||||
except NotFoundError:
|
|
||||||
logger.exception(f"Error whilst trying to download artifact: {task_id}")
|
|
||||||
return Response(
|
|
||||||
content=json.dumps(
|
|
||||||
{
|
|
||||||
"error": f"Artifact not found "
|
|
||||||
"- task_id: {task_id}, artifact_id: {artifact_id}"
|
|
||||||
}
|
|
||||||
),
|
|
||||||
status_code=404,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Error whilst trying to download artifact: {task_id}")
|
logger.exception(f"Error whilst trying to download artifact: {task_id}")
|
||||||
return Response(
|
raise
|
||||||
content=json.dumps(
|
|
||||||
{
|
|
||||||
"error": f"Internal server error "
|
|
||||||
"- task_id: {task_id}, artifact_id: {artifact_id}",
|
|
||||||
"exception": get_exception_message(),
|
|
||||||
"traceback": get_detailed_traceback(),
|
|
||||||
}
|
|
||||||
),
|
|
||||||
status_code=500,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
|
@ -1 +1,3 @@
|
|||||||
from .db import AgentDB
|
from .db import AgentDB
|
||||||
|
|
||||||
|
__all__ = ["AgentDB"]
|
||||||
|
@ -4,23 +4,22 @@ It uses SQLite as the database and file store backend.
|
|||||||
IT IS NOT ADVISED TO USE THIS IN PRODUCTION!
|
IT IS NOT ADVISED TO USE THIS IN PRODUCTION!
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import datetime
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import uuid
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
from typing import Any, Dict, List, Literal, Optional, Tuple
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import JSON, Boolean, DateTime, ForeignKey, create_engine
|
||||||
JSON,
|
|
||||||
Boolean,
|
|
||||||
Column,
|
|
||||||
DateTime,
|
|
||||||
ForeignKey,
|
|
||||||
String,
|
|
||||||
create_engine,
|
|
||||||
)
|
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
from sqlalchemy.orm import DeclarativeBase, joinedload, relationship, sessionmaker
|
from sqlalchemy.orm import (
|
||||||
|
DeclarativeBase,
|
||||||
|
Mapped,
|
||||||
|
joinedload,
|
||||||
|
mapped_column,
|
||||||
|
relationship,
|
||||||
|
sessionmaker,
|
||||||
|
)
|
||||||
|
|
||||||
from forge.utils.exceptions import NotFoundError
|
from forge.utils.exceptions import NotFoundError
|
||||||
|
|
||||||
@ -32,18 +31,20 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(DeclarativeBase):
|
||||||
pass
|
type_annotation_map = {
|
||||||
|
dict[str, Any]: JSON,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class TaskModel(Base):
|
class TaskModel(Base):
|
||||||
__tablename__ = "tasks"
|
__tablename__ = "tasks"
|
||||||
|
|
||||||
task_id = Column(String, primary_key=True, index=True)
|
task_id: Mapped[str] = mapped_column(primary_key=True, index=True)
|
||||||
input = Column(String)
|
input: Mapped[str]
|
||||||
additional_input = Column(JSON)
|
additional_input: Mapped[dict[str, Any]] = mapped_column(default=dict)
|
||||||
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||||
modified_at = Column(
|
modified_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
|
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
||||||
)
|
)
|
||||||
|
|
||||||
artifacts = relationship("ArtifactModel", back_populates="task")
|
artifacts = relationship("ArtifactModel", back_populates="task")
|
||||||
@ -52,35 +53,35 @@ class TaskModel(Base):
|
|||||||
class StepModel(Base):
|
class StepModel(Base):
|
||||||
__tablename__ = "steps"
|
__tablename__ = "steps"
|
||||||
|
|
||||||
step_id = Column(String, primary_key=True, index=True)
|
step_id: Mapped[str] = mapped_column(primary_key=True, index=True)
|
||||||
task_id = Column(String, ForeignKey("tasks.task_id"))
|
task_id: Mapped[str] = mapped_column(ForeignKey("tasks.task_id"))
|
||||||
name = Column(String)
|
name: Mapped[str]
|
||||||
input = Column(String)
|
input: Mapped[str]
|
||||||
status = Column(String)
|
status: Mapped[str]
|
||||||
output = Column(String)
|
output: Mapped[Optional[str]]
|
||||||
is_last = Column(Boolean, default=False)
|
is_last: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||||
modified_at = Column(
|
modified_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
|
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
||||||
)
|
)
|
||||||
|
|
||||||
additional_input = Column(JSON)
|
additional_input: Mapped[dict[str, Any]] = mapped_column(default=dict)
|
||||||
additional_output = Column(JSON)
|
additional_output: Mapped[Optional[dict[str, Any]]]
|
||||||
artifacts = relationship("ArtifactModel", back_populates="step")
|
artifacts = relationship("ArtifactModel", back_populates="step")
|
||||||
|
|
||||||
|
|
||||||
class ArtifactModel(Base):
|
class ArtifactModel(Base):
|
||||||
__tablename__ = "artifacts"
|
__tablename__ = "artifacts"
|
||||||
|
|
||||||
artifact_id = Column(String, primary_key=True, index=True)
|
artifact_id: Mapped[str] = mapped_column(primary_key=True, index=True)
|
||||||
task_id = Column(String, ForeignKey("tasks.task_id"))
|
task_id: Mapped[str] = mapped_column(ForeignKey("tasks.task_id"))
|
||||||
step_id = Column(String, ForeignKey("steps.step_id"))
|
step_id: Mapped[Optional[str]] = mapped_column(ForeignKey("steps.step_id"))
|
||||||
agent_created = Column(Boolean, default=False)
|
agent_created: Mapped[bool] = mapped_column(default=False)
|
||||||
file_name = Column(String)
|
file_name: Mapped[str]
|
||||||
relative_path = Column(String)
|
relative_path: Mapped[str]
|
||||||
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
created_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
|
||||||
modified_at = Column(
|
modified_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
|
default=datetime.utcnow, onupdate=datetime.utcnow
|
||||||
)
|
)
|
||||||
|
|
||||||
step = relationship("StepModel", back_populates="artifacts")
|
step = relationship("StepModel", back_populates="artifacts")
|
||||||
@ -150,6 +151,10 @@ class AgentDB:
|
|||||||
Base.metadata.create_all(self.engine)
|
Base.metadata.create_all(self.engine)
|
||||||
self.Session = sessionmaker(bind=self.engine)
|
self.Session = sessionmaker(bind=self.engine)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
self.Session.close_all()
|
||||||
|
self.engine.dispose()
|
||||||
|
|
||||||
async def create_task(
|
async def create_task(
|
||||||
self, input: Optional[str], additional_input: Optional[dict] = {}
|
self, input: Optional[str], additional_input: Optional[dict] = {}
|
||||||
) -> Task:
|
) -> Task:
|
||||||
@ -172,8 +177,6 @@ class AgentDB:
|
|||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
logger.error(f"SQLAlchemy error while creating task: {e}")
|
logger.error(f"SQLAlchemy error while creating task: {e}")
|
||||||
raise
|
raise
|
||||||
except NotFoundError as e:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error while creating task: {e}")
|
logger.error(f"Unexpected error while creating task: {e}")
|
||||||
raise
|
raise
|
||||||
@ -207,8 +210,6 @@ class AgentDB:
|
|||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
logger.error(f"SQLAlchemy error while creating step: {e}")
|
logger.error(f"SQLAlchemy error while creating step: {e}")
|
||||||
raise
|
raise
|
||||||
except NotFoundError as e:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error while creating step: {e}")
|
logger.error(f"Unexpected error while creating step: {e}")
|
||||||
raise
|
raise
|
||||||
@ -237,7 +238,7 @@ class AgentDB:
|
|||||||
session.close()
|
session.close()
|
||||||
if self.debug_enabled:
|
if self.debug_enabled:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Artifact already exists with relative_path: {relative_path}"
|
f"Artifact {file_name} already exists at {relative_path}/"
|
||||||
)
|
)
|
||||||
return convert_to_artifact(existing_artifact)
|
return convert_to_artifact(existing_artifact)
|
||||||
|
|
||||||
@ -254,14 +255,12 @@ class AgentDB:
|
|||||||
session.refresh(new_artifact)
|
session.refresh(new_artifact)
|
||||||
if self.debug_enabled:
|
if self.debug_enabled:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Created new artifact with artifact_id: {new_artifact.artifact_id}"
|
f"Created new artifact with ID: {new_artifact.artifact_id}"
|
||||||
)
|
)
|
||||||
return convert_to_artifact(new_artifact)
|
return convert_to_artifact(new_artifact)
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
logger.error(f"SQLAlchemy error while creating step: {e}")
|
logger.error(f"SQLAlchemy error while creating step: {e}")
|
||||||
raise
|
raise
|
||||||
except NotFoundError as e:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error while creating step: {e}")
|
logger.error(f"Unexpected error while creating step: {e}")
|
||||||
raise
|
raise
|
||||||
@ -285,8 +284,6 @@ class AgentDB:
|
|||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
logger.error(f"SQLAlchemy error while getting task: {e}")
|
logger.error(f"SQLAlchemy error while getting task: {e}")
|
||||||
raise
|
raise
|
||||||
except NotFoundError as e:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error while getting task: {e}")
|
logger.error(f"Unexpected error while getting task: {e}")
|
||||||
raise
|
raise
|
||||||
@ -312,8 +309,6 @@ class AgentDB:
|
|||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
logger.error(f"SQLAlchemy error while getting step: {e}")
|
logger.error(f"SQLAlchemy error while getting step: {e}")
|
||||||
raise
|
raise
|
||||||
except NotFoundError as e:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error while getting step: {e}")
|
logger.error(f"Unexpected error while getting step: {e}")
|
||||||
raise
|
raise
|
||||||
@ -337,8 +332,6 @@ class AgentDB:
|
|||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
logger.error(f"SQLAlchemy error while getting artifact: {e}")
|
logger.error(f"SQLAlchemy error while getting artifact: {e}")
|
||||||
raise
|
raise
|
||||||
except NotFoundError as e:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error while getting artifact: {e}")
|
logger.error(f"Unexpected error while getting artifact: {e}")
|
||||||
raise
|
raise
|
||||||
@ -375,14 +368,13 @@ class AgentDB:
|
|||||||
return await self.get_step(task_id, step_id)
|
return await self.get_step(task_id, step_id)
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Step not found for update with task_id: {task_id} and step_id: {step_id}"
|
"Can't update non-existent Step with "
|
||||||
|
f"task_id: {task_id} and step_id: {step_id}"
|
||||||
)
|
)
|
||||||
raise NotFoundError("Step not found")
|
raise NotFoundError("Step not found")
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
logger.error(f"SQLAlchemy error while getting step: {e}")
|
logger.error(f"SQLAlchemy error while getting step: {e}")
|
||||||
raise
|
raise
|
||||||
except NotFoundError as e:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error while getting step: {e}")
|
logger.error(f"Unexpected error while getting step: {e}")
|
||||||
raise
|
raise
|
||||||
@ -441,8 +433,6 @@ class AgentDB:
|
|||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
logger.error(f"SQLAlchemy error while listing tasks: {e}")
|
logger.error(f"SQLAlchemy error while listing tasks: {e}")
|
||||||
raise
|
raise
|
||||||
except NotFoundError as e:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error while listing tasks: {e}")
|
logger.error(f"Unexpected error while listing tasks: {e}")
|
||||||
raise
|
raise
|
||||||
@ -475,8 +465,6 @@ class AgentDB:
|
|||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
logger.error(f"SQLAlchemy error while listing steps: {e}")
|
logger.error(f"SQLAlchemy error while listing steps: {e}")
|
||||||
raise
|
raise
|
||||||
except NotFoundError as e:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error while listing steps: {e}")
|
logger.error(f"Unexpected error while listing steps: {e}")
|
||||||
raise
|
raise
|
||||||
@ -509,8 +497,6 @@ class AgentDB:
|
|||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
logger.error(f"SQLAlchemy error while listing artifacts: {e}")
|
logger.error(f"SQLAlchemy error while listing artifacts: {e}")
|
||||||
raise
|
raise
|
||||||
except NotFoundError as e:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error while listing artifacts: {e}")
|
logger.error(f"Unexpected error while listing artifacts: {e}")
|
||||||
raise
|
raise
|
||||||
|
@ -22,14 +22,27 @@ from forge.agent_protocol.models import (
|
|||||||
)
|
)
|
||||||
from forge.utils.exceptions import NotFoundError as DataNotFoundError
|
from forge.utils.exceptions import NotFoundError as DataNotFoundError
|
||||||
|
|
||||||
|
TEST_DB_FILENAME = "test_db.sqlite3"
|
||||||
|
TEST_DB_URL = f"sqlite:///{TEST_DB_FILENAME}"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
def test_table_creation():
|
|
||||||
db_name = "sqlite:///test_db.sqlite3"
|
|
||||||
agent_db = AgentDB(db_name)
|
|
||||||
|
|
||||||
conn = sqlite3.connect("test_db.sqlite3")
|
@pytest.fixture
|
||||||
cursor = conn.cursor()
|
def agent_db():
|
||||||
|
db = AgentDB(TEST_DB_URL)
|
||||||
|
yield db
|
||||||
|
db.close()
|
||||||
|
os.remove(TEST_DB_FILENAME)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def raw_db_connection(agent_db: AgentDB):
|
||||||
|
connection = sqlite3.connect(TEST_DB_FILENAME)
|
||||||
|
yield connection
|
||||||
|
connection.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_table_creation(raw_db_connection: sqlite3.Connection):
|
||||||
|
cursor = raw_db_connection.cursor()
|
||||||
|
|
||||||
# Test for tasks table existence
|
# Test for tasks table existence
|
||||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='tasks'")
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='tasks'")
|
||||||
@ -45,8 +58,6 @@ def test_table_creation():
|
|||||||
)
|
)
|
||||||
assert cursor.fetchone() is not None
|
assert cursor.fetchone() is not None
|
||||||
|
|
||||||
os.remove(db_name.split("///")[1])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_task_schema():
|
async def test_task_schema():
|
||||||
@ -84,7 +95,10 @@ async def test_step_schema():
|
|||||||
name="Write to file",
|
name="Write to file",
|
||||||
input="Write the words you receive to the file 'output.txt'.",
|
input="Write the words you receive to the file 'output.txt'.",
|
||||||
status=StepStatus.created,
|
status=StepStatus.created,
|
||||||
output="I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')>",
|
output=(
|
||||||
|
"I am going to use the write_to_file command and write Washington "
|
||||||
|
"to a file called output.txt <write_to_file('output.txt', 'Washington')>"
|
||||||
|
),
|
||||||
artifacts=[
|
artifacts=[
|
||||||
Artifact(
|
Artifact(
|
||||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||||
@ -101,13 +115,13 @@ async def test_step_schema():
|
|||||||
assert step.step_id == "6bb1801a-fd80-45e8-899a-4dd723cc602e"
|
assert step.step_id == "6bb1801a-fd80-45e8-899a-4dd723cc602e"
|
||||||
assert step.name == "Write to file"
|
assert step.name == "Write to file"
|
||||||
assert step.status == StepStatus.created
|
assert step.status == StepStatus.created
|
||||||
assert (
|
assert step.output == (
|
||||||
step.output
|
"I am going to use the write_to_file command and write Washington "
|
||||||
== "I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')>"
|
"to a file called output.txt <write_to_file('output.txt', 'Washington')>"
|
||||||
)
|
)
|
||||||
assert len(step.artifacts) == 1
|
assert len(step.artifacts) == 1
|
||||||
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
||||||
assert step.is_last == False
|
assert step.is_last is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -118,6 +132,7 @@ async def test_convert_to_task():
|
|||||||
created_at=now,
|
created_at=now,
|
||||||
modified_at=now,
|
modified_at=now,
|
||||||
input="Write the words you receive to the file 'output.txt'.",
|
input="Write the words you receive to the file 'output.txt'.",
|
||||||
|
additional_input={},
|
||||||
artifacts=[
|
artifacts=[
|
||||||
ArtifactModel(
|
ArtifactModel(
|
||||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||||
@ -147,6 +162,7 @@ async def test_convert_to_step():
|
|||||||
name="Write to file",
|
name="Write to file",
|
||||||
status="created",
|
status="created",
|
||||||
input="Write the words you receive to the file 'output.txt'.",
|
input="Write the words you receive to the file 'output.txt'.",
|
||||||
|
additional_input={},
|
||||||
artifacts=[
|
artifacts=[
|
||||||
ArtifactModel(
|
ArtifactModel(
|
||||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||||
@ -166,7 +182,7 @@ async def test_convert_to_step():
|
|||||||
assert step.status == StepStatus.created
|
assert step.status == StepStatus.created
|
||||||
assert len(step.artifacts) == 1
|
assert len(step.artifacts) == 1
|
||||||
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
||||||
assert step.is_last == False
|
assert step.is_last is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -183,91 +199,67 @@ async def test_convert_to_artifact():
|
|||||||
artifact = convert_to_artifact(artifact_model)
|
artifact = convert_to_artifact(artifact_model)
|
||||||
assert artifact.artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
assert artifact.artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
||||||
assert artifact.relative_path == "file:///path/to/main.py"
|
assert artifact.relative_path == "file:///path/to/main.py"
|
||||||
assert artifact.agent_created == True
|
assert artifact.agent_created is True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_task():
|
async def test_create_task(agent_db: AgentDB):
|
||||||
# Having issues with pytest fixture so added setup and teardown in each test as a rapid workaround
|
|
||||||
# TODO: Fix this!
|
|
||||||
db_name = "sqlite:///test_db.sqlite3"
|
|
||||||
agent_db = AgentDB(db_name)
|
|
||||||
|
|
||||||
task = await agent_db.create_task("task_input")
|
task = await agent_db.create_task("task_input")
|
||||||
assert task.input == "task_input"
|
assert task.input == "task_input"
|
||||||
os.remove(db_name.split("///")[1])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_and_get_task():
|
async def test_create_and_get_task(agent_db: AgentDB):
|
||||||
db_name = "sqlite:///test_db.sqlite3"
|
|
||||||
agent_db = AgentDB(db_name)
|
|
||||||
task = await agent_db.create_task("test_input")
|
task = await agent_db.create_task("test_input")
|
||||||
fetched_task = await agent_db.get_task(task.task_id)
|
fetched_task = await agent_db.get_task(task.task_id)
|
||||||
assert fetched_task.input == "test_input"
|
assert fetched_task.input == "test_input"
|
||||||
os.remove(db_name.split("///")[1])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_task_not_found():
|
async def test_get_task_not_found(agent_db: AgentDB):
|
||||||
db_name = "sqlite:///test_db.sqlite3"
|
|
||||||
agent_db = AgentDB(db_name)
|
|
||||||
with pytest.raises(DataNotFoundError):
|
with pytest.raises(DataNotFoundError):
|
||||||
await agent_db.get_task(9999)
|
await agent_db.get_task("9999")
|
||||||
os.remove(db_name.split("///")[1])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_and_get_step():
|
async def test_create_and_get_step(agent_db: AgentDB):
|
||||||
db_name = "sqlite:///test_db.sqlite3"
|
|
||||||
agent_db = AgentDB(db_name)
|
|
||||||
task = await agent_db.create_task("task_input")
|
task = await agent_db.create_task("task_input")
|
||||||
step_input = StepInput(type="python/code")
|
step_input = {"type": "python/code"}
|
||||||
request = StepRequestBody(input="test_input debug", additional_input=step_input)
|
request = StepRequestBody(input="test_input debug", additional_input=step_input)
|
||||||
step = await agent_db.create_step(task.task_id, request)
|
step = await agent_db.create_step(task.task_id, request)
|
||||||
step = await agent_db.get_step(task.task_id, step.step_id)
|
step = await agent_db.get_step(task.task_id, step.step_id)
|
||||||
assert step.input == "test_input debug"
|
assert step.input == "test_input debug"
|
||||||
os.remove(db_name.split("///")[1])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_updating_step():
|
async def test_updating_step(agent_db: AgentDB):
|
||||||
db_name = "sqlite:///test_db.sqlite3"
|
|
||||||
agent_db = AgentDB(db_name)
|
|
||||||
created_task = await agent_db.create_task("task_input")
|
created_task = await agent_db.create_task("task_input")
|
||||||
step_input = StepInput(type="python/code")
|
step_input = {"type": "python/code"}
|
||||||
request = StepRequestBody(input="test_input debug", additional_input=step_input)
|
request = StepRequestBody(input="test_input debug", additional_input=step_input)
|
||||||
created_step = await agent_db.create_step(created_task.task_id, request)
|
created_step = await agent_db.create_step(created_task.task_id, request)
|
||||||
await agent_db.update_step(created_task.task_id, created_step.step_id, "completed")
|
await agent_db.update_step(created_task.task_id, created_step.step_id, "completed")
|
||||||
|
|
||||||
step = await agent_db.get_step(created_task.task_id, created_step.step_id)
|
step = await agent_db.get_step(created_task.task_id, created_step.step_id)
|
||||||
assert step.status.value == "completed"
|
assert step.status.value == "completed"
|
||||||
os.remove(db_name.split("///")[1])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_step_not_found():
|
async def test_get_step_not_found(agent_db: AgentDB):
|
||||||
db_name = "sqlite:///test_db.sqlite3"
|
|
||||||
agent_db = AgentDB(db_name)
|
|
||||||
with pytest.raises(DataNotFoundError):
|
with pytest.raises(DataNotFoundError):
|
||||||
await agent_db.get_step(9999, 9999)
|
await agent_db.get_step("9999", "9999")
|
||||||
os.remove(db_name.split("///")[1])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_artifact():
|
async def test_get_artifact(agent_db: AgentDB):
|
||||||
db_name = "sqlite:///test_db.sqlite3"
|
|
||||||
db = AgentDB(db_name)
|
|
||||||
|
|
||||||
# Given: A task and its corresponding artifact
|
# Given: A task and its corresponding artifact
|
||||||
task = await db.create_task("test_input debug")
|
task = await agent_db.create_task("test_input debug")
|
||||||
step_input = StepInput(type="python/code")
|
step_input = {"type": "python/code"}
|
||||||
requst = StepRequestBody(input="test_input debug", additional_input=step_input)
|
requst = StepRequestBody(input="test_input debug", additional_input=step_input)
|
||||||
|
|
||||||
step = await db.create_step(task.task_id, requst)
|
step = await agent_db.create_step(task.task_id, requst)
|
||||||
|
|
||||||
# Create an artifact
|
# Create an artifact
|
||||||
artifact = await db.create_artifact(
|
artifact = await agent_db.create_artifact(
|
||||||
task_id=task.task_id,
|
task_id=task.task_id,
|
||||||
file_name="test_get_artifact_sample_file.txt",
|
file_name="test_get_artifact_sample_file.txt",
|
||||||
relative_path="file:///path/to/test_get_artifact_sample_file.txt",
|
relative_path="file:///path/to/test_get_artifact_sample_file.txt",
|
||||||
@ -276,7 +268,7 @@ async def test_get_artifact():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# When: The artifact is fetched by its ID
|
# When: The artifact is fetched by its ID
|
||||||
fetched_artifact = await db.get_artifact(artifact.artifact_id)
|
fetched_artifact = await agent_db.get_artifact(artifact.artifact_id)
|
||||||
|
|
||||||
# Then: The fetched artifact matches the original
|
# Then: The fetched artifact matches the original
|
||||||
assert fetched_artifact.artifact_id == artifact.artifact_id
|
assert fetched_artifact.artifact_id == artifact.artifact_id
|
||||||
@ -285,47 +277,37 @@ async def test_get_artifact():
|
|||||||
== "file:///path/to/test_get_artifact_sample_file.txt"
|
== "file:///path/to/test_get_artifact_sample_file.txt"
|
||||||
)
|
)
|
||||||
|
|
||||||
os.remove(db_name.split("///")[1])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_tasks():
|
async def test_list_tasks(agent_db: AgentDB):
|
||||||
db_name = "sqlite:///test_db.sqlite3"
|
|
||||||
db = AgentDB(db_name)
|
|
||||||
|
|
||||||
# Given: Multiple tasks in the database
|
# Given: Multiple tasks in the database
|
||||||
task1 = await db.create_task("test_input_1")
|
task1 = await agent_db.create_task("test_input_1")
|
||||||
task2 = await db.create_task("test_input_2")
|
task2 = await agent_db.create_task("test_input_2")
|
||||||
|
|
||||||
# When: All tasks are fetched
|
# When: All tasks are fetched
|
||||||
fetched_tasks, pagination = await db.list_tasks()
|
fetched_tasks, pagination = await agent_db.list_tasks()
|
||||||
|
|
||||||
# Then: The fetched tasks list includes the created tasks
|
# Then: The fetched tasks list includes the created tasks
|
||||||
task_ids = [task.task_id for task in fetched_tasks]
|
task_ids = [task.task_id for task in fetched_tasks]
|
||||||
assert task1.task_id in task_ids
|
assert task1.task_id in task_ids
|
||||||
assert task2.task_id in task_ids
|
assert task2.task_id in task_ids
|
||||||
os.remove(db_name.split("///")[1])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_steps():
|
async def test_list_steps(agent_db: AgentDB):
|
||||||
db_name = "sqlite:///test_db.sqlite3"
|
step_input = {"type": "python/code"}
|
||||||
db = AgentDB(db_name)
|
request = StepRequestBody(input="test_input debug", additional_input=step_input)
|
||||||
|
|
||||||
step_input = StepInput(type="python/code")
|
|
||||||
requst = StepRequestBody(input="test_input debug", additional_input=step_input)
|
|
||||||
|
|
||||||
# Given: A task and multiple steps for that task
|
# Given: A task and multiple steps for that task
|
||||||
task = await db.create_task("test_input")
|
task = await agent_db.create_task("test_input")
|
||||||
step1 = await db.create_step(task.task_id, requst)
|
step1 = await agent_db.create_step(task.task_id, request)
|
||||||
requst = StepRequestBody(input="step two", additional_input=step_input)
|
request = StepRequestBody(input="step two")
|
||||||
step2 = await db.create_step(task.task_id, requst)
|
step2 = await agent_db.create_step(task.task_id, request)
|
||||||
|
|
||||||
# When: All steps for the task are fetched
|
# When: All steps for the task are fetched
|
||||||
fetched_steps, pagination = await db.list_steps(task.task_id)
|
fetched_steps, pagination = await agent_db.list_steps(task.task_id)
|
||||||
|
|
||||||
# Then: The fetched steps list includes the created steps
|
# Then: The fetched steps list includes the created steps
|
||||||
step_ids = [step.step_id for step in fetched_steps]
|
step_ids = [step.step_id for step in fetched_steps]
|
||||||
assert step1.step_id in step_ids
|
assert step1.step_id in step_ids
|
||||||
assert step2.step_id in step_ids
|
assert step2.step_id in step_ids
|
||||||
os.remove(db_name.split("///")[1])
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from .artifact import Artifact, ArtifactUpload
|
from .artifact import Artifact
|
||||||
from .pagination import Pagination
|
from .pagination import Pagination
|
||||||
from .task import (
|
from .task import (
|
||||||
Step,
|
Step,
|
||||||
@ -10,3 +10,16 @@ from .task import (
|
|||||||
TaskRequestBody,
|
TaskRequestBody,
|
||||||
TaskStepsListResponse,
|
TaskStepsListResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Artifact",
|
||||||
|
"Pagination",
|
||||||
|
"Step",
|
||||||
|
"StepRequestBody",
|
||||||
|
"StepStatus",
|
||||||
|
"Task",
|
||||||
|
"TaskArtifactsListResponse",
|
||||||
|
"TaskListResponse",
|
||||||
|
"TaskRequestBody",
|
||||||
|
"TaskStepsListResponse",
|
||||||
|
]
|
||||||
|
@ -3,15 +3,6 @@ from datetime import datetime
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class ArtifactUpload(BaseModel):
|
|
||||||
file: str = Field(..., description="File to upload.", format="binary")
|
|
||||||
relative_path: str = Field(
|
|
||||||
...,
|
|
||||||
description="Relative path of the artifact in the agent's workspace.",
|
|
||||||
example="python/code",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Artifact(BaseModel):
|
class Artifact(BaseModel):
|
||||||
created_at: datetime = Field(
|
created_at: datetime = Field(
|
||||||
...,
|
...,
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -17,7 +17,7 @@ class TaskRequestBody(BaseModel):
|
|||||||
description="Input prompt for the task.",
|
description="Input prompt for the task.",
|
||||||
example="Write the words you receive to the file 'output.txt'.",
|
example="Write the words you receive to the file 'output.txt'.",
|
||||||
)
|
)
|
||||||
additional_input: Optional[dict] = None
|
additional_input: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class Task(TaskRequestBody):
|
class Task(TaskRequestBody):
|
||||||
@ -38,8 +38,8 @@ class Task(TaskRequestBody):
|
|||||||
description="The ID of the task.",
|
description="The ID of the task.",
|
||||||
example="50da533e-3904-4401-8a07-c49adf88b5eb",
|
example="50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||||
)
|
)
|
||||||
artifacts: Optional[List[Artifact]] = Field(
|
artifacts: list[Artifact] = Field(
|
||||||
[],
|
default_factory=list,
|
||||||
description="A list of artifacts that the task has produced.",
|
description="A list of artifacts that the task has produced.",
|
||||||
example=[
|
example=[
|
||||||
"7a49f31c-f9c6-4346-a22c-e32bc5af4d8e",
|
"7a49f31c-f9c6-4346-a22c-e32bc5af4d8e",
|
||||||
@ -50,14 +50,12 @@ class Task(TaskRequestBody):
|
|||||||
|
|
||||||
class StepRequestBody(BaseModel):
|
class StepRequestBody(BaseModel):
|
||||||
name: Optional[str] = Field(
|
name: Optional[str] = Field(
|
||||||
None, description="The name of the task step.", example="Write to file"
|
default=None, description="The name of the task step.", example="Write to file"
|
||||||
)
|
)
|
||||||
input: Optional[str] = Field(
|
input: str = Field(
|
||||||
None,
|
..., description="Input prompt for the step.", example="Washington"
|
||||||
description="Input prompt for the step.",
|
|
||||||
example="Washington",
|
|
||||||
)
|
)
|
||||||
additional_input: Optional[dict] = None
|
additional_input: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class StepStatus(Enum):
|
class StepStatus(Enum):
|
||||||
@ -90,19 +88,23 @@ class Step(StepRequestBody):
|
|||||||
example="6bb1801a-fd80-45e8-899a-4dd723cc602e",
|
example="6bb1801a-fd80-45e8-899a-4dd723cc602e",
|
||||||
)
|
)
|
||||||
name: Optional[str] = Field(
|
name: Optional[str] = Field(
|
||||||
None, description="The name of the task step.", example="Write to file"
|
default=None, description="The name of the task step.", example="Write to file"
|
||||||
)
|
)
|
||||||
status: StepStatus = Field(
|
status: StepStatus = Field(
|
||||||
..., description="The status of the task step.", example="created"
|
..., description="The status of the task step.", example="created"
|
||||||
)
|
)
|
||||||
output: Optional[str] = Field(
|
output: Optional[str] = Field(
|
||||||
None,
|
default=None,
|
||||||
description="Output of the task step.",
|
description="Output of the task step.",
|
||||||
example="I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')",
|
example=(
|
||||||
|
"I am going to use the write_to_file command and write Washington "
|
||||||
|
"to a file called output.txt <write_to_file('output.txt', 'Washington')"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
additional_output: Optional[dict] = None
|
additional_output: Optional[dict[str, Any]] = None
|
||||||
artifacts: Optional[List[Artifact]] = Field(
|
artifacts: list[Artifact] = Field(
|
||||||
[], description="A list of artifacts that the step has produced."
|
default_factory=list,
|
||||||
|
description="A list of artifacts that the step has produced.",
|
||||||
)
|
)
|
||||||
is_last: bool = Field(
|
is_last: bool = Field(
|
||||||
..., description="Whether this is the last step in the task.", example=True
|
..., description="Whether this is the last step in the task.", example=True
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
from .command import Command, CommandOutput, CommandParameter
|
from .command import Command
|
||||||
from .decorator import command
|
from .decorator import command
|
||||||
from .parameter import CommandParameter
|
from .parameter import CommandParameter
|
||||||
|
|
||||||
|
__all__ = ["Command", "CommandParameter", "command"]
|
||||||
|
@ -1,14 +1,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Callable, Generic, ParamSpec, TypeVar
|
from typing import Callable, Concatenate, Generic, ParamSpec, TypeVar, cast
|
||||||
|
|
||||||
|
from forge.agent.protocols import CommandProvider
|
||||||
|
|
||||||
from .parameter import CommandParameter
|
from .parameter import CommandParameter
|
||||||
|
|
||||||
CommandOutput = Any
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
CO = TypeVar("CO", bound=CommandOutput)
|
CO = TypeVar("CO") # command output
|
||||||
|
|
||||||
|
_CP = TypeVar("_CP", bound=CommandProvider)
|
||||||
|
|
||||||
|
|
||||||
class Command(Generic[P, CO]):
|
class Command(Generic[P, CO]):
|
||||||
@ -24,7 +26,7 @@ class Command(Generic[P, CO]):
|
|||||||
self,
|
self,
|
||||||
names: list[str],
|
names: list[str],
|
||||||
description: str,
|
description: str,
|
||||||
method: Callable[P, CO],
|
method: Callable[Concatenate[_CP, P], CO],
|
||||||
parameters: list[CommandParameter],
|
parameters: list[CommandParameter],
|
||||||
):
|
):
|
||||||
# Check if all parameters are provided
|
# Check if all parameters are provided
|
||||||
@ -34,7 +36,9 @@ class Command(Generic[P, CO]):
|
|||||||
)
|
)
|
||||||
self.names = names
|
self.names = names
|
||||||
self.description = description
|
self.description = description
|
||||||
self.method = method
|
# Method technically has a `self` parameter, but we can ignore that
|
||||||
|
# since Python passes it internally.
|
||||||
|
self.method = cast(Callable[P, CO], method)
|
||||||
self.parameters = parameters
|
self.parameters = parameters
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -62,7 +66,8 @@ class Command(Generic[P, CO]):
|
|||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
params = [
|
params = [
|
||||||
f"{param.name}: "
|
f"{param.name}: "
|
||||||
+ ("%s" if param.spec.required else "Optional[%s]") % param.spec.type.value
|
+ ("%s" if param.spec.required else "Optional[%s]")
|
||||||
|
% (param.spec.type.value if param.spec.type else "Any")
|
||||||
for param in self.parameters
|
for param in self.parameters
|
||||||
]
|
]
|
||||||
return (
|
return (
|
||||||
|
@ -1,2 +1,4 @@
|
|||||||
from .action_history import ActionHistoryComponent
|
from .action_history import ActionHistoryComponent
|
||||||
from .model import Episode, EpisodicActionHistory
|
from .model import Episode, EpisodicActionHistory
|
||||||
|
|
||||||
|
__all__ = ["ActionHistoryComponent", "Episode", "EpisodicActionHistory"]
|
||||||
|
@ -1,27 +1,27 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Callable, Generic, Iterator, Optional
|
from typing import TYPE_CHECKING, Callable, Iterator, Optional
|
||||||
|
|
||||||
from forge.agent.protocols import AfterExecute, AfterParse, MessageProvider
|
from forge.agent.protocols import AfterExecute, AfterParse, MessageProvider
|
||||||
from forge.llm.prompting.utils import indent
|
from forge.llm.prompting.utils import indent
|
||||||
from forge.llm.providers import ChatMessage, ChatModelProvider
|
from forge.llm.providers import ChatMessage, MultiProvider
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from forge.config.config import Config
|
from forge.config.config import Config
|
||||||
|
|
||||||
from .model import AP, ActionResult, Episode, EpisodicActionHistory
|
from .model import ActionResult, AnyProposal, Episode, EpisodicActionHistory
|
||||||
|
|
||||||
|
|
||||||
class ActionHistoryComponent(MessageProvider, AfterParse, AfterExecute, Generic[AP]):
|
class ActionHistoryComponent(MessageProvider, AfterParse[AnyProposal], AfterExecute):
|
||||||
"""Keeps track of the event history and provides a summary of the steps."""
|
"""Keeps track of the event history and provides a summary of the steps."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
event_history: EpisodicActionHistory[AP],
|
event_history: EpisodicActionHistory[AnyProposal],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
count_tokens: Callable[[str], int],
|
count_tokens: Callable[[str], int],
|
||||||
legacy_config: Config,
|
legacy_config: Config,
|
||||||
llm_provider: ChatModelProvider,
|
llm_provider: MultiProvider,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.event_history = event_history
|
self.event_history = event_history
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
@ -37,7 +37,7 @@ class ActionHistoryComponent(MessageProvider, AfterParse, AfterExecute, Generic[
|
|||||||
):
|
):
|
||||||
yield ChatMessage.system(f"## Progress on your Task so far\n\n{progress}")
|
yield ChatMessage.system(f"## Progress on your Task so far\n\n{progress}")
|
||||||
|
|
||||||
def after_parse(self, result: AP) -> None:
|
def after_parse(self, result: AnyProposal) -> None:
|
||||||
self.event_history.register_action(result)
|
self.event_history.register_action(result)
|
||||||
|
|
||||||
async def after_execute(self, result: ActionResult) -> None:
|
async def after_execute(self, result: ActionResult) -> None:
|
||||||
@ -48,7 +48,7 @@ class ActionHistoryComponent(MessageProvider, AfterParse, AfterExecute, Generic[
|
|||||||
|
|
||||||
def _compile_progress(
|
def _compile_progress(
|
||||||
self,
|
self,
|
||||||
episode_history: list[Episode],
|
episode_history: list[Episode[AnyProposal]],
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
count_tokens: Optional[Callable[[str], int]] = None,
|
count_tokens: Optional[Callable[[str], int]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
@ -1,25 +1,23 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import TYPE_CHECKING, Generic, Iterator, TypeVar
|
from typing import TYPE_CHECKING, Generic
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from pydantic.generics import GenericModel
|
from pydantic.generics import GenericModel
|
||||||
|
|
||||||
from forge.content_processing.text import summarize_text
|
from forge.content_processing.text import summarize_text
|
||||||
from forge.llm.prompting.utils import format_numbered_list, indent
|
from forge.llm.prompting.utils import format_numbered_list, indent
|
||||||
from forge.models.action import ActionProposal, ActionResult
|
from forge.models.action import ActionResult, AnyProposal
|
||||||
from forge.models.utils import ModelWithSummary
|
from forge.models.utils import ModelWithSummary
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from forge.config.config import Config
|
from forge.config.config import Config
|
||||||
from forge.llm.providers import ChatModelProvider
|
from forge.llm.providers import MultiProvider
|
||||||
|
|
||||||
AP = TypeVar("AP", bound=ActionProposal)
|
|
||||||
|
|
||||||
|
|
||||||
class Episode(GenericModel, Generic[AP]):
|
class Episode(GenericModel, Generic[AnyProposal]):
|
||||||
action: AP
|
action: AnyProposal
|
||||||
result: ActionResult | None
|
result: ActionResult | None
|
||||||
summary: str | None = None
|
summary: str | None = None
|
||||||
|
|
||||||
@ -54,32 +52,29 @@ class Episode(GenericModel, Generic[AP]):
|
|||||||
return executed_action + action_result
|
return executed_action + action_result
|
||||||
|
|
||||||
|
|
||||||
class EpisodicActionHistory(GenericModel, Generic[AP]):
|
class EpisodicActionHistory(GenericModel, Generic[AnyProposal]):
|
||||||
"""Utility container for an action history"""
|
"""Utility container for an action history"""
|
||||||
|
|
||||||
episodes: list[Episode[AP]] = Field(default_factory=list)
|
episodes: list[Episode[AnyProposal]] = Field(default_factory=list)
|
||||||
cursor: int = 0
|
cursor: int = 0
|
||||||
_lock = asyncio.Lock()
|
_lock = asyncio.Lock()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_episode(self) -> Episode[AP] | None:
|
def current_episode(self) -> Episode[AnyProposal] | None:
|
||||||
if self.cursor == len(self):
|
if self.cursor == len(self):
|
||||||
return None
|
return None
|
||||||
return self[self.cursor]
|
return self[self.cursor]
|
||||||
|
|
||||||
def __getitem__(self, key: int) -> Episode[AP]:
|
def __getitem__(self, key: int) -> Episode[AnyProposal]:
|
||||||
return self.episodes[key]
|
return self.episodes[key]
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[Episode[AP]]:
|
|
||||||
return iter(self.episodes)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.episodes)
|
return len(self.episodes)
|
||||||
|
|
||||||
def __bool__(self) -> bool:
|
def __bool__(self) -> bool:
|
||||||
return len(self.episodes) > 0
|
return len(self.episodes) > 0
|
||||||
|
|
||||||
def register_action(self, action: AP) -> None:
|
def register_action(self, action: AnyProposal) -> None:
|
||||||
if not self.current_episode:
|
if not self.current_episode:
|
||||||
self.episodes.append(Episode(action=action, result=None))
|
self.episodes.append(Episode(action=action, result=None))
|
||||||
assert self.current_episode
|
assert self.current_episode
|
||||||
@ -113,7 +108,7 @@ class EpisodicActionHistory(GenericModel, Generic[AP]):
|
|||||||
self.cursor = len(self.episodes)
|
self.cursor = len(self.episodes)
|
||||||
|
|
||||||
async def handle_compression(
|
async def handle_compression(
|
||||||
self, llm_provider: ChatModelProvider, app_config: Config
|
self, llm_provider: MultiProvider, app_config: Config
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Compresses each episode in the action history using an LLM.
|
"""Compresses each episode in the action history using an LLM.
|
||||||
|
|
||||||
|
@ -3,6 +3,11 @@ from .code_executor import (
|
|||||||
DENYLIST_CONTROL,
|
DENYLIST_CONTROL,
|
||||||
CodeExecutionError,
|
CodeExecutionError,
|
||||||
CodeExecutorComponent,
|
CodeExecutorComponent,
|
||||||
is_docker_available,
|
|
||||||
we_are_running_in_a_docker_container,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ALLOWLIST_CONTROL",
|
||||||
|
"DENYLIST_CONTROL",
|
||||||
|
"CodeExecutionError",
|
||||||
|
"CodeExecutorComponent",
|
||||||
|
]
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user