Set up unified pre-commit + CI w/ linting + type checking & FIX EVERYTHING (#7171)

- **FIX ALL LINT/TYPE ERRORS IN AUTOGPT, FORGE, AND BENCHMARK**

### Linting
- Clean up linter configs for `autogpt`, `forge`, and `benchmark`
- Add type checking with Pyright
- Create unified pre-commit config
- Create unified linting and type checking CI workflow

### Testing
- Synchronize CI test setups for `autogpt`, `forge`, and `benchmark`
   - Add missing pytest-cov to benchmark dependencies
- Mark GCS tests as slow to speed up pre-commit test runs
- Repair `forge` test suite
  - Add `AgentDB.close()` method for test DB teardown in db_test.py
  - Use actual temporary dir instead of forge/test_workspace/
- Move left-behind dependencies for moved `forge`-code to from autogpt to forge

### Notable type changes
- Replace uses of `ChatModelProvider` by `MultiProvider`
- Removed unnecessary exports from various __init__.py
- Simplify `FileStorage.open_file` signature by removing `IOBase` from return type union
  - Implement `S3BinaryIOWrapper(BinaryIO)` type interposer for `S3FileStorage`

- Expand overloads of `GCSFileStorage.open_file` for improved typing of read and write modes

  Had to silence type checking for the extra overloads, because (I think) Pyright is reporting a false-positive:
  https://github.com/microsoft/pyright/issues/8007

- Change `count_tokens`, `get_tokenizer`, `count_message_tokens` methods on `ModelProvider`s from class methods to instance methods

- Move `CompletionModelFunction.schema` method -> helper function `format_function_def_for_openai` in `forge.llm.providers.openai`

- Rename `ModelProvider` -> `BaseModelProvider`
- Rename `ChatModelProvider` -> `BaseChatModelProvider`
- Add type `ChatModelProvider` which is a union of all subclasses of `BaseChatModelProvider`

### Removed rather than fixed
- Remove deprecated and broken autogpt/agbenchmark_config/benchmarks.py
- Various base classes and properties on base classes in `forge.llm.providers.schema` and `forge.models.providers`

### Fixes for other issues that came to light
- Clean up `forge.agent_protocol.api_router`, `forge.agent_protocol.database`, and `forge.agent.agent`

- Add fallback behavior to `ImageGeneratorComponent`
   - Remove test for deprecated failure behavior

- Fix `agbenchmark.challenges.builtin` challenge exclusion mechanism on Windows

- Fix `_tool_calls_compat_extract_calls` in `forge.llm.providers.openai`

- Add support for `any` (= no type specified) in `JSONSchema.typescript_type`
This commit is contained in:
Reinier van der Leer 2024-05-28 05:04:21 +02:00 committed by GitHub
parent 2c13a2706c
commit f107ff8cf0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
147 changed files with 2897 additions and 2425 deletions

View File

@ -1,4 +1,4 @@
name: AutoGPT Python CI name: AutoGPT CI
on: on:
push: push:
@ -24,57 +24,6 @@ defaults:
working-directory: autogpt working-directory: autogpt
jobs: jobs:
lint:
runs-on: ubuntu-latest
env:
min-python-version: "3.10"
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python ${{ env.min-python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ env.min-python-version }}
- id: get_date
name: Get date
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
- name: Set up Python dependency cache
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: ${{ runner.os }}-poetry-${{ hashFiles('autogpt/pyproject.toml') }}-${{ steps.get_date.outputs.date }}
- name: Install Python dependencies
run: |
curl -sSL https://install.python-poetry.org | python3 -
poetry install
- name: Lint with flake8
run: poetry run flake8
- name: Check black formatting
run: poetry run black . --check
if: success() || failure()
- name: Check isort formatting
run: poetry run isort . --check
if: success() || failure()
# - name: Check mypy formatting
# run: poetry run mypy
# if: success() || failure()
# - name: Check for unused imports and pass statements
# run: |
# cmd="autoflake --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring autogpt tests"
# poetry run $cmd --check || (echo "You have unused imports or pass statements, please run '${cmd} --in-place'" && exit 1)
test: test:
permissions: permissions:
contents: read contents: read

View File

@ -1,4 +1,4 @@
name: AutoGPTs smoke test CI name: Agent smoke tests
on: on:
workflow_dispatch: workflow_dispatch:
@ -28,7 +28,7 @@ on:
- '!**/*.md' - '!**/*.md'
jobs: jobs:
run-tests: serve-agent-protocol:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:

View File

@ -1,4 +1,4 @@
name: Benchmark CI name: AGBenchmark CI
on: on:
push: push:
@ -14,62 +14,91 @@ on:
- '!benchmark/reports/**' - '!benchmark/reports/**'
- .github/workflows/benchmark-ci.yml - .github/workflows/benchmark-ci.yml
concurrency:
group: ${{ format('benchmark-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
cancel-in-progress: ${{ startsWith(github.event_name, 'pull_request') }}
defaults:
run:
shell: bash
env: env:
min-python-version: '3.10' min-python-version: '3.10'
jobs: jobs:
lint: test:
runs-on: ubuntu-latest permissions:
contents: read
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
platform-os: [ubuntu, macos, macos-arm64, windows]
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
defaults:
run:
shell: bash
working-directory: benchmark
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
submodules: true
- name: Set up Python ${{ env.min-python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: ${{ env.min-python-version }} python-version: ${{ matrix.python-version }}
- id: get_date - name: Set up Python dependency cache
name: Get date # On Windows, unpacking cached dependencies takes longer than just installing them
working-directory: ./benchmark/ if: runner.os != 'Windows'
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT uses: actions/cache@v4
with:
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
key: poetry-${{ runner.os }}-${{ hashFiles('benchmark/poetry.lock') }}
- name: Install Poetry - name: Install Poetry (Unix)
working-directory: ./benchmark/ if: runner.os != 'Windows'
run: | run: |
curl -sSL https://install.python-poetry.org | python - curl -sSL https://install.python-poetry.org | python3 -
- name: Install dependencies if [ "${{ runner.os }}" = "macOS" ]; then
working-directory: ./benchmark/ PATH="$HOME/.local/bin:$PATH"
echo "$HOME/.local/bin" >> $GITHUB_PATH
fi
- name: Install Poetry (Windows)
if: runner.os == 'Windows'
shell: pwsh
run: | run: |
export POETRY_VIRTUALENVS_IN_PROJECT=true (Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
poetry install -vvv
- name: Lint with flake8 $env:PATH += ";$env:APPDATA\Python\Scripts"
working-directory: ./benchmark/ echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
run: poetry run flake8
- name: Check black formatting - name: Install Python dependencies
working-directory: ./benchmark/ run: poetry install
run: poetry run black . --exclude test.py --check
if: success() || failure()
- name: Check isort formatting - name: Run pytest with coverage
working-directory: ./benchmark/
run: poetry run isort . --check
if: success() || failure()
- name: Check for unused imports and pass statements
working-directory: ./benchmark/
run: | run: |
cmd="poetry run autoflake --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring agbenchmark" poetry run pytest -vv \
$cmd --check || (echo "You have unused imports or pass statements, please run '${cmd} --in-place'" && exit 1) --cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
if: success() || failure() --durations=10 \
tests
env:
CI: true
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
tests-agbenchmark: - name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: agbenchmark,${{ runner.os }}
self-test-with-agent:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
@ -89,11 +118,11 @@ jobs:
python-version: ${{ env.min-python-version }} python-version: ${{ env.min-python-version }}
- name: Install Poetry - name: Install Poetry
working-directory: ./${{ matrix.agent-name }}/
run: | run: |
curl -sSL https://install.python-poetry.org | python - curl -sSL https://install.python-poetry.org | python -
- name: Run regression tests - name: Run regression tests
working-directory: .
run: | run: |
./run agent start ${{ matrix.agent-name }} ./run agent start ${{ matrix.agent-name }}
cd ${{ matrix.agent-name }} cd ${{ matrix.agent-name }}
@ -125,7 +154,6 @@ jobs:
export BUILD_SKILL_TREE=true export BUILD_SKILL_TREE=true
poetry run agbenchmark --mock poetry run agbenchmark --mock
poetry run pytest -vv -s tests
CHANGED=$(git diff --name-only | grep -E '(agbenchmark/challenges)|(../frontend/assets)') || echo "No diffs" CHANGED=$(git diff --name-only | grep -E '(agbenchmark/challenges)|(../frontend/assets)') || echo "No diffs"
if [ ! -z "$CHANGED" ]; then if [ ! -z "$CHANGED" ]; then

129
.github/workflows/forge-ci.yml vendored Normal file
View File

@ -0,0 +1,129 @@
name: Forge CI
on:
push:
branches: [ master, development, ci-test* ]
paths:
- '.github/workflows/forge-ci.yml'
- 'forge/**'
pull_request:
branches: [ master, development, release-* ]
paths:
- '.github/workflows/forge-ci.yml'
- 'forge/**'
concurrency:
group: ${{ format('forge-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
cancel-in-progress: ${{ startsWith(github.event_name, 'pull_request') }}
defaults:
run:
shell: bash
working-directory: forge
jobs:
test:
permissions:
contents: read
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
platform-os: [ubuntu, macos, macos-arm64, windows]
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
steps:
# Quite slow on macOS (2~4 minutes to set up Docker)
# - name: Set up Docker (macOS)
# if: runner.os == 'macOS'
# uses: crazy-max/ghaction-setup-docker@v3
- name: Start MinIO service (Linux)
if: runner.os == 'Linux'
working-directory: '.'
run: |
docker pull minio/minio:edge-cicd
docker run -d -p 9000:9000 minio/minio:edge-cicd
- name: Start MinIO service (macOS)
if: runner.os == 'macOS'
working-directory: ${{ runner.temp }}
run: |
brew install minio/stable/minio
mkdir data
minio server ./data &
# No MinIO on Windows:
# - Windows doesn't support running Linux Docker containers
# - It doesn't seem possible to start background processes on Windows. They are
# killed after the step returns.
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Set up Python dependency cache
# On Windows, unpacking cached dependencies takes longer than just installing them
if: runner.os != 'Windows'
uses: actions/cache@v4
with:
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
key: poetry-${{ runner.os }}-${{ hashFiles('forge/poetry.lock') }}
- name: Install Poetry (Unix)
if: runner.os != 'Windows'
run: |
curl -sSL https://install.python-poetry.org | python3 -
if [ "${{ runner.os }}" = "macOS" ]; then
PATH="$HOME/.local/bin:$PATH"
echo "$HOME/.local/bin" >> $GITHUB_PATH
fi
- name: Install Poetry (Windows)
if: runner.os == 'Windows'
shell: pwsh
run: |
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
$env:PATH += ";$env:APPDATA\Python\Scripts"
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
- name: Install Python dependencies
run: poetry install
- name: Run pytest with coverage
run: |
poetry run pytest -vv \
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
--durations=10 \
forge
env:
CI: true
PLAIN_OUTPUT: True
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
AWS_ACCESS_KEY_ID: minioadmin
AWS_SECRET_ACCESS_KEY: minioadmin
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: forge,${{ runner.os }}
- name: Upload logs to artifact
if: always()
uses: actions/upload-artifact@v4
with:
name: test-logs
path: forge/logs/

151
.github/workflows/python-checks.yml vendored Normal file
View File

@ -0,0 +1,151 @@
name: Python checks
on:
push:
branches: [ master, development, ci-test* ]
paths:
- '.github/workflows/lint-ci.yml'
- 'autogpt/**'
- 'forge/**'
- 'benchmark/**'
- '**.py'
- '!autogpt/tests/vcr_cassettes'
pull_request:
branches: [ master, development, release-* ]
paths:
- '.github/workflows/lint-ci.yml'
- 'autogpt/**'
- 'forge/**'
- 'benchmark/**'
- '**.py'
- '!autogpt/tests/vcr_cassettes'
concurrency:
group: ${{ format('lint-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
cancel-in-progress: ${{ startsWith(github.event_name, 'pull_request') }}
defaults:
run:
shell: bash
jobs:
get-changed-parts:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- id: changes-in
name: Determine affected subprojects
uses: dorny/paths-filter@v3
with:
filters: |
autogpt:
- autogpt/autogpt/**
- autogpt/tests/**
- autogpt/poetry.lock
forge:
- forge/forge/**
- forge/tests/**
- forge/poetry.lock
benchmark:
- benchmark/agbenchmark/**
- benchmark/tests/**
- benchmark/poetry.lock
outputs:
changed-parts: ${{ steps.changes-in.outputs.changes }}
lint:
needs: get-changed-parts
runs-on: ubuntu-latest
env:
min-python-version: "3.10"
strategy:
matrix:
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
fail-fast: false
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python ${{ env.min-python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ env.min-python-version }}
- name: Set up Python dependency cache
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
# Install dependencies
- name: Install Python dependencies
run: poetry -C ${{ matrix.sub-package }} install
# Lint
- name: Lint (isort)
run: poetry run isort --check .
working-directory: ${{ matrix.sub-package }}
- name: Lint (Black)
if: success() || failure()
run: poetry run black --check .
working-directory: ${{ matrix.sub-package }}
- name: Lint (Flake8)
if: success() || failure()
run: poetry run flake8 .
working-directory: ${{ matrix.sub-package }}
types:
needs: get-changed-parts
runs-on: ubuntu-latest
env:
min-python-version: "3.10"
strategy:
matrix:
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
fail-fast: false
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python ${{ env.min-python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ env.min-python-version }}
- name: Set up Python dependency cache
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
# Install dependencies
- name: Install Python dependencies
run: poetry -C ${{ matrix.sub-package }} install
# Typecheck
- name: Typecheck
if: success() || failure()
run: poetry run pyright
working-directory: ${{ matrix.sub-package }}

127
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,127 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-added-large-files
args: ["--maxkb=500"]
- id: fix-byte-order-marker
- id: check-case-conflict
- id: check-merge-conflict
- id: check-symlinks
- id: debug-statements
- repo: local
# isort needs the context of which packages are installed to function, so we
# can't use a vendored isort pre-commit hook (which runs in its own isolated venv).
hooks:
- id: isort-autogpt
name: Lint (isort) - AutoGPT
entry: poetry -C autogpt run isort
files: ^autogpt/
types: [file, python]
language: system
- id: isort-forge
name: Lint (isort) - Forge
entry: poetry -C forge run isort
files: ^forge/
types: [file, python]
language: system
- id: isort-benchmark
name: Lint (isort) - Benchmark
entry: poetry -C benchmark run isort
files: ^benchmark/
types: [file, python]
language: system
- repo: https://github.com/psf/black
rev: 23.12.1
# Black has sensible defaults, doesn't need package context, and ignores
# everything in .gitignore, so it works fine without any config or arguments.
hooks:
- id: black
name: Lint (Black)
language_version: python3.10
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
# To have flake8 load the config of the individual subprojects, we have to call
# them separately.
hooks:
- id: flake8
name: Lint (Flake8) - AutoGPT
alias: flake8-autogpt
files: ^autogpt/(autogpt|scripts|tests)/
args: [--config=autogpt/.flake8]
- id: flake8
name: Lint (Flake8) - Forge
alias: flake8-forge
files: ^forge/(forge|tests)/
args: [--config=forge/.flake8]
- id: flake8
name: Lint (Flake8) - Benchmark
alias: flake8-benchmark
files: ^benchmark/(agbenchmark|tests)/((?!reports).)*[/.]
args: [--config=benchmark/.flake8]
- repo: local
# To have watertight type checking, we check *all* the files in an affected
# project. To trigger on poetry.lock we also reset the file `types` filter.
hooks:
- id: pyright
name: Typecheck - AutoGPT
alias: pyright-autogpt
entry: poetry -C autogpt run pyright
args: [-p, autogpt, autogpt]
# include forge source (since it's a path dependency) but exclude *_test.py files:
files: ^(autogpt/((autogpt|scripts|tests)/|poetry\.lock$)|forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
types: [file]
language: system
pass_filenames: false
- id: pyright
name: Typecheck - Forge
alias: pyright-forge
entry: poetry -C forge run pyright
args: [-p, forge, forge]
files: ^forge/(forge/|poetry\.lock$)
types: [file]
language: system
pass_filenames: false
- id: pyright
name: Typecheck - Benchmark
alias: pyright-benchmark
entry: poetry -C benchmark run pyright
args: [-p, benchmark, benchmark]
files: ^benchmark/(agbenchmark|tests)/
types: [file]
language: system
pass_filenames: false
- repo: local
hooks:
- id: pytest-autogpt
name: Run tests - AutoGPT (excl. slow tests)
entry: bash -c 'cd autogpt && poetry run pytest --cov=autogpt -m "not slow" tests/unit tests/integration'
# include forge source (since it's a path dependency) but exclude *_test.py files:
files: ^(autogpt/((autogpt|tests)/|poetry\.lock$)|forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
language: system
pass_filenames: false
- id: pytest-forge
name: Run tests - Forge (excl. slow tests)
entry: bash -c 'cd forge && poetry run pytest --cov=forge -m "not slow"'
files: ^forge/(forge/|tests/|poetry\.lock$)
language: system
pass_filenames: false
- id: pytest-benchmark
name: Run tests - Benchmark
entry: bash -c 'cd benchmark && poetry run pytest --cov=benchmark'
files: ^benchmark/(agbenchmark/|tests/|poetry\.lock$)
language: system
pass_filenames: false

View File

@ -1,11 +1,14 @@
[flake8] [flake8]
max-line-length = 88 max-line-length = 88
extend-exclude = # Ignore rules that conflict with Black code style
.*_cache/, extend-ignore = E203, W503
.venv, exclude =
.git,
__pycache__/,
*.pyc,
.pytest_cache/,
venv*/,
.venv/,
data/, data/,
logs/, logs/,
tests/unit/data/, tests/unit/data/,
extend-ignore =
# No whitespace before ':' conflicts with Black style for slices
E203,

View File

@ -1,47 +0,0 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-added-large-files
args: ['--maxkb=500']
- id: check-byte-order-marker
- id: check-case-conflict
- id: check-merge-conflict
- id: check-symlinks
- id: debug-statements
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
language_version: python3.10
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
language_version: python3.10
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
hooks:
- id: flake8
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: 'v1.3.0'
# hooks:
# - id: mypy
- repo: local
hooks:
# - id: autoflake
# name: autoflake
# entry: autoflake --in-place --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring autogpt tests
# language: python
# types: [ python ]
- id: pytest-check
name: pytest-check
entry: bash -c 'cd autogpt && poetry run pytest --cov=autogpt tests/unit'
language: system
pass_filenames: false
always_run: true

View File

@ -1,77 +0,0 @@
import asyncio
import logging
import sys
from pathlib import Path
from forge.config.ai_profile import AIProfile
from forge.config.config import ConfigBuilder
from forge.file_storage import FileStorageBackendName, get_storage
from forge.logging.config import configure_logging
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
from autogpt.agents.agent_manager import AgentManager
from autogpt.app.main import _configure_llm_provider, run_interaction_loop
LOG_DIR = Path(__file__).parent / "logs"
def run_specific_agent(task: str, continuous_mode: bool = False) -> None:
agent = bootstrap_agent(task, continuous_mode)
asyncio.run(run_interaction_loop(agent))
def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
configure_logging(
level=logging.DEBUG,
log_dir=LOG_DIR,
plain_console_output=True,
)
config = ConfigBuilder.build_config_from_env()
config.continuous_mode = continuous_mode
config.continuous_limit = 20
config.noninteractive_mode = True
ai_profile = AIProfile(
ai_name="AutoGPT",
ai_role="a multi-purpose AI assistant.",
ai_goals=[task],
)
agent_settings = AgentSettings(
name=Agent.default_settings.name,
agent_id=AgentManager.generate_id("AutoGPT-benchmark"),
description=Agent.default_settings.description,
ai_profile=ai_profile,
config=AgentConfiguration(
fast_llm=config.fast_llm,
smart_llm=config.smart_llm,
allow_fs_access=not config.restrict_to_workspace,
use_functions_api=config.openai_functions,
),
history=Agent.default_settings.history.copy(deep=True),
)
local = config.file_storage_backend == FileStorageBackendName.LOCAL
restrict_to_root = not local or config.restrict_to_workspace
file_storage = get_storage(
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root
)
file_storage.initialize()
agent = Agent(
settings=agent_settings,
llm_provider=_configure_llm_provider(config),
file_storage=file_storage,
legacy_config=config,
)
return agent
if __name__ == "__main__":
# The first argument is the script name itself, second is the task
if len(sys.argv) != 2:
print("Usage: python script.py <task>")
sys.exit(1)
task = sys.argv[1]
run_specific_agent(task, continuous_mode=True)

View File

@ -4,7 +4,7 @@ from forge.config.ai_directives import AIDirectives
from forge.config.ai_profile import AIProfile from forge.config.ai_profile import AIProfile
from forge.config.config import Config from forge.config.config import Config
from forge.file_storage.base import FileStorage from forge.file_storage.base import FileStorage
from forge.llm.providers import ChatModelProvider from forge.llm.providers import MultiProvider
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
@ -15,7 +15,7 @@ def create_agent(
ai_profile: AIProfile, ai_profile: AIProfile,
app_config: Config, app_config: Config,
file_storage: FileStorage, file_storage: FileStorage,
llm_provider: ChatModelProvider, llm_provider: MultiProvider,
directives: Optional[AIDirectives] = None, directives: Optional[AIDirectives] = None,
) -> Agent: ) -> Agent:
if not task: if not task:
@ -39,7 +39,7 @@ def configure_agent_with_state(
state: AgentSettings, state: AgentSettings,
app_config: Config, app_config: Config,
file_storage: FileStorage, file_storage: FileStorage,
llm_provider: ChatModelProvider, llm_provider: MultiProvider,
) -> Agent: ) -> Agent:
return _configure_agent( return _configure_agent(
state=state, state=state,
@ -51,7 +51,7 @@ def configure_agent_with_state(
def _configure_agent( def _configure_agent(
app_config: Config, app_config: Config,
llm_provider: ChatModelProvider, llm_provider: MultiProvider,
file_storage: FileStorage, file_storage: FileStorage,
agent_id: str = "", agent_id: str = "",
task: str = "", task: str = "",
@ -59,20 +59,22 @@ def _configure_agent(
directives: Optional[AIDirectives] = None, directives: Optional[AIDirectives] = None,
state: Optional[AgentSettings] = None, state: Optional[AgentSettings] = None,
) -> Agent: ) -> Agent:
if not (state or agent_id and task and ai_profile and directives): if state:
agent_state = state
elif agent_id and task and ai_profile and directives:
agent_state = state or create_agent_state(
agent_id=agent_id,
task=task,
ai_profile=ai_profile,
directives=directives,
app_config=app_config,
)
else:
raise TypeError( raise TypeError(
"Either (state) or (agent_id, task, ai_profile, directives)" "Either (state) or (agent_id, task, ai_profile, directives)"
" must be specified" " must be specified"
) )
agent_state = state or create_agent_state(
agent_id=agent_id,
task=task,
ai_profile=ai_profile,
directives=directives,
app_config=app_config,
)
return Agent( return Agent(
settings=agent_state, settings=agent_state,
llm_provider=llm_provider, llm_provider=llm_provider,

View File

@ -7,7 +7,7 @@ from forge.file_storage.base import FileStorage
if TYPE_CHECKING: if TYPE_CHECKING:
from autogpt.agents.agent import Agent from autogpt.agents.agent import Agent
from forge.config.config import Config from forge.config.config import Config
from forge.llm.providers.schema import ChatModelProvider from forge.llm.providers import MultiProvider
from .configurators import _configure_agent from .configurators import _configure_agent
from .profile_generator import generate_agent_profile_for_task from .profile_generator import generate_agent_profile_for_task
@ -18,7 +18,7 @@ async def generate_agent_for_task(
task: str, task: str,
app_config: Config, app_config: Config,
file_storage: FileStorage, file_storage: FileStorage,
llm_provider: ChatModelProvider, llm_provider: MultiProvider,
) -> Agent: ) -> Agent:
ai_profile, task_directives = await generate_agent_profile_for_task( ai_profile, task_directives = await generate_agent_profile_for_task(
task=task, task=task,

View File

@ -5,10 +5,10 @@ from forge.config.ai_directives import AIDirectives
from forge.config.ai_profile import AIProfile from forge.config.ai_profile import AIProfile
from forge.config.config import Config from forge.config.config import Config
from forge.llm.prompting import ChatPrompt, LanguageModelClassification, PromptStrategy from forge.llm.prompting import ChatPrompt, LanguageModelClassification, PromptStrategy
from forge.llm.providers import MultiProvider
from forge.llm.providers.schema import ( from forge.llm.providers.schema import (
AssistantChatMessage, AssistantChatMessage,
ChatMessage, ChatMessage,
ChatModelProvider,
CompletionModelFunction, CompletionModelFunction,
) )
from forge.models.config import SystemConfiguration, UserConfigurable from forge.models.config import SystemConfiguration, UserConfigurable
@ -141,7 +141,7 @@ class AgentProfileGeneratorConfiguration(SystemConfiguration):
required=True, required=True,
), ),
}, },
).schema ).dict()
) )
@ -160,7 +160,7 @@ class AgentProfileGenerator(PromptStrategy):
self._model_classification = model_classification self._model_classification = model_classification
self._system_prompt_message = system_prompt self._system_prompt_message = system_prompt
self._user_prompt_template = user_prompt_template self._user_prompt_template = user_prompt_template
self._create_agent_function = CompletionModelFunction.parse( self._create_agent_function = CompletionModelFunction.parse_obj(
create_agent_function create_agent_function
) )
@ -183,7 +183,7 @@ class AgentProfileGenerator(PromptStrategy):
def parse_response_content( def parse_response_content(
self, self,
response_content: AssistantChatMessage, response: AssistantChatMessage,
) -> tuple[AIProfile, AIDirectives]: ) -> tuple[AIProfile, AIDirectives]:
"""Parse the actual text response from the objective model. """Parse the actual text response from the objective model.
@ -195,15 +195,15 @@ class AgentProfileGenerator(PromptStrategy):
""" """
try: try:
if not response_content.tool_calls: if not response.tool_calls:
raise ValueError( raise ValueError(
f"LLM did not call {self._create_agent_function.name} function; " f"LLM did not call {self._create_agent_function.name} function; "
"agent profile creation failed" "agent profile creation failed"
) )
arguments: object = response_content.tool_calls[0].function.arguments arguments: object = response.tool_calls[0].function.arguments
ai_profile = AIProfile( ai_profile = AIProfile(
ai_name=arguments.get("name"), ai_name=arguments.get("name"), # type: ignore
ai_role=arguments.get("description"), ai_role=arguments.get("description"), # type: ignore
) )
ai_directives = AIDirectives( ai_directives = AIDirectives(
best_practices=arguments.get("directives", {}).get("best_practices"), best_practices=arguments.get("directives", {}).get("best_practices"),
@ -211,7 +211,7 @@ class AgentProfileGenerator(PromptStrategy):
resources=[], resources=[],
) )
except KeyError: except KeyError:
logger.debug(f"Failed to parse this response content: {response_content}") logger.debug(f"Failed to parse this response content: {response}")
raise raise
return ai_profile, ai_directives return ai_profile, ai_directives
@ -219,7 +219,7 @@ class AgentProfileGenerator(PromptStrategy):
async def generate_agent_profile_for_task( async def generate_agent_profile_for_task(
task: str, task: str,
app_config: Config, app_config: Config,
llm_provider: ChatModelProvider, llm_provider: MultiProvider,
) -> tuple[AIProfile, AIDirectives]: ) -> tuple[AIProfile, AIDirectives]:
"""Generates an AIConfig object from the given string. """Generates an AIConfig object from the given string.

View File

@ -24,7 +24,7 @@ class MyAgent(Agent):
def __init__( def __init__(
self, self,
settings: AgentSettings, settings: AgentSettings,
llm_provider: ChatModelProvider, llm_provider: MultiProvider
file_storage: FileStorage, file_storage: FileStorage,
legacy_config: Config, legacy_config: Config,
): ):

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import inspect import inspect
import logging import logging
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Any, ClassVar, Optional
import sentry_sdk import sentry_sdk
from forge.agent.base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings from forge.agent.base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
@ -14,7 +14,7 @@ from forge.agent.protocols import (
DirectiveProvider, DirectiveProvider,
MessageProvider, MessageProvider,
) )
from forge.command.command import Command, CommandOutput from forge.command.command import Command
from forge.components.action_history import ( from forge.components.action_history import (
ActionHistoryComponent, ActionHistoryComponent,
EpisodicActionHistory, EpisodicActionHistory,
@ -34,8 +34,8 @@ from forge.llm.prompting.utils import dump_prompt
from forge.llm.providers import ( from forge.llm.providers import (
AssistantFunctionCall, AssistantFunctionCall,
ChatMessage, ChatMessage,
ChatModelProvider,
ChatModelResponse, ChatModelResponse,
MultiProvider,
) )
from forge.llm.providers.utils import function_specs_from_commands from forge.llm.providers.utils import function_specs_from_commands
from forge.models.action import ( from forge.models.action import (
@ -76,7 +76,9 @@ class AgentConfiguration(BaseAgentConfiguration):
class AgentSettings(BaseAgentSettings): class AgentSettings(BaseAgentSettings):
config: AgentConfiguration = Field(default_factory=AgentConfiguration) config: AgentConfiguration = Field( # type: ignore
default_factory=AgentConfiguration
)
history: EpisodicActionHistory[OneShotAgentActionProposal] = Field( history: EpisodicActionHistory[OneShotAgentActionProposal] = Field(
default_factory=EpisodicActionHistory[OneShotAgentActionProposal] default_factory=EpisodicActionHistory[OneShotAgentActionProposal]
@ -86,8 +88,8 @@ class AgentSettings(BaseAgentSettings):
context: AgentContext = Field(default_factory=AgentContext) context: AgentContext = Field(default_factory=AgentContext)
class Agent(BaseAgent, Configurable[AgentSettings]): class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
default_settings: AgentSettings = AgentSettings( default_settings: ClassVar[AgentSettings] = AgentSettings(
name="Agent", name="Agent",
description=__doc__ if __doc__ else "", description=__doc__ if __doc__ else "",
) )
@ -95,7 +97,7 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
def __init__( def __init__(
self, self,
settings: AgentSettings, settings: AgentSettings,
llm_provider: ChatModelProvider, llm_provider: MultiProvider,
file_storage: FileStorage, file_storage: FileStorage,
legacy_config: Config, legacy_config: Config,
): ):
@ -280,7 +282,7 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
return result return result
async def _execute_tool(self, tool_call: AssistantFunctionCall) -> CommandOutput: async def _execute_tool(self, tool_call: AssistantFunctionCall) -> Any:
"""Execute the command and return the result """Execute the command and return the result
Args: Args:

View File

@ -43,7 +43,7 @@ class AssistantThoughts(ModelWithSummary):
class OneShotAgentActionProposal(ActionProposal): class OneShotAgentActionProposal(ActionProposal):
thoughts: AssistantThoughts thoughts: AssistantThoughts # type: ignore
class OneShotAgentPromptConfiguration(SystemConfiguration): class OneShotAgentPromptConfiguration(SystemConfiguration):
@ -186,11 +186,8 @@ class OneShotAgentPromptStrategy(PromptStrategy):
def response_format_instruction(self, use_functions_api: bool) -> tuple[str, str]: def response_format_instruction(self, use_functions_api: bool) -> tuple[str, str]:
response_schema = self.response_schema.copy(deep=True) response_schema = self.response_schema.copy(deep=True)
if ( assert response_schema.properties
use_functions_api if use_functions_api and "use_tool" in response_schema.properties:
and response_schema.properties
and "use_tool" in response_schema.properties
):
del response_schema.properties["use_tool"] del response_schema.properties["use_tool"]
# Unindent for performance # Unindent for performance
@ -288,10 +285,10 @@ class OneShotAgentPromptStrategy(PromptStrategy):
"Parsing object extracted from LLM response:\n" "Parsing object extracted from LLM response:\n"
f"{json.dumps(assistant_reply_dict, indent=4)}" f"{json.dumps(assistant_reply_dict, indent=4)}"
) )
parsed_response = OneShotAgentActionProposal.parse_obj(assistant_reply_dict)
if self.config.use_functions_api: if self.config.use_functions_api:
if not response.tool_calls: if not response.tool_calls:
raise InvalidAgentResponseError("Assistant did not use a tool") raise InvalidAgentResponseError("Assistant did not use a tool")
parsed_response.use_tool = response.tool_calls[0].function assistant_reply_dict["use_tool"] = response.tool_calls[0].function
parsed_response = OneShotAgentActionProposal.parse_obj(assistant_reply_dict)
return parsed_response return parsed_response

View File

@ -25,7 +25,7 @@ from forge.agent_protocol.models import (
) )
from forge.config.config import Config from forge.config.config import Config
from forge.file_storage import FileStorage from forge.file_storage import FileStorage
from forge.llm.providers import ChatModelProvider, ModelProviderBudget from forge.llm.providers import ModelProviderBudget, MultiProvider
from forge.models.action import ActionErrorResult, ActionSuccessResult from forge.models.action import ActionErrorResult, ActionSuccessResult
from forge.utils.const import ASK_COMMAND, FINISH_COMMAND from forge.utils.const import ASK_COMMAND, FINISH_COMMAND
from forge.utils.exceptions import AgentFinished, NotFoundError from forge.utils.exceptions import AgentFinished, NotFoundError
@ -49,7 +49,7 @@ class AgentProtocolServer:
app_config: Config, app_config: Config,
database: AgentDB, database: AgentDB,
file_storage: FileStorage, file_storage: FileStorage,
llm_provider: ChatModelProvider, llm_provider: MultiProvider,
): ):
self.app_config = app_config self.app_config = app_config
self.db = database self.db = database
@ -444,9 +444,7 @@ class AgentProtocolServer:
agent_id = task_agent_id(task_id) agent_id = task_agent_id(task_id)
return self.file_storage.clone_with_subroot(f"agents/{agent_id}/workspace") return self.file_storage.clone_with_subroot(f"agents/{agent_id}/workspace")
def _get_task_llm_provider( def _get_task_llm_provider(self, task: Task, step_id: str = "") -> MultiProvider:
self, task: Task, step_id: str = ""
) -> ChatModelProvider:
""" """
Configures the LLM provider with headers to link outgoing requests to the task. Configures the LLM provider with headers to link outgoing requests to the task.
""" """

View File

@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Optional
from colorama import Fore, Style from colorama import Fore, Style
from forge.agent_protocol.database import AgentDB from forge.agent_protocol.database import AgentDB
from forge.components.code_executor import ( from forge.components.code_executor.code_executor import (
is_docker_available, is_docker_available,
we_are_running_in_a_docker_container, we_are_running_in_a_docker_container,
) )
@ -82,7 +82,9 @@ async def run_auto_gpt(
local = config.file_storage_backend == FileStorageBackendName.LOCAL local = config.file_storage_backend == FileStorageBackendName.LOCAL
restrict_to_root = not local or config.restrict_to_workspace restrict_to_root = not local or config.restrict_to_workspace
file_storage = get_storage( file_storage = get_storage(
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root config.file_storage_backend,
root_path=Path("data"),
restrict_to_root=restrict_to_root,
) )
file_storage.initialize() file_storage.initialize()
@ -353,7 +355,9 @@ async def run_auto_gpt_server(
local = config.file_storage_backend == FileStorageBackendName.LOCAL local = config.file_storage_backend == FileStorageBackendName.LOCAL
restrict_to_root = not local or config.restrict_to_workspace restrict_to_root = not local or config.restrict_to_workspace
file_storage = get_storage( file_storage = get_storage(
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root config.file_storage_backend,
root_path=Path("data"),
restrict_to_root=restrict_to_root,
) )
file_storage.initialize() file_storage.initialize()

View File

@ -7,7 +7,7 @@ import re
import socket import socket
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Coroutine, ParamSpec, TypeVar from typing import Any, Callable, Coroutine, ParamSpec, TypeVar, cast
import requests import requests
from colorama import Fore, Style from colorama import Fore, Style
@ -88,7 +88,7 @@ def vcs_state_diverges_from_master() -> bool:
def get_git_user_email() -> str: def get_git_user_email() -> str:
try: try:
repo = Repo(search_parent_directories=True) repo = Repo(search_parent_directories=True)
return repo.config_reader().get_value("user", "email", default="") return cast(str, repo.config_reader().get_value("user", "email", default=""))
except InvalidGitRepositoryError: except InvalidGitRepositoryError:
return "" return ""

529
autogpt/poetry.lock generated

File diff suppressed because one or more lines are too long

View File

@ -1,9 +1,7 @@
[tool.poetry] [tool.poetry]
name = "agpt" name = "agpt"
version = "0.5.0" version = "0.5.0"
authors = [ authors = ["Significant Gravitas <support@agpt.co>"]
"Significant Gravitas <support@agpt.co>",
]
readme = "README.md" readme = "README.md"
description = "An open-source attempt to make GPT-4 autonomous" description = "An open-source attempt to make GPT-4 autonomous"
homepage = "https://github.com/Significant-Gravitas/AutoGPT/tree/master/autogpt" homepage = "https://github.com/Significant-Gravitas/AutoGPT/tree/master/autogpt"
@ -30,11 +28,10 @@ charset-normalizer = "^3.1.0"
click = "*" click = "*"
colorama = "^0.4.6" colorama = "^0.4.6"
distro = "^1.8.0" distro = "^1.8.0"
en-core-web-sm = {url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl"} en-core-web-sm = { url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl" }
fastapi = "^0.109.1" fastapi = "^0.109.1"
ftfy = "^6.1.1" ftfy = "^6.1.1"
google-api-python-client = "*" google-api-python-client = "*"
gTTS = "^2.3.1"
hypercorn = "^0.14.4" hypercorn = "^0.14.4"
inflection = "*" inflection = "*"
jsonschema = "*" jsonschema = "*"
@ -58,21 +55,18 @@ openapi-python-client = "^0.14.0"
# Benchmarking # Benchmarking
agbenchmark = { path = "../benchmark", optional = true } agbenchmark = { path = "../benchmark", optional = true }
# agbenchmark = {git = "https://github.com/Significant-Gravitas/AutoGPT.git", subdirectory = "benchmark", optional = true} # agbenchmark = {git = "https://github.com/Significant-Gravitas/AutoGPT.git", subdirectory = "benchmark", optional = true}
google-cloud-logging = "^3.8.0"
google-cloud-storage = "^2.13.0"
psycopg2-binary = "^2.9.9" psycopg2-binary = "^2.9.9"
[tool.poetry.extras] [tool.poetry.extras]
benchmark = ["agbenchmark"] benchmark = ["agbenchmark"]
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
black = "*" black = "^23.12.1"
boto3-stubs = {extras = ["s3"], version = "^1.33.6"} flake8 = "^7.0.0"
flake8 = "*"
gitpython = "^3.1.32" gitpython = "^3.1.32"
isort = "*" isort = "^5.13.1"
mypy = "*"
pre-commit = "*" pre-commit = "*"
pyright = "^1.1.364"
types-beautifulsoup4 = "*" types-beautifulsoup4 = "*"
types-colorama = "*" types-colorama = "*"
types-Markdown = "*" types-Markdown = "*"
@ -89,7 +83,7 @@ pytest-integration = "*"
pytest-mock = "*" pytest-mock = "*"
pytest-recording = "*" pytest-recording = "*"
pytest-xdist = "*" pytest-xdist = "*"
vcrpy = {git = "https://github.com/Significant-Gravitas/vcrpy.git", rev = "master"} vcrpy = { git = "https://github.com/Significant-Gravitas/vcrpy.git", rev = "master" }
[build-system] [build-system]
@ -101,50 +95,18 @@ build-backend = "poetry.core.masonry.api"
line-length = 88 line-length = 88
target-version = ['py310'] target-version = ['py310']
include = '\.pyi?$' include = '\.pyi?$'
packages = ["autogpt"]
extend-exclude = '.+/(dist|.venv|venv|build|data)/.+'
[tool.isort] [tool.isort]
profile = "black" profile = "black"
multi_line_output = 3 skip_glob = ["data"]
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
line_length = 88
sections = [
"FUTURE",
"STDLIB",
"THIRDPARTY",
"FIRSTPARTY",
"LOCALFOLDER"
]
extend_skip = [
"agbenchmark_config/temp_folder/",
"data/",
]
[tool.mypy] [tool.pyright]
follow_imports = 'skip' pythonVersion = "3.10"
check_untyped_defs = true exclude = ["data/**", "**/node_modules", "**/__pycache__", "**/.*"]
disallow_untyped_calls = true ignore = ["../forge/**"]
files = [
'autogpt/**/*.py',
'tests/**/*.py'
]
[[tool.mypy.overrides]]
module = [
'requests.*',
'yaml.*'
]
ignore_missing_imports = true
[tool.pytest.ini_options] [tool.pytest.ini_options]
markers = [ markers = ["slow", "requires_openai_api_key", "requires_huggingface_api_key"]
"requires_openai_api_key",
"requires_huggingface_api_key"
]

View File

@ -4,12 +4,12 @@ import sys
from importlib.metadata import version from importlib.metadata import version
try: try:
import poetry.factory # noqa import poetry.factory # type: ignore # noqa
except ModuleNotFoundError: except ModuleNotFoundError:
os.system(f"{sys.executable} -m pip install 'poetry>=1.6.1,<2.0.0'") os.system(f"{sys.executable} -m pip install 'poetry>=1.6.1,<2.0.0'")
from poetry.core.constraints.version.version import Version from poetry.core.constraints.version.version import Version # type: ignore
from poetry.factory import Factory from poetry.factory import Factory # type: ignore
def main(): def main():

View File

@ -20,7 +20,7 @@ from autogpt.app.utils import coroutine
) )
@coroutine @coroutine
async def generate_release_notes(repo_path: Optional[Path] = None): async def generate_release_notes(repo_path: Optional[Path] = None):
logger = logging.getLogger(generate_release_notes.name) logger = logging.getLogger(generate_release_notes.name) # pyright: ignore
repo = Repo(repo_path, search_parent_directories=True) repo = Repo(repo_path, search_parent_directories=True)
tags = list(repo.tags) tags = list(repo.tags)

View File

@ -12,7 +12,7 @@ from forge.file_storage.local import (
FileStorageConfiguration, FileStorageConfiguration,
LocalFileStorage, LocalFileStorage,
) )
from forge.llm.providers import ChatModelProvider from forge.llm.providers import MultiProvider
from forge.logging.config import configure_logging from forge.logging.config import configure_logging
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
@ -71,14 +71,12 @@ def setup_logger(config: Config):
@pytest.fixture @pytest.fixture
def llm_provider(config: Config) -> ChatModelProvider: def llm_provider(config: Config) -> MultiProvider:
return _configure_llm_provider(config) return _configure_llm_provider(config)
@pytest.fixture @pytest.fixture
def agent( def agent(config: Config, llm_provider: MultiProvider, storage: FileStorage) -> Agent:
config: Config, llm_provider: ChatModelProvider, storage: FileStorage
) -> Agent:
ai_profile = AIProfile( ai_profile = AIProfile(
ai_name="Base", ai_name="Base",
ai_role="A base AI", ai_role="A base AI",

View File

@ -1,13 +1,16 @@
from pathlib import Path
import pytest import pytest
from forge.config.ai_profile import AIProfile from forge.config.ai_profile import AIProfile
from forge.config.config import Config from forge.config.config import Config
from forge.file_storage import FileStorageBackendName, get_storage from forge.file_storage import FileStorageBackendName, get_storage
from forge.llm.providers import MultiProvider
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
@pytest.fixture @pytest.fixture
def dummy_agent(config: Config, llm_provider, memory_json_file): def dummy_agent(config: Config, llm_provider: MultiProvider):
ai_profile = AIProfile( ai_profile = AIProfile(
ai_name="Dummy Agent", ai_name="Dummy Agent",
ai_role="Dummy Role", ai_role="Dummy Role",
@ -31,7 +34,9 @@ def dummy_agent(config: Config, llm_provider, memory_json_file):
local = config.file_storage_backend == FileStorageBackendName.LOCAL local = config.file_storage_backend == FileStorageBackendName.LOCAL
restrict_to_root = not local or config.restrict_to_workspace restrict_to_root = not local or config.restrict_to_workspace
file_storage = get_storage( file_storage = get_storage(
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root config.file_storage_backend,
root_path=Path("data"),
restrict_to_root=restrict_to_root,
) )
file_storage.initialize() file_storage.initialize()

View File

@ -4,7 +4,7 @@ import tempfile
from pathlib import Path from pathlib import Path
import pytest import pytest
from forge.components.code_executor import ( from forge.components.code_executor.code_executor import (
ALLOWLIST_CONTROL, ALLOWLIST_CONTROL,
CodeExecutorComponent, CodeExecutorComponent,
is_docker_available, is_docker_available,

View File

@ -257,17 +257,3 @@ def test_huggingface_fail_request_bad_image(
result = image_gen_component.generate_image("astronaut riding a horse", 512) result = image_gen_component.generate_image("astronaut riding a horse", 512)
assert result == "Error creating image." assert result == "Error creating image."
def test_huggingface_fail_missing_api_token(
mocker, image_gen_component: ImageGeneratorComponent, agent: Agent
):
agent.legacy_config.image_provider = "huggingface"
agent.legacy_config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"
# Mock requests.post to raise ValueError
mocker.patch("requests.post", side_effect=ValueError)
# Verify request raises an error.
with pytest.raises(ValueError):
image_gen_component.generate_image("astronaut riding a horse", 512)

View File

@ -67,8 +67,8 @@ def test_missing_azure_config(config: Config) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
config.openai_credentials.load_azure_config(config_file) config.openai_credentials.load_azure_config(config_file)
assert config.openai_credentials.api_type != "azure" assert config.openai_credentials.api_type != SecretStr("azure")
assert config.openai_credentials.api_version == "" assert config.openai_credentials.api_version is None
assert config.openai_credentials.azure_model_to_deploy_id_map is None assert config.openai_credentials.azure_model_to_deploy_id_map is None
@ -98,8 +98,8 @@ azure_model_map:
def test_azure_config(config_with_azure: Config) -> None: def test_azure_config(config_with_azure: Config) -> None:
assert (credentials := config_with_azure.openai_credentials) is not None assert (credentials := config_with_azure.openai_credentials) is not None
assert credentials.api_type == "azure" assert credentials.api_type == SecretStr("azure")
assert credentials.api_version == "2023-06-01-preview" assert credentials.api_version == SecretStr("2023-06-01-preview")
assert credentials.azure_endpoint == SecretStr("https://dummy.openai.azure.com") assert credentials.azure_endpoint == SecretStr("https://dummy.openai.azure.com")
assert credentials.azure_model_to_deploy_id_map == { assert credentials.azure_model_to_deploy_id_map == {
config_with_azure.fast_llm: "FAST-LLM_ID", config_with_azure.fast_llm: "FAST-LLM_ID",

View File

@ -4,7 +4,7 @@ from pathlib import Path
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from forge.file_storage import GCSFileStorage, GCSFileStorageConfiguration from forge.file_storage.gcs import GCSFileStorage, GCSFileStorageConfiguration
from google.auth.exceptions import GoogleAuthError from google.auth.exceptions import GoogleAuthError
from google.cloud import storage from google.cloud import storage
from google.cloud.exceptions import NotFound from google.cloud.exceptions import NotFound
@ -14,6 +14,8 @@ try:
except GoogleAuthError: except GoogleAuthError:
pytest.skip("Google Cloud Authentication not configured", allow_module_level=True) pytest.skip("Google Cloud Authentication not configured", allow_module_level=True)
pytestmark = pytest.mark.slow
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def gcs_bucket_name() -> str: def gcs_bucket_name() -> str:
@ -26,7 +28,7 @@ def gcs_root() -> Path:
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def gcs_storage_uninitialized(gcs_bucket_name: str, gcs_root: Path) -> GCSFileStorage: def gcs_storage_uninitialized(gcs_bucket_name: str, gcs_root: Path):
os.environ["STORAGE_BUCKET"] = gcs_bucket_name os.environ["STORAGE_BUCKET"] = gcs_bucket_name
storage_config = GCSFileStorageConfiguration.from_env() storage_config = GCSFileStorageConfiguration.from_env()
storage_config.root = gcs_root storage_config.root = gcs_root
@ -52,7 +54,7 @@ def test_initialize(gcs_bucket_name: str, gcs_storage_uninitialized: GCSFileStor
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def gcs_storage(gcs_storage_uninitialized: GCSFileStorage) -> GCSFileStorage: def gcs_storage(gcs_storage_uninitialized: GCSFileStorage):
(gcs_storage := gcs_storage_uninitialized).initialize() (gcs_storage := gcs_storage_uninitialized).initialize()
yield gcs_storage # type: ignore yield gcs_storage # type: ignore
@ -77,7 +79,7 @@ TEST_FILES: list[tuple[str | Path, str]] = [
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def gcs_storage_with_files(gcs_storage: GCSFileStorage) -> GCSFileStorage: async def gcs_storage_with_files(gcs_storage: GCSFileStorage):
for file_name, file_content in TEST_FILES: for file_name, file_content in TEST_FILES:
gcs_storage._bucket.blob( gcs_storage._bucket.blob(
str(gcs_storage.get_path(file_name)) str(gcs_storage.get_path(file_name))

View File

@ -1,7 +1,7 @@
import json import json
import pytest import pytest
from forge.json import json_loads from forge.json.parsing import json_loads
_JSON_FIXABLE: list[tuple[str, str]] = [ _JSON_FIXABLE: list[tuple[str, str]] = [
# Missing comma # Missing comma

View File

@ -1,7 +1,7 @@
from pathlib import Path from pathlib import Path
import pytest import pytest
from forge.file_storage import FileStorageConfiguration, LocalFileStorage from forge.file_storage.local import FileStorageConfiguration, LocalFileStorage
_ACCESSIBLE_PATHS = [ _ACCESSIBLE_PATHS = [
Path("."), Path("."),

View File

@ -5,7 +5,7 @@ from pathlib import Path
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from forge.file_storage import S3FileStorage, S3FileStorageConfiguration from forge.file_storage.s3 import S3FileStorage, S3FileStorageConfiguration
if not (os.getenv("S3_ENDPOINT_URL") and os.getenv("AWS_ACCESS_KEY_ID")): if not (os.getenv("S3_ENDPOINT_URL") and os.getenv("AWS_ACCESS_KEY_ID")):
pytest.skip("S3 environment variables are not set", allow_module_level=True) pytest.skip("S3 environment variables are not set", allow_module_level=True)
@ -22,7 +22,7 @@ def s3_root() -> Path:
@pytest.fixture @pytest.fixture
def s3_storage_uninitialized(s3_bucket_name: str, s3_root: Path) -> S3FileStorage: def s3_storage_uninitialized(s3_bucket_name: str, s3_root: Path):
os.environ["STORAGE_BUCKET"] = s3_bucket_name os.environ["STORAGE_BUCKET"] = s3_bucket_name
storage_config = S3FileStorageConfiguration.from_env() storage_config = S3FileStorageConfiguration.from_env()
storage_config.root = s3_root storage_config.root = s3_root
@ -36,12 +36,13 @@ def test_initialize(s3_bucket_name: str, s3_storage_uninitialized: S3FileStorage
# test that the bucket doesn't exist yet # test that the bucket doesn't exist yet
with pytest.raises(ClientError): with pytest.raises(ClientError):
s3.meta.client.head_bucket(Bucket=s3_bucket_name) s3.meta.client.head_bucket(Bucket=s3_bucket_name) # pyright: ignore
s3_storage_uninitialized.initialize() s3_storage_uninitialized.initialize()
# test that the bucket has been created # test that the bucket has been created
s3.meta.client.head_bucket(Bucket=s3_bucket_name) s3.meta.client.head_bucket(Bucket=s3_bucket_name) # pyright: ignore
# FIXME: remove the "pyright: ignore" comments after moving this test file to forge
def test_workspace_bucket_name( def test_workspace_bucket_name(
@ -52,7 +53,7 @@ def test_workspace_bucket_name(
@pytest.fixture @pytest.fixture
def s3_storage(s3_storage_uninitialized: S3FileStorage) -> S3FileStorage: def s3_storage(s3_storage_uninitialized: S3FileStorage):
(s3_storage := s3_storage_uninitialized).initialize() (s3_storage := s3_storage_uninitialized).initialize()
yield s3_storage # type: ignore yield s3_storage # type: ignore
@ -71,7 +72,7 @@ TEST_FILES: list[tuple[str | Path, str]] = [
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def s3_storage_with_files(s3_storage: S3FileStorage) -> S3FileStorage: async def s3_storage_with_files(s3_storage: S3FileStorage):
for file_name, file_content in TEST_FILES: for file_name, file_content in TEST_FILES:
s3_storage._bucket.Object(str(s3_storage.get_path(file_name))).put( s3_storage._bucket.Object(str(s3_storage.get_path(file_name))).put(
Body=file_content Body=file_content

View File

@ -1,6 +1,7 @@
import logging import logging
import os import os
from hashlib import sha256 from hashlib import sha256
from typing import cast
import pytest import pytest
from openai import OpenAI from openai import OpenAI
@ -53,11 +54,14 @@ def cached_openai_client(mocker: MockerFixture) -> OpenAI:
def _patched_prepare_options(self, options: FinalRequestOptions): def _patched_prepare_options(self, options: FinalRequestOptions):
_prepare_options(options) _prepare_options(options)
if not options.json_data:
return
headers: dict[str, str | Omit] = ( headers: dict[str, str | Omit] = (
{**options.headers} if is_given(options.headers) else {} {**options.headers} if is_given(options.headers) else {}
) )
options.headers = headers options.headers = headers
data: dict = options.json_data data = cast(dict, options.json_data)
logging.getLogger("cached_openai_client").debug( logging.getLogger("cached_openai_client").debug(
f"Outgoing API request: {headers}\n{data if data else None}" f"Outgoing API request: {headers}\n{data if data else None}"

View File

@ -1,15 +1,12 @@
[flake8] [flake8]
max-line-length = 88 max-line-length = 88
select = "E303, W293, W291, W292, E305, E231, E302" # Ignore rules that conflict with Black code style
extend-ignore = E203, W503
exclude = exclude =
.tox, __pycache__/,
__pycache__,
*.pyc, *.pyc,
.env .pytest_cache/,
venv*/*, venv*/,
.venv/*, .venv/,
reports/*, reports/,
dist/*, agbenchmark/reports/,
agent/*,
code,
agbenchmark/challenges/*

View File

@ -1,36 +0,0 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-added-large-files
args: ['--maxkb=500']
- id: check-byte-order-marker
- id: check-case-conflict
- id: check-merge-conflict
- id: check-symlinks
- id: debug-statements
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
language_version: python3.10
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
language_version: python3.10
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.3.0'
hooks:
- id: mypy
- repo: local
hooks:
- id: autoflake
name: autoflake
entry: autoflake --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring --in-place agbenchmark
language: python
types: [ python ]

View File

@ -28,7 +28,7 @@ async def run_api_agent(
configuration = Configuration(host=config.host) configuration = Configuration(host=config.host)
async with ApiClient(configuration) as api_client: async with ApiClient(configuration) as api_client:
api_instance = AgentApi(api_client) api_instance = AgentApi(api_client)
task_request_body = TaskRequestBody(input=task) task_request_body = TaskRequestBody(input=task, additional_input=None)
start_time = time.time() start_time = time.time()
response = await api_instance.create_agent_task( response = await api_instance.create_agent_task(

View File

@ -106,8 +106,8 @@ def find_agbenchmark_without_uvicorn():
class CreateReportRequest(BaseModel): class CreateReportRequest(BaseModel):
test: str = None test: str
test_run_id: str = None test_run_id: str
# category: Optional[str] = [] # category: Optional[str] = []
mock: Optional[bool] = False mock: Optional[bool] = False
@ -178,8 +178,8 @@ def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
logger.debug(f"Benchmark finished running in {time.time() - start_time} s") logger.debug(f"Benchmark finished running in {time.time() - start_time} s")
# List all folders in the current working directory # List all folders in the current working directory
path_reports = agbenchmark_config.reports_folder reports_folder = agbenchmark_config.reports_folder
folders = [folder for folder in path_reports.iterdir() if folder.is_dir()] folders = [folder for folder in reports_folder.iterdir() if folder.is_dir()]
# Sort the folders based on their names # Sort the folders based on their names
sorted_folders = sorted(folders, key=lambda x: x.name) sorted_folders = sorted(folders, key=lambda x: x.name)
@ -196,13 +196,14 @@ def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
data = json.load(file) data = json.load(file)
logger.debug(f"Report data: {data}") logger.debug(f"Report data: {data}")
else: else:
logger.error( raise HTTPException(
502,
"Could not get result after running benchmark: " "Could not get result after running benchmark: "
f"'report.json' does not exist in '{latest_folder}'" f"'report.json' does not exist in '{latest_folder}'",
) )
else: else:
logger.error( raise HTTPException(
"Could not get result after running benchmark: no reports found" 504, "Could not get result after running benchmark: no reports found"
) )
return data return data
@ -239,7 +240,9 @@ def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
api_instance = AgentApi(api_client) api_instance = AgentApi(api_client)
task_input = challenge_info.task task_input = challenge_info.task
task_request_body = TaskRequestBody(input=task_input) task_request_body = TaskRequestBody(
input=task_input, additional_input=None
)
task_response = await api_instance.create_agent_task( task_response = await api_instance.create_agent_task(
task_request_body=task_request_body task_request_body=task_request_body
) )
@ -276,7 +279,7 @@ def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
# Forward the request # Forward the request
response = await client.post( response = await client.post(
new_url, new_url,
data=await request.body(), content=await request.body(),
headers=dict(request.headers), headers=dict(request.headers),
) )

View File

@ -1,7 +1,7 @@
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import AsyncIterator, ClassVar, Optional from typing import AsyncIterator, Awaitable, ClassVar, Optional
import pytest import pytest
from agent_protocol_client import AgentApi, Step from agent_protocol_client import AgentApi, Step
@ -54,7 +54,7 @@ class BaseChallenge(ABC):
config: AgentBenchmarkConfig, config: AgentBenchmarkConfig,
request: pytest.FixtureRequest, request: pytest.FixtureRequest,
i_attempt: int, i_attempt: int,
) -> None: ) -> None | Awaitable[None]:
""" """
Test method for use by Pytest-based benchmark sessions. Should return normally Test method for use by Pytest-based benchmark sessions. Should return normally
if the challenge passes, and raise a (preferably descriptive) error otherwise. if the challenge passes, and raise a (preferably descriptive) error otherwise.

View File

@ -1,4 +1,3 @@
from collections import deque
import glob import glob
import json import json
import logging import logging
@ -6,19 +5,17 @@ import os
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
from collections import deque
from pathlib import Path from pathlib import Path
from typing import Any, ClassVar, Iterator, Literal, Optional from typing import Annotated, Any, ClassVar, Iterator, Literal, Optional
import pytest import pytest
from agent_protocol_client import ( from agent_protocol_client import AgentApi, ApiClient
AgentApi, from agent_protocol_client import Configuration as ClientConfig
ApiClient, from agent_protocol_client import Step
Configuration as ClientConfig,
Step,
)
from colorama import Fore, Style from colorama import Fore, Style
from openai import _load_client as get_openai_client from openai import _load_client as get_openai_client
from pydantic import BaseModel, constr, Field, validator from pydantic import BaseModel, Field, constr, validator
from agbenchmark.agent_api_interface import download_agent_artifacts_into_folder from agbenchmark.agent_api_interface import download_agent_artifacts_into_folder
from agbenchmark.agent_interface import copy_challenge_artifacts_into_workspace from agbenchmark.agent_interface import copy_challenge_artifacts_into_workspace
@ -49,7 +46,7 @@ class BuiltinChallengeSpec(BaseModel):
class Info(BaseModel): class Info(BaseModel):
difficulty: DifficultyLevel difficulty: DifficultyLevel
description: constr(regex=r"^Tests if the agent can.*") description: Annotated[str, constr(regex=r"^Tests if the agent can.*")]
side_effects: list[str] = Field(default_factory=list) side_effects: list[str] = Field(default_factory=list)
info: Info info: Info
@ -184,7 +181,7 @@ class BuiltinChallenge(BaseChallenge):
steps: list[Step] = [] steps: list[Step] = []
try: try:
async for step in self.run_challenge( async for step in self.run_challenge(
config, timeout, mock=request.config.getoption("--mock") config, timeout, mock=bool(request.config.getoption("--mock"))
): ):
if not task_id: if not task_id:
task_id = step.task_id task_id = step.task_id
@ -199,6 +196,8 @@ class BuiltinChallenge(BaseChallenge):
timed_out = False timed_out = False
except TimeoutError: except TimeoutError:
timed_out = True timed_out = True
assert isinstance(request.node, pytest.Item)
request.node.user_properties.append(("steps", steps)) request.node.user_properties.append(("steps", steps))
request.node.user_properties.append(("n_steps", n_steps)) request.node.user_properties.append(("n_steps", n_steps))
request.node.user_properties.append(("timed_out", timed_out)) request.node.user_properties.append(("timed_out", timed_out))
@ -411,15 +410,10 @@ class BuiltinChallenge(BaseChallenge):
def load_builtin_challenges() -> Iterator[type[BuiltinChallenge]]: def load_builtin_challenges() -> Iterator[type[BuiltinChallenge]]:
logger.info("Loading built-in challenges...") logger.info("Loading built-in challenges...")
challenges_path = os.path.dirname(__file__) challenges_path = Path(__file__).parent
logger.debug(f"Looking for challenge spec files in {challenges_path}...") logger.debug(f"Looking for challenge spec files in {challenges_path}...")
json_files = deque( json_files = deque(challenges_path.rglob("data.json"))
glob.glob(
f"{challenges_path}/**/data.json",
recursive=True,
)
)
logger.debug(f"Found {len(json_files)} built-in challenges.") logger.debug(f"Found {len(json_files)} built-in challenges.")
@ -431,7 +425,7 @@ def load_builtin_challenges() -> Iterator[type[BuiltinChallenge]]:
ignored += 1 ignored += 1
continue continue
challenge = BuiltinChallenge.from_challenge_spec_file(Path(json_file)) challenge = BuiltinChallenge.from_challenge_spec_file(json_file)
logger.debug(f"Generated test for {challenge.info.name}") logger.debug(f"Generated test for {challenge.info.name}")
yield challenge yield challenge
@ -442,8 +436,8 @@ def load_builtin_challenges() -> Iterator[type[BuiltinChallenge]]:
) )
def _challenge_should_be_ignored(json_file_path: str): def _challenge_should_be_ignored(json_file_path: Path):
return ( return (
"challenges/deprecated" in json_file_path "challenges/deprecated" in json_file_path.as_posix()
or "challenges/library" in json_file_path or "challenges/library" in json_file_path.as_posix()
) )

View File

@ -23,9 +23,10 @@ def test_get_ethereum_price() -> None:
real_eth_price_value = float(real_eth_price) real_eth_price_value = float(real_eth_price)
# Check if the eth price is within $50 of the actual Ethereum price # Check if the eth price is within $50 of the actual Ethereum price
assert ( assert abs(real_eth_price_value - eth_price_value) <= 50, (
abs(real_eth_price_value - eth_price_value) <= 50 "AssertionError: Ethereum price is not within $50 of the actual Ethereum price "
), f"AssertionError: Ethereum price is not within $50 of the actual Ethereum price (Provided price: ${eth_price}, Real price: ${real_eth_price})" f"(Provided price: ${eth_price}, Real price: ${real_eth_price})"
)
print("Matches") print("Matches")

View File

@ -23,9 +23,10 @@ def test_get_ethereum_price() -> None:
real_eth_price_value = float(real_eth_price) real_eth_price_value = float(real_eth_price)
# Check if the eth price is within $50 of the actual Ethereum price # Check if the eth price is within $50 of the actual Ethereum price
assert ( assert abs(real_eth_price_value - eth_price_value) <= 50, (
abs(real_eth_price_value - eth_price_value) <= 50 "AssertionError: Ethereum price is not within $50 of the actual Ethereum price "
), f"AssertionError: Ethereum price is not within $50 of the actual Ethereum price (Provided price: ${eth_price}, Real price: ${real_eth_price})" f"(Provided price: ${eth_price}, Real price: ${real_eth_price})"
)
print("Matches") print("Matches")

View File

@ -1,4 +1,3 @@
# mypy: ignore-errors
from typing import List, Optional from typing import List, Optional

View File

@ -1,4 +1,4 @@
# mypy: ignore-errors # pyright: reportMissingImports=false
from typing import List from typing import List
from sample_code import three_sum from sample_code import three_sum

View File

@ -21,7 +21,6 @@ def generate_password(length: int = 8) -> str:
if __name__ == "__main__": if __name__ == "__main__":
password_length = ( password_length = (
int(sys.argv[sys.argv.index("--length") + 1]) int(sys.argv[sys.argv.index("--length") + 1]) if "--length" in sys.argv else 8
if "--length" in sys.argv else 8
) )
print(generate_password(password_length)) print(generate_password(password_length))

View File

@ -1,3 +1,4 @@
# pyright: reportMissingImports=false
import unittest import unittest
import password_generator import password_generator
@ -18,7 +19,9 @@ class TestPasswordGenerator(unittest.TestCase):
def test_password_content(self): def test_password_content(self):
password = password_generator.generate_password() password = password_generator.generate_password()
self.assertTrue(any(c.isdigit() for c in password)) self.assertTrue(any(c.isdigit() for c in password))
self.assertTrue(any(c in password_generator.string.punctuation for c in password)) self.assertTrue(
any(c in password_generator.string.punctuation for c in password)
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,3 +1,4 @@
# pyright: reportMissingImports=false
import unittest import unittest
from url_shortener import retrieve_url, shorten_url from url_shortener import retrieve_url, shorten_url

View File

@ -56,7 +56,7 @@ def winner(board):
def getLocation(): def getLocation():
location = input( location = input(
"Choose where to play. Enter two numbers separated by a comma, for example: 1,1 " "Choose where to play. Enter two numbers separated by a comma [example: 1,1]: "
) )
print(f"\nYou picked {location}") print(f"\nYou picked {location}")
coordinates = [int(x) for x in location.split(",")] coordinates = [int(x) for x in location.split(",")]
@ -69,7 +69,8 @@ def getLocation():
): ):
print("You inputted a location in an invalid format") print("You inputted a location in an invalid format")
location = input( location = input(
"Choose where to play. Enter two numbers separated by a comma, for example: 1,1 " "Choose where to play. Enter two numbers separated by a comma "
"[example: 1,1]: "
) )
coordinates = [int(x) for x in location.split(",")] coordinates = [int(x) for x in location.split(",")]
return coordinates return coordinates

View File

@ -37,15 +37,14 @@ class GameStatus(BaseModel):
winner: Optional[str] winner: Optional[str]
from typing import List
class Game(BaseModel): class Game(BaseModel):
game_id: str game_id: str
players: List[str] players: list[str]
board: dict # This could represent the state of the game board, you might need to flesh this out further # This could represent the state of the game board,
ships: List[ShipPlacement] # List of ship placements for this game # you might need to flesh this out further:
turns: List[Turn] # List of turns that have been taken board: dict
ships: list[ShipPlacement] # List of ship placements for this game
turns: list[Turn] # List of turns that have been taken
class AbstractBattleship(ABC): class AbstractBattleship(ABC):
@ -86,7 +85,7 @@ class AbstractBattleship(ABC):
pass pass
@abstractmethod @abstractmethod
def get_game(self) -> Game: def get_game(self) -> Game | None:
""" """
Retrieve the state of the game. Retrieve the state of the game.
""" """
@ -103,5 +102,8 @@ class AbstractBattleship(ABC):
def create_game(self) -> None: def create_game(self) -> None:
""" """
Create a new game. Create a new game.
Returns:
str: The ID of the created game.
""" """
pass pass

View File

@ -1,3 +1,4 @@
# pyright: reportMissingImports=false
import pytest import pytest
from abstract_class import ShipPlacement, Turn from abstract_class import ShipPlacement, Turn
from battleship import Battleship from battleship import Battleship

View File

@ -50,7 +50,7 @@ def test_cant_hit_before_ships_placed(battleship_game):
def test_cant_place_ship_after_all_ships_placed(battleship_game, initialized_game_id): def test_cant_place_ship_after_all_ships_placed(battleship_game, initialized_game_id):
game = battleship_game.get_game(initialized_game_id) battleship_game.get_game(initialized_game_id)
additional_ship = ShipPlacement( additional_ship = ShipPlacement(
ship_type="carrier", start={"row": 2, "column": "E"}, direction="horizontal" ship_type="carrier", start={"row": 2, "column": "E"}, direction="horizontal"
) )

View File

@ -61,6 +61,7 @@ def test_ship_sinking_feedback(battleship_game, initialized_game_id):
{"row": 1, "column": "H"}, {"row": 1, "column": "H"},
] ]
response = None
for index, hit in enumerate(hits): for index, hit in enumerate(hits):
turn = Turn(target={"row": 2, "column": hit}) turn = Turn(target={"row": 2, "column": hit})
response = battleship_game.create_turn(initialized_game_id, turn) response = battleship_game.create_turn(initialized_game_id, turn)
@ -69,7 +70,7 @@ def test_ship_sinking_feedback(battleship_game, initialized_game_id):
static_turn = Turn(target=static_moves[index]) static_turn = Turn(target=static_moves[index])
battleship_game.create_turn(initialized_game_id, static_turn) battleship_game.create_turn(initialized_game_id, static_turn)
assert response.result == "sunk" assert response and response.result == "sunk"
def test_restart_game(battleship_game): def test_restart_game(battleship_game):

View File

@ -37,15 +37,14 @@ class GameStatus(BaseModel):
winner: Optional[str] winner: Optional[str]
from typing import List
class Game(BaseModel): class Game(BaseModel):
game_id: str game_id: str
players: List[str] players: list[str]
board: dict # This could represent the state of the game board, you might need to flesh this out further # This could represent the state of the game board,
ships: List[ShipPlacement] # List of ship placements for this game # you might need to flesh this out further:
turns: List[Turn] # List of turns that have been taken board: dict
ships: list[ShipPlacement] # List of ship placements for this game
turns: list[Turn] # List of turns that have been taken
class AbstractBattleship(ABC): class AbstractBattleship(ABC):
@ -86,7 +85,7 @@ class AbstractBattleship(ABC):
pass pass
@abstractmethod @abstractmethod
def get_game(self) -> Game: def get_game(self, game_id: str) -> Game | None:
""" """
Retrieve the state of the game. Retrieve the state of the game.
""" """
@ -100,8 +99,11 @@ class AbstractBattleship(ABC):
pass pass
@abstractmethod @abstractmethod
def create_game(self) -> None: def create_game(self) -> str:
""" """
Create a new game. Create a new game.
Returns:
str: The ID of the created game.
""" """
pass pass

View File

@ -1,14 +1,20 @@
from typing import Dict from typing import Dict
from abstract_class import (AbstractBattleship, Game, GameStatus, from abstract_class import (
ShipPlacement, Turn, TurnResponse) AbstractBattleship,
Game,
GameStatus,
ShipPlacement,
Turn,
TurnResponse,
)
class Battleship(AbstractBattleship): class Battleship(AbstractBattleship):
def __init__(self): def __init__(self):
self.games: Dict[int, Game] = {} self.games: Dict[str, Game] = {}
def create_game(self) -> int: def create_game(self) -> str:
game_id = str(len(self.games)) game_id = str(len(self.games))
new_game = Game( new_game = Game(
game_id=game_id, game_id=game_id,
@ -19,7 +25,7 @@ class Battleship(AbstractBattleship):
) )
self.games[game_id] = new_game self.games[game_id] = new_game
return new_game.game_id return game_id
def create_ship_placement(self, game_id: str, placement: ShipPlacement) -> None: def create_ship_placement(self, game_id: str, placement: ShipPlacement) -> None:
game = self.games.get(game_id) game = self.games.get(game_id)
@ -79,38 +85,34 @@ class Battleship(AbstractBattleship):
game.turns.append(turn) game.turns.append(turn)
if hit_ship == "hit": if not hit_ship or hit_ship == "hit": # if no ship or already hit
return TurnResponse(result="miss", ship_type=None) return TurnResponse(result="miss", ship_type=None)
if hit_ship: ship_placement = next(sp for sp in game.ships if sp.ship_type == hit_ship)
ship_placement = next(sp for sp in game.ships if sp.ship_type == hit_ship) start_row, start_col = (
ship_placement.start["row"],
ord(ship_placement.start["column"]) - ord("A"),
)
ship_positions = [
(
start_row + (i if ship_placement.direction == "vertical" else 0),
start_col + (i if ship_placement.direction == "horizontal" else 0),
)
for i in range(self.SHIP_LENGTHS[hit_ship])
]
if hit_ship: targeted_positions = {
ship_placement = next(sp for sp in game.ships if sp.ship_type == hit_ship) (t.target["row"], ord(t.target["column"]) - ord("A")) for t in game.turns
start_row, start_col = ship_placement.start["row"], ord( }
ship_placement.start["column"]
) - ord("A")
ship_positions = [
(
start_row + (i if ship_placement.direction == "vertical" else 0),
start_col + (i if ship_placement.direction == "horizontal" else 0),
)
for i in range(self.SHIP_LENGTHS[hit_ship])
]
targeted_positions = { game.board[(target_row, target_col)] = "hit"
(t.target["row"], ord(t.target["column"]) - ord("A"))
for t in game.turns
}
game.board[(target_row, target_col)] = "hit" if set(ship_positions).issubset(targeted_positions):
for pos in ship_positions:
if set(ship_positions).issubset(targeted_positions): game.board[pos] = "hit"
for pos in ship_positions: return TurnResponse(result="sunk", ship_type=hit_ship)
game.board[pos] = "hit" else:
return TurnResponse(result="sunk", ship_type=hit_ship) return TurnResponse(result="hit", ship_type=hit_ship)
else:
return TurnResponse(result="hit", ship_type=hit_ship)
def get_game_status(self, game_id: str) -> GameStatus: def get_game_status(self, game_id: str) -> GameStatus:
game = self.games.get(game_id) game = self.games.get(game_id)
@ -132,12 +134,12 @@ class Battleship(AbstractBattleship):
def get_winner(self, game_id: str) -> str: def get_winner(self, game_id: str) -> str:
game_status = self.get_game_status(game_id) game_status = self.get_game_status(game_id)
if game_status.is_game_over: if game_status.is_game_over and game_status.winner:
return game_status.winner return game_status.winner
else: else:
return None raise ValueError(f"Game {game_id} isn't over yet")
def get_game(self, game_id: str) -> Game: def get_game(self, game_id: str) -> Game | None:
return self.games.get(game_id) return self.games.get(game_id)
def delete_game(self, game_id: str) -> None: def delete_game(self, game_id: str) -> None:

View File

@ -50,7 +50,7 @@ def test_cant_hit_before_ships_placed(battleship_game):
def test_cant_place_ship_after_all_ships_placed(battleship_game, initialized_game_id): def test_cant_place_ship_after_all_ships_placed(battleship_game, initialized_game_id):
game = battleship_game.get_game(initialized_game_id) battleship_game.get_game(initialized_game_id)
additional_ship = ShipPlacement( additional_ship = ShipPlacement(
ship_type="carrier", start={"row": 2, "column": "E"}, direction="horizontal" ship_type="carrier", start={"row": 2, "column": "E"}, direction="horizontal"
) )

View File

@ -61,6 +61,7 @@ def test_ship_sinking_feedback(battleship_game, initialized_game_id):
{"row": 1, "column": "H"}, {"row": 1, "column": "H"},
] ]
response = None
for index, hit in enumerate(hits): for index, hit in enumerate(hits):
turn = Turn(target={"row": 2, "column": hit}) turn = Turn(target={"row": 2, "column": hit})
response = battleship_game.create_turn(initialized_game_id, turn) response = battleship_game.create_turn(initialized_game_id, turn)
@ -69,7 +70,7 @@ def test_ship_sinking_feedback(battleship_game, initialized_game_id):
static_turn = Turn(target=static_moves[index]) static_turn = Turn(target=static_moves[index])
battleship_game.create_turn(initialized_game_id, static_turn) battleship_game.create_turn(initialized_game_id, static_turn)
assert response.result == "sunk" assert response and response.result == "sunk"
def test_restart_game(battleship_game): def test_restart_game(battleship_game):

View File

@ -6,7 +6,7 @@ from typing import ClassVar, Iterator, Literal
import pytest import pytest
import requests import requests
from agent_protocol_client import AgentApi, Step from agent_protocol_client import AgentApi, Step
from pydantic import BaseModel, validator, ValidationError from pydantic import BaseModel, ValidationError, validator
from agbenchmark.config import AgentBenchmarkConfig from agbenchmark.config import AgentBenchmarkConfig
from agbenchmark.utils.data_types import Category, EvalResult from agbenchmark.utils.data_types import Category, EvalResult
@ -93,11 +93,12 @@ class Eval(ABC):
... ...
class StringEval(BaseModel, Eval): class BaseStringEval(BaseModel, Eval):
type: ReferenceAnswerType # type: ReferenceAnswerType
pass
class ExactStringMatchEval(StringEval): class ExactStringMatchEval(BaseStringEval):
type: Literal["exact_match"] = "exact_match" type: Literal["exact_match"] = "exact_match"
reference_answer: str reference_answer: str
@ -109,7 +110,7 @@ class ExactStringMatchEval(StringEval):
return string == self.reference_answer return string == self.reference_answer
class FuzzyStringMatchEval(StringEval): class FuzzyStringMatchEval(BaseStringEval):
type: Literal["fuzzy_match"] = "fuzzy_match" type: Literal["fuzzy_match"] = "fuzzy_match"
reference_answer: str reference_answer: str
@ -122,7 +123,7 @@ class FuzzyStringMatchEval(StringEval):
return self.reference_answer.lower() in string.lower() return self.reference_answer.lower() in string.lower()
class MustIncludeStringEval(StringEval): class MustIncludeStringEval(BaseStringEval):
type: Literal["must_include"] = "must_include" type: Literal["must_include"] = "must_include"
reference_answer: str reference_answer: str
@ -134,6 +135,9 @@ class MustIncludeStringEval(StringEval):
return self.reference_answer.lower() in string.lower() return self.reference_answer.lower() in string.lower()
StringEval = ExactStringMatchEval | FuzzyStringMatchEval | MustIncludeStringEval
class UrlMatchEval(BaseModel, Eval): class UrlMatchEval(BaseModel, Eval):
url: str url: str
"""Example: `"__WIKI__/wiki/Octopus"`""" """Example: `"__WIKI__/wiki/Octopus"`"""
@ -142,8 +146,8 @@ class UrlMatchEval(BaseModel, Eval):
def description(self) -> str: def description(self) -> str:
return f"Agent must navigate to '{self.url}'" return f"Agent must navigate to '{self.url}'"
def evaluate(self, url: str) -> bool: def evaluate(self, string: str) -> bool:
return url == resolve_uri(self.url) return string == resolve_uri(self.url)
class ProgramHtmlEval(BaseModel): class ProgramHtmlEval(BaseModel):
@ -258,7 +262,8 @@ class WebArenaChallengeSpec(BaseModel):
f"{' and '.join(s.base_url for s in sites)}.\n\n" f"{' and '.join(s.base_url for s in sites)}.\n\n"
+ "\n".join( + "\n".join(
s.additional_info.format(url=s.base_url) s.additional_info.format(url=s.base_url)
for s in sites if s.additional_info for s in sites
if s.additional_info
) )
).strip() ).strip()
@ -391,7 +396,9 @@ class WebArenaChallenge(BaseChallenge):
if request.config.getoption("--nc"): if request.config.getoption("--nc"):
timeout = 100000 timeout = 100000
elif cutoff := request.config.getoption("--cutoff"): elif cutoff := request.config.getoption("--cutoff"):
timeout = int(cutoff) timeout = int(cutoff) # type: ignore
assert isinstance(request.node, pytest.Item)
n_steps = 0 n_steps = 0
timed_out = None timed_out = None
@ -400,7 +407,7 @@ class WebArenaChallenge(BaseChallenge):
eval_results_per_step: list[list[tuple[_Eval, EvalResult]]] = [] eval_results_per_step: list[list[tuple[_Eval, EvalResult]]] = []
try: try:
async for step in self.run_challenge( async for step in self.run_challenge(
config, timeout, mock=request.config.getoption("--mock") config, timeout, mock=bool(request.config.getoption("--mock"))
): ):
if not step.output: if not step.output:
logger.warn(f"Step has no output: {step}") logger.warn(f"Step has no output: {step}")
@ -415,7 +422,7 @@ class WebArenaChallenge(BaseChallenge):
) )
step_eval_results = self.evaluate_step_result( step_eval_results = self.evaluate_step_result(
step, mock=request.config.getoption("--mock") step, mock=bool(request.config.getoption("--mock"))
) )
logger.debug(f"Intermediary results: {step_eval_results}") logger.debug(f"Intermediary results: {step_eval_results}")
eval_results_per_step.append(step_eval_results) eval_results_per_step.append(step_eval_results)
@ -462,7 +469,7 @@ class WebArenaChallenge(BaseChallenge):
def load_webarena_challenges( def load_webarena_challenges(
skip_unavailable: bool = True skip_unavailable: bool = True,
) -> Iterator[type[WebArenaChallenge]]: ) -> Iterator[type[WebArenaChallenge]]:
logger.info("Loading WebArena challenges...") logger.info("Loading WebArena challenges...")

View File

@ -123,8 +123,10 @@ def check_regression(request: pytest.FixtureRequest) -> None:
with contextlib.suppress(FileNotFoundError): with contextlib.suppress(FileNotFoundError):
rt_tracker = RegressionTestsTracker(agbenchmark_config.regression_tests_file) rt_tracker = RegressionTestsTracker(agbenchmark_config.regression_tests_file)
assert isinstance(request.node, pytest.Function)
assert isinstance(request.node.parent, pytest.Class)
test_name = request.node.parent.name test_name = request.node.parent.name
challenge_location = getattr(request.node.parent.cls, "CHALLENGE_LOCATION", "") challenge_location = getattr(request.node.cls, "CHALLENGE_LOCATION", "")
skip_string = f"Skipping {test_name} at {challenge_location}" skip_string = f"Skipping {test_name} at {challenge_location}"
# Check if the test name exists in the regression tests # Check if the test name exists in the regression tests
@ -148,7 +150,9 @@ def mock(request: pytest.FixtureRequest) -> bool:
Returns: Returns:
bool: Whether `--mock` is set for this session. bool: Whether `--mock` is set for this session.
""" """
return request.config.getoption("--mock") mock = request.config.getoption("--mock")
assert isinstance(mock, bool)
return mock
test_reports: dict[str, Test] = {} test_reports: dict[str, Test] = {}
@ -221,7 +225,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc):
def pytest_collection_modifyitems( def pytest_collection_modifyitems(
items: list[pytest.Item], config: pytest.Config items: list[pytest.Function], config: pytest.Config
) -> None: ) -> None:
""" """
Pytest hook that is called after initial test collection has been performed. Pytest hook that is called after initial test collection has been performed.
@ -248,8 +252,9 @@ def pytest_collection_modifyitems(
i = 0 i = 0
while i < len(items): while i < len(items):
item = items[i] item = items[i]
assert item.cls and issubclass(item.cls, BaseChallenge)
challenge = item.cls challenge = item.cls
challenge_name = item.cls.__name__ challenge_name = challenge.info.name
if not issubclass(challenge, BaseChallenge): if not issubclass(challenge, BaseChallenge):
item.warn( item.warn(

View File

@ -18,9 +18,9 @@ def run_benchmark(
maintain: bool = False, maintain: bool = False,
improve: bool = False, improve: bool = False,
explore: bool = False, explore: bool = False,
tests: tuple[str] = tuple(), tests: tuple[str, ...] = tuple(),
categories: tuple[str] = tuple(), categories: tuple[str, ...] = tuple(),
skip_categories: tuple[str] = tuple(), skip_categories: tuple[str, ...] = tuple(),
attempts_per_challenge: int = 1, attempts_per_challenge: int = 1,
mock: bool = False, mock: bool = False,
no_dep: bool = False, no_dep: bool = False,

View File

@ -53,9 +53,9 @@ class SingletonReportManager:
@classmethod @classmethod
def clear_instance(cls): def clear_instance(cls):
cls.instance = None cls.instance = None
cls.INFO_MANAGER = None del cls.INFO_MANAGER
cls.REGRESSION_MANAGER = None del cls.REGRESSION_MANAGER
cls.SUCCESS_RATE_TRACKER = None del cls.SUCCESS_RATE_TRACKER
class BaseReportManager: class BaseReportManager:
@ -99,7 +99,8 @@ class BaseReportManager:
class SessionReportManager(BaseReportManager): class SessionReportManager(BaseReportManager):
"""Abstracts interaction with the regression tests file""" """Abstracts interaction with the regression tests file"""
tests: dict[str, Test] | Report tests: dict[str, Test]
report: Report | None = None
def __init__(self, report_file: Path, benchmark_start_time: datetime): def __init__(self, report_file: Path, benchmark_start_time: datetime):
super().__init__(report_file) super().__init__(report_file)
@ -109,20 +110,21 @@ class SessionReportManager(BaseReportManager):
def save(self) -> None: def save(self) -> None:
with self.report_file.open("w") as f: with self.report_file.open("w") as f:
if isinstance(self.tests, Report): if self.report:
f.write(self.tests.json(indent=4)) f.write(self.report.json(indent=4))
else: else:
json.dump({k: v.dict() for k, v in self.tests.items()}, f, indent=4) json.dump({k: v.dict() for k, v in self.tests.items()}, f, indent=4)
def load(self) -> None: def load(self) -> None:
super().load() super().load()
if "tests" in self.tests: # type: ignore
self.tests = Report.parse_obj(self.tests) if "tests" in self.tests:
self.report = Report.parse_obj(self.tests)
else: else:
self.tests = {n: Test.parse_obj(d) for n, d in self.tests.items()} self.tests = {n: Test.parse_obj(d) for n, d in self.tests.items()}
def add_test_report(self, test_name: str, test_report: Test) -> None: def add_test_report(self, test_name: str, test_report: Test) -> None:
if isinstance(self.tests, Report): if self.report:
raise RuntimeError("Session report already finalized") raise RuntimeError("Session report already finalized")
if test_name.startswith("Test"): if test_name.startswith("Test"):
@ -134,10 +136,10 @@ class SessionReportManager(BaseReportManager):
def finalize_session_report(self, config: AgentBenchmarkConfig) -> None: def finalize_session_report(self, config: AgentBenchmarkConfig) -> None:
command = " ".join(sys.argv) command = " ".join(sys.argv)
if isinstance(self.tests, Report): if self.report:
raise RuntimeError("Session report already finalized") raise RuntimeError("Session report already finalized")
self.tests = Report( self.report = Report(
command=command.split(os.sep)[-1], command=command.split(os.sep)[-1],
benchmark_git_commit_sha="---", benchmark_git_commit_sha="---",
agent_git_commit_sha="---", agent_git_commit_sha="---",
@ -156,7 +158,7 @@ class SessionReportManager(BaseReportManager):
config=config.dict(exclude={"reports_folder"}, exclude_none=True), config=config.dict(exclude={"reports_folder"}, exclude_none=True),
) )
agent_categories = get_highest_achieved_difficulty_per_category(self.tests) agent_categories = get_highest_achieved_difficulty_per_category(self.report)
if len(agent_categories) > 1: if len(agent_categories) > 1:
save_single_radar_chart( save_single_radar_chart(
agent_categories, agent_categories,
@ -166,8 +168,8 @@ class SessionReportManager(BaseReportManager):
self.save() self.save()
def get_total_costs(self): def get_total_costs(self):
if isinstance(self.tests, Report): if self.report:
tests = self.tests.tests tests = self.report.tests
else: else:
tests = self.tests tests = self.tests

View File

@ -3,7 +3,7 @@ Model definitions used internally and for reports generated during command-line
""" """
import logging import logging
from typing import Any, Dict, List from typing import Annotated, Any, Dict, List
from agent_protocol_client import Step from agent_protocol_client import Step
from pydantic import BaseModel, Field, constr, validator from pydantic import BaseModel, Field, constr, validator
@ -88,7 +88,7 @@ class Test(BaseModel):
class ReportBase(BaseModel): class ReportBase(BaseModel):
command: str command: str
completion_time: str | None = None completion_time: str | None = None
benchmark_start_time: constr(regex=datetime_format) benchmark_start_time: Annotated[str, constr(regex=datetime_format)]
metrics: MetricsOverall metrics: MetricsOverall
config: Dict[str, str | dict[str, str]] config: Dict[str, str | dict[str, str]]
agent_git_commit_sha: str | None = None agent_git_commit_sha: str | None = None

View File

@ -1,4 +1,5 @@
"""Model definitions for use in the API""" """Model definitions for use in the API"""
from typing import Annotated
from pydantic import BaseModel, constr from pydantic import BaseModel, constr
@ -36,7 +37,7 @@ class RunDetails(BaseModel):
run_id: str | None = None run_id: str | None = None
command: str command: str
completion_time: str | None = None completion_time: str | None = None
benchmark_start_time: constr(regex=datetime_format) benchmark_start_time: Annotated[str, constr(regex=datetime_format)]
class BenchmarkRun(BaseModel): class BenchmarkRun(BaseModel):

View File

@ -1,14 +1,10 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional from typing import Any, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class TaskInput(BaseModel):
pass
class TaskRequestBody(BaseModel): class TaskRequestBody(BaseModel):
input: str = Field( input: str = Field(
..., ...,
@ -16,7 +12,7 @@ class TaskRequestBody(BaseModel):
description="Input prompt for the task.", description="Input prompt for the task.",
example="Write the words you receive to the file 'output.txt'.", example="Write the words you receive to the file 'output.txt'.",
) )
additional_input: Optional[TaskInput] = {} additional_input: Optional[dict[str, Any]] = Field(default_factory=dict)
class TaskEvalRequestBody(TaskRequestBody): class TaskEvalRequestBody(TaskRequestBody):

View File

@ -32,7 +32,10 @@ def _add_ini_and_option(
default: str | bool | int, default: str | bool | int,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Add an option to both the ini file as well as the command line flags, with the latter overriding the former.""" """
Add an option to both the ini file and the command line flags.
Command line flags/options takes precedence over the ini config.
"""
parser.addini( parser.addini(
name, name,
help + " This overrides the similarly named option from the config.", help + " This overrides the similarly named option from the config.",
@ -44,7 +47,10 @@ def _add_ini_and_option(
def _get_ini_or_option( def _get_ini_or_option(
config: Any, name: str, choices: Optional[list[str]] config: Any, name: str, choices: Optional[list[str]]
) -> str | None: ) -> str | None:
"""Get an option from either the ini file or the command line flags, the latter taking precedence.""" """
Get an option from either the ini file or the command line flags,
with the latter taking precedence.
"""
value = config.getini(name) value = config.getini(name)
if value is not None and choices is not None and value not in choices: if value is not None and choices is not None and value not in choices:
raise ValueError( raise ValueError(
@ -73,7 +79,7 @@ def pytest_addoption(parser: Parser) -> None:
default=False, default=False,
help=( help=(
"List all non-nodeid dependency names + the tests they resolve to. " "List all non-nodeid dependency names + the tests they resolve to. "
"Will also list all nodeid dependency names when verbosity is high enough." "Will also list all nodeid dependency names in verbose mode."
), ),
) )
@ -83,7 +89,10 @@ def pytest_addoption(parser: Parser) -> None:
"--list-processed-dependencies", "--list-processed-dependencies",
action="store_true", action="store_true",
default=False, default=False,
help="List all dependencies of all tests as a list of nodeids + the names that could not be resolved.", help=(
"List all dependencies of all tests as a list of nodeids "
"+ the names that could not be resolved."
),
) )
# Add an ini option + flag to choose the action to take for failed dependencies # Add an ini option + flag to choose the action to take for failed dependencies
@ -94,7 +103,8 @@ def pytest_addoption(parser: Parser) -> None:
name="failed_dependency_action", name="failed_dependency_action",
help=( help=(
"The action to take when a test has dependencies that failed. " "The action to take when a test has dependencies that failed. "
'Use "run" to run the test anyway, "skip" to skip the test, and "fail" to fail the test.' 'Use "run" to run the test anyway, "skip" to skip the test, '
'and "fail" to fail the test.'
), ),
default="skip", default="skip",
choices=DEPENDENCY_PROBLEM_ACTIONS.keys(), choices=DEPENDENCY_PROBLEM_ACTIONS.keys(),
@ -107,8 +117,10 @@ def pytest_addoption(parser: Parser) -> None:
group, group,
name="missing_dependency_action", name="missing_dependency_action",
help=( help=(
"The action to take when a test has dependencies that cannot be found within the current scope. " "The action to take when a test has dependencies that cannot be found "
'Use "run" to run the test anyway, "skip" to skip the test, and "fail" to fail the test.' "within the current scope. "
'Use "run" to run the test anyway, "skip" to skip the test, '
'and "fail" to fail the test.'
), ),
default="warning", default="warning",
choices=DEPENDENCY_PROBLEM_ACTIONS.keys(), choices=DEPENDENCY_PROBLEM_ACTIONS.keys(),
@ -139,7 +151,7 @@ def pytest_configure(config: Any) -> None:
@pytest.hookimpl(trylast=True) @pytest.hookimpl(trylast=True)
def pytest_collection_modifyitems(config: Any, items: list[Item]) -> None: def pytest_collection_modifyitems(config: Any, items: list[pytest.Function]) -> None:
manager = managers[-1] manager = managers[-1]
# Register the founds tests on the manager # Register the founds tests on the manager

View File

@ -3,7 +3,7 @@
# The name of the marker used # The name of the marker used
MARKER_NAME = "depends" MARKER_NAME = "depends"
# The name of the keyword argument for the marker that contains custom name(s) for the tests # The name of the kwarg for 'depends' markers that contains custom name(s) for the tests
MARKER_KWARG_ID = "name" MARKER_KWARG_ID = "name"
# The name of the keyword argument for the marker that specifies the tests to depend on # The name of the keyword argument for the marker that specifies the tests to depend on

View File

@ -57,8 +57,10 @@ def curved_edges(
""" """
ax = plt.gca() ax = plt.gca()
for u, v, data in G.edges(data=True): for u, v, data in G.edges(data=True):
src = np.array(pos[u]) _src = pos[u]
dst = np.array(pos[v]) _dst = pos[v]
src = np.array(_src)
dst = np.array(_dst)
same_level = abs(src[1] - dst[1]) < 0.01 same_level = abs(src[1] - dst[1]) < 0.01
@ -68,7 +70,7 @@ def curved_edges(
arrow = patches.FancyArrowPatch( arrow = patches.FancyArrowPatch(
posA=curve[0], # type: ignore posA=curve[0], # type: ignore
posB=curve[-1], # type: ignore posB=curve[-1], # type: ignore
connectionstyle=f"arc3,rad=0.2", connectionstyle="arc3,rad=0.2",
color="gray", color="gray",
arrowstyle="-|>", arrowstyle="-|>",
mutation_scale=15.0, mutation_scale=15.0,
@ -80,8 +82,8 @@ def curved_edges(
else: else:
ax.annotate( ax.annotate(
"", "",
xy=dst, xy=_dst,
xytext=src, xytext=_src,
arrowprops=dict( arrowprops=dict(
arrowstyle="-|>", color="gray", lw=1, shrinkA=10, shrinkB=10 arrowstyle="-|>", color="gray", lw=1, shrinkA=10, shrinkB=10
), ),
@ -89,7 +91,8 @@ def curved_edges(
def tree_layout(graph: nx.DiGraph, root_node: Any) -> Dict[Any, Tuple[float, float]]: def tree_layout(graph: nx.DiGraph, root_node: Any) -> Dict[Any, Tuple[float, float]]:
"""Compute positions as a tree layout centered on the root with alternating vertical shifts.""" """Compute positions as a tree layout centered on the root
with alternating vertical shifts."""
bfs_tree = nx.bfs_tree(graph, source=root_node) bfs_tree = nx.bfs_tree(graph, source=root_node)
levels = { levels = {
node: depth node: depth
@ -137,7 +140,7 @@ def tree_layout(graph: nx.DiGraph, root_node: Any) -> Dict[Any, Tuple[float, flo
def graph_spring_layout( def graph_spring_layout(
dag: nx.DiGraph, labels: Dict[Any, str], tree: bool = True dag: nx.DiGraph, labels: Dict[Any, str], tree: bool = True
) -> None: ) -> None:
num_nodes = len(dag.nodes()) num_nodes = len(list(dag.nodes()))
# Setting up the figure and axis # Setting up the figure and axis
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.axis("off") # Turn off the axis ax.axis("off") # Turn off the axis
@ -288,7 +291,8 @@ def graph_interactive_network(
# Optionally, save to a file # Optionally, save to a file
# Sync with the flutter UI # Sync with the flutter UI
# this literally only works in the AutoGPT repo, but this part of the code is not reached if BUILD_SKILL_TREE is false # this literally only works in the AutoGPT repo, but this part of the code
# is not reached if BUILD_SKILL_TREE is false
write_pretty_json(graph_data, flutter_app_path / "tree_structure.json") write_pretty_json(graph_data, flutter_app_path / "tree_structure.json")
validate_skill_tree(graph_data, "") validate_skill_tree(graph_data, "")
@ -332,11 +336,13 @@ def graph_interactive_network(
def extract_subgraph_based_on_category(graph, category): def extract_subgraph_based_on_category(graph, category):
""" """
Extracts a subgraph that includes all nodes and edges required to reach all nodes with a specified category. Extracts a subgraph that includes all nodes and edges required to reach all nodes
with a specified category.
:param graph: The original graph. :param graph: The original graph.
:param category: The target category. :param category: The target category.
:return: Subgraph with nodes and edges required to reach the nodes with the given category. :return: Subgraph with nodes and edges required to reach the nodes
with the given category.
""" """
subgraph = {"nodes": [], "edges": []} subgraph = {"nodes": [], "edges": []}
@ -424,7 +430,8 @@ def get_roots(graph):
def validate_skill_tree(graph, skill_tree_name): def validate_skill_tree(graph, skill_tree_name):
""" """
Validate if a given graph represents a valid skill tree and raise appropriate exceptions if not. Validate if a given graph represents a valid skill tree
and raise appropriate exceptions if not.
:param graph: A dictionary representing the graph with 'nodes' and 'edges'. :param graph: A dictionary representing the graph with 'nodes' and 'edges'.
:raises: ValueError with a description of the invalidity. :raises: ValueError with a description of the invalidity.
@ -434,7 +441,8 @@ def validate_skill_tree(graph, skill_tree_name):
if cycle_path: if cycle_path:
cycle_str = " -> ".join(cycle_path) cycle_str = " -> ".join(cycle_path)
raise ValueError( raise ValueError(
f"{skill_tree_name} skill tree is circular! Circular path detected: {cycle_str}." f"{skill_tree_name} skill tree is circular! "
f"Detected circular path: {cycle_str}."
) )
# Check for multiple roots # Check for multiple roots

View File

@ -1,18 +1,19 @@
""" """
A module to manage dependencies between pytest tests. A module to manage dependencies between pytest tests.
This module provides the methods implementing the main logic. These are used in the pytest hooks that are in This module provides the methods implementing the main logic.
__init__.py. These are used in the pytest hooks that are in __init__.py.
""" """
import collections import collections
import json
import os import os
from typing import Any, Generator from typing import Any, Generator
import colorama import colorama
import networkx import networkx
from _pytest.nodes import Item from pytest import Function, Item
from agbenchmark.challenges.base import BaseChallenge
from .constants import MARKER_KWARG_DEPENDENCIES, MARKER_NAME from .constants import MARKER_KWARG_DEPENDENCIES, MARKER_NAME
from .graphs import graph_interactive_network from .graphs import graph_interactive_network
@ -38,7 +39,8 @@ class TestResult(object):
) )
if result.when in self.results: if result.when in self.results:
raise AttributeError( raise AttributeError(
f"Received multiple results for step {result.when} of test {self.nodeid}" f"Received multiple results for step {result.when} "
f"of test {self.nodeid}"
) )
self.results[result.when] = result.outcome self.results[result.when] = result.outcome
@ -66,7 +68,7 @@ class TestDependencies(object):
for dep in marker.kwargs.get(MARKER_KWARG_DEPENDENCIES, []) for dep in marker.kwargs.get(MARKER_KWARG_DEPENDENCIES, [])
] ]
for dependency in dependencies: for dependency in dependencies:
# If the name is not known, try to make it absolute (ie file::[class::]method) # If the name is not known, try to make it absolute (file::[class::]method)
if dependency not in manager.name_to_nodeids: if dependency not in manager.name_to_nodeids:
absolute_dependency = get_absolute_nodeid(dependency, self.nodeid) absolute_dependency = get_absolute_nodeid(dependency, self.nodeid)
if absolute_dependency in manager.name_to_nodeids: if absolute_dependency in manager.name_to_nodeids:
@ -86,20 +88,20 @@ class DependencyManager(object):
def __init__(self) -> None: def __init__(self) -> None:
"""Create a new DependencyManager.""" """Create a new DependencyManager."""
self.options: dict[str, Any] = {} self.options: dict[str, Any] = {}
self._items: list[Item] | None = None self._items: list[Function] | None = None
self._name_to_nodeids: Any = None self._name_to_nodeids: Any = None
self._nodeid_to_item: Any = None self._nodeid_to_item: Any = None
self._results: Any = None self._results: Any = None
@property @property
def items(self) -> list[Item]: def items(self) -> list[Function]:
"""The collected tests that are managed by this instance.""" """The collected tests that are managed by this instance."""
if self._items is None: if self._items is None:
raise AttributeError("The items attribute has not been set yet") raise AttributeError("The items attribute has not been set yet")
return self._items return self._items
@items.setter @items.setter
def items(self, items: list[Item]) -> None: def items(self, items: list[Function]) -> None:
if self._items is not None: if self._items is not None:
raise AttributeError("The items attribute has already been set") raise AttributeError("The items attribute has already been set")
self._items = items self._items = items
@ -125,7 +127,8 @@ class DependencyManager(object):
for item in items: for item in items:
nodeid = clean_nodeid(item.nodeid) nodeid = clean_nodeid(item.nodeid)
# Process the dependencies of this test # Process the dependencies of this test
# This uses the mappings created in the previous loop, and can thus not be merged into that loop # This uses the mappings created in the previous loop,
# and can thus not be merged into that loop
self._dependencies[nodeid] = TestDependencies(item, self) self._dependencies[nodeid] = TestDependencies(item, self)
@property @property
@ -135,7 +138,7 @@ class DependencyManager(object):
return self._name_to_nodeids return self._name_to_nodeids
@property @property
def nodeid_to_item(self) -> dict[str, Item]: def nodeid_to_item(self) -> dict[str, Function]:
"""A mapping from node ids to test items.""" """A mapping from node ids to test items."""
assert self.items is not None assert self.items is not None
return self._nodeid_to_item return self._nodeid_to_item
@ -194,7 +197,9 @@ class DependencyManager(object):
@property @property
def sorted_items(self) -> Generator: def sorted_items(self) -> Generator:
"""Get a sorted list of tests where all tests are sorted after their dependencies.""" """
Get a sorted list of tests where all tests are sorted after their dependencies.
"""
# Build a directed graph for sorting # Build a directed graph for sorting
build_skill_tree = os.getenv("BUILD_SKILL_TREE") build_skill_tree = os.getenv("BUILD_SKILL_TREE")
BUILD_SKILL_TREE = ( BUILD_SKILL_TREE = (
@ -202,8 +207,8 @@ class DependencyManager(object):
) )
dag = networkx.DiGraph() dag = networkx.DiGraph()
# Insert all items as nodes, to prevent items that have no dependencies and are not dependencies themselves from # Insert all items as nodes, to prevent items that have no dependencies
# being lost # and are not dependencies themselves from being lost
dag.add_nodes_from(self.items) dag.add_nodes_from(self.items)
# Insert edges for all the dependencies # Insert edges for all the dependencies
@ -214,11 +219,8 @@ class DependencyManager(object):
labels = {} labels = {}
for item in self.items: for item in self.items:
try: assert item.cls and issubclass(item.cls, BaseChallenge)
with open(item.cls.CHALLENGE_LOCATION) as f: data = item.cls.info.dict()
data = json.load(f)
except:
data = {}
node_name = get_name(item) node_name = get_name(item)
data["name"] = node_name data["name"] = node_name

View File

@ -38,7 +38,8 @@ def strip_nodeid_parameters(nodeid: str) -> str:
def get_absolute_nodeid(nodeid: str, scope: str) -> str: def get_absolute_nodeid(nodeid: str, scope: str) -> str:
""" """
Transform a possibly relative node id to an absolute one using the scope in which it is used. Transform a possibly relative node id to an absolute one
using the scope in which it is used.
>>> scope = 'test_file.py::TestClass::test' >>> scope = 'test_file.py::TestClass::test'
>>> get_absolute_nodeid('test2', scope) >>> get_absolute_nodeid('test2', scope)
@ -49,7 +50,7 @@ def get_absolute_nodeid(nodeid: str, scope: str) -> str:
'test_file2.py::TestClass2::test2' 'test_file2.py::TestClass2::test2'
""" """
parts = nodeid.split("::") parts = nodeid.split("::")
# Completely relative (test_name), so add the full current scope (either file::class or file) # Completely relative (test_name): add the full current scope (file::class or file)
if len(parts) == 1: if len(parts) == 1:
base_nodeid = scope.rsplit("::", 1)[0] base_nodeid = scope.rsplit("::", 1)[0]
nodeid = f"{base_nodeid}::{nodeid}" nodeid = f"{base_nodeid}::{nodeid}"

View File

@ -15,7 +15,8 @@ def get_data_from_helicone(challenge: str) -> Optional[float]:
# Define the endpoint of your GraphQL server # Define the endpoint of your GraphQL server
url = "https://www.helicone.ai/api/graphql" url = "https://www.helicone.ai/api/graphql"
# Set the headers, usually you'd need to set the content type and possibly an authorization token # Set the headers, usually you'd need to set the content type
# and possibly an authorization token
headers = {"authorization": f"Bearer {os.environ.get('HELICONE_API_KEY')}"} headers = {"authorization": f"Bearer {os.environ.get('HELICONE_API_KEY')}"}
# Define the query, variables, and operation name # Define the query, variables, and operation name

View File

@ -1,7 +1,18 @@
SCORING_MAP = { SCORING_MAP = {
"percentage": "assign a float score that will represent a percentage out of 100. Use decimal points to be even more accurate. 0 represents the worst possible generation, while 100 represents the ideal generation", "percentage": (
"scale": "assign an integer score from a scale of 1-10. 1 represents a really bad generation, while 10 represents an ideal generation", "assign a float score that will represent a percentage out of 100. "
"binary": "assign a binary score of either 0 or 1. 0 represents a failure, while 1 represents a success", "Use decimal points to be even more accurate. "
"0 represents the worst possible generation, "
"while 100 represents the ideal generation"
),
"scale": (
"assign an integer score from a scale of 1-10. "
"1 represents a really bad generation, while 10 represents an ideal generation"
),
"binary": (
"assign a binary score of either 0 or 1. "
"0 represents a failure, while 1 represents a success"
),
} }
@ -17,7 +28,7 @@ Here is the ideal response you're comparing to based on the task:
Here is the current machine generated response to the task that you need to evaluate: Here is the current machine generated response to the task that you need to evaluate:
{response} {response}
""" """ # noqa: E501
RUBRIC_PROMPT = """Ignore previous directions. You are now an expert at evaluating machine generated responses to given tasks. RUBRIC_PROMPT = """Ignore previous directions. You are now an expert at evaluating machine generated responses to given tasks.
In order to score the generated texts you will {scoring}. Make sure to factor in rubric into your thinking, deliberation, and final result regarding scoring. Return nothing but a float score. In order to score the generated texts you will {scoring}. Make sure to factor in rubric into your thinking, deliberation, and final result regarding scoring. Return nothing but a float score.
@ -31,7 +42,7 @@ Use the below rubric to guide your thinking about scoring:
Here is the current machine generated response to the task that you need to evaluate: Here is the current machine generated response to the task that you need to evaluate:
{response} {response}
""" """ # noqa: E501
QUESTION_PROMPT = """Ignore previous directions. You are now an expert at evaluating machine generated responses to given tasks. QUESTION_PROMPT = """Ignore previous directions. You are now an expert at evaluating machine generated responses to given tasks.
In order to score the generated texts you will {scoring}. Make sure to think about whether the generated response answers the question well in order to score accurately. Return nothing but a float score. In order to score the generated texts you will {scoring}. Make sure to think about whether the generated response answers the question well in order to score accurately. Return nothing but a float score.
@ -45,12 +56,12 @@ Here is a question that checks if the task was completed correctly:
Here is the current machine generated response to the task that you need to evaluate: Here is the current machine generated response to the task that you need to evaluate:
{response} {response}
""" """ # noqa: E501
FEW_SHOT_EXAMPLES = """Here are some examples of how to score a machine generated response based on the above: FEW_SHOT_EXAMPLES = """Here are some examples of how to score a machine generated response based on the above:
{examples} {examples}
""" """ # noqa: E501
CUSTOM_PROMPT = """{custom} CUSTOM_PROMPT = """{custom}
{scoring} {scoring}

View File

@ -202,11 +202,15 @@ def sorted_by_enum_index(
sortable: Iterable[T], sortable: Iterable[T],
enum: type[Enum], enum: type[Enum],
*, *,
key: Callable[[T], Enum | None] = lambda x: x, # type: ignore key: Optional[Callable[[T], Enum | None]] = None,
reverse: bool = False, reverse: bool = False,
) -> list[T]: ) -> list[T]:
return sorted( return sorted(
sortable, sortable,
key=lambda x: enum._member_names_.index(e.name) if (e := key(x)) else 420e3, key=lambda x: (
enum._member_names_.index(e.name) # type: ignore
if (e := key(x) if key else x)
else 420e3
),
reverse=reverse, reverse=reverse,
) )

View File

@ -1,13 +0,0 @@
[mypy]
namespace_packages = True
follow_imports = skip
check_untyped_defs = True
disallow_untyped_defs = True
exclude = ^(agbenchmark/challenges/|agent/|venv|venv-dev)
ignore_missing_imports = True
[mypy-agbenchmark.utils.data_types.*]
ignore_errors = True
[mypy-numpy.*]
ignore_errors = True

213
benchmark/poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]] [[package]]
name = "agent-protocol-client" name = "agent-protocol-client"
@ -197,63 +197,49 @@ tests = ["attrs[tests-no-zope]", "zope-interface"]
tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"] tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"]
tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"] tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"]
[[package]]
name = "autoflake"
version = "1.7.8"
description = "Removes unused imports and unused variables"
optional = false
python-versions = ">=3.7"
files = [
{file = "autoflake-1.7.8-py3-none-any.whl", hash = "sha256:46373ef69b6714f5064c923bb28bd797c4f8a9497f557d87fc36665c6d956b39"},
{file = "autoflake-1.7.8.tar.gz", hash = "sha256:e7e46372dee46fa1c97acf310d99d922b63d369718a270809d7c278d34a194cf"},
]
[package.dependencies]
pyflakes = ">=1.1.0,<3"
tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""}
[[package]] [[package]]
name = "black" name = "black"
version = "22.3.0" version = "23.12.1"
description = "The uncompromising code formatter." description = "The uncompromising code formatter."
optional = false optional = false
python-versions = ">=3.6.2" python-versions = ">=3.8"
files = [ files = [
{file = "black-22.3.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:2497f9c2386572e28921fa8bec7be3e51de6801f7459dffd6e62492531c47e09"}, {file = "black-23.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0aaf6041986767a5e0ce663c7a2f0e9eaf21e6ff87a5f95cbf3675bfd4c41d2"},
{file = "black-22.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5795a0375eb87bfe902e80e0c8cfaedf8af4d49694d69161e5bd3206c18618bb"}, {file = "black-23.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c88b3711d12905b74206227109272673edce0cb29f27e1385f33b0163c414bba"},
{file = "black-22.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e3556168e2e5c49629f7b0f377070240bd5511e45e25a4497bb0073d9dda776a"}, {file = "black-23.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a920b569dc6b3472513ba6ddea21f440d4b4c699494d2e972a1753cdc25df7b0"},
{file = "black-22.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67c8301ec94e3bcc8906740fe071391bce40a862b7be0b86fb5382beefecd968"}, {file = "black-23.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:3fa4be75ef2a6b96ea8d92b1587dd8cb3a35c7e3d51f0738ced0781c3aa3a5a3"},
{file = "black-22.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:fd57160949179ec517d32ac2ac898b5f20d68ed1a9c977346efbac9c2f1e779d"}, {file = "black-23.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8d4df77958a622f9b5a4c96edb4b8c0034f8434032ab11077ec6c56ae9f384ba"},
{file = "black-22.3.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:cc1e1de68c8e5444e8f94c3670bb48a2beef0e91dddfd4fcc29595ebd90bb9ce"}, {file = "black-23.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:602cfb1196dc692424c70b6507593a2b29aac0547c1be9a1d1365f0d964c353b"},
{file = "black-22.3.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2fc92002d44746d3e7db7cf9313cf4452f43e9ea77a2c939defce3b10b5c82"}, {file = "black-23.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c4352800f14be5b4864016882cdba10755bd50805c95f728011bcb47a4afd59"},
{file = "black-22.3.0-cp36-cp36m-win_amd64.whl", hash = "sha256:a6342964b43a99dbc72f72812bf88cad8f0217ae9acb47c0d4f141a6416d2d7b"}, {file = "black-23.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:0808494f2b2df923ffc5723ed3c7b096bd76341f6213989759287611e9837d50"},
{file = "black-22.3.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:328efc0cc70ccb23429d6be184a15ce613f676bdfc85e5fe8ea2a9354b4e9015"}, {file = "black-23.12.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:25e57fd232a6d6ff3f4478a6fd0580838e47c93c83eaf1ccc92d4faf27112c4e"},
{file = "black-22.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06f9d8846f2340dfac80ceb20200ea5d1b3f181dd0556b47af4e8e0b24fa0a6b"}, {file = "black-23.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2d9e13db441c509a3763a7a3d9a49ccc1b4e974a47be4e08ade2a228876500ec"},
{file = "black-22.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:ad4efa5fad66b903b4a5f96d91461d90b9507a812b3c5de657d544215bb7877a"}, {file = "black-23.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1bd9c210f8b109b1762ec9fd36592fdd528485aadb3f5849b2740ef17e674e"},
{file = "black-22.3.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:e8477ec6bbfe0312c128e74644ac8a02ca06bcdb8982d4ee06f209be28cdf163"}, {file = "black-23.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:ae76c22bde5cbb6bfd211ec343ded2163bba7883c7bc77f6b756a1049436fbb9"},
{file = "black-22.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:637a4014c63fbf42a692d22b55d8ad6968a946b4a6ebc385c5505d9625b6a464"}, {file = "black-23.12.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1fa88a0f74e50e4487477bc0bb900c6781dbddfdfa32691e780bf854c3b4a47f"},
{file = "black-22.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:863714200ada56cbc366dc9ae5291ceb936573155f8bf8e9de92aef51f3ad0f0"}, {file = "black-23.12.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a4d6a9668e45ad99d2f8ec70d5c8c04ef4f32f648ef39048d010b0689832ec6d"},
{file = "black-22.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10dbe6e6d2988049b4655b2b739f98785a884d4d6b85bc35133a8fb9a2233176"}, {file = "black-23.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b18fb2ae6c4bb63eebe5be6bd869ba2f14fd0259bda7d18a46b764d8fb86298a"},
{file = "black-22.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:cee3e11161dde1b2a33a904b850b0899e0424cc331b7295f2a9698e79f9a69a0"}, {file = "black-23.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:c04b6d9d20e9c13f43eee8ea87d44156b8505ca8a3c878773f68b4e4812a421e"},
{file = "black-22.3.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5891ef8abc06576985de8fa88e95ab70641de6c1fca97e2a15820a9b69e51b20"}, {file = "black-23.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3e1b38b3135fd4c025c28c55ddfc236b05af657828a8a6abe5deec419a0b7055"},
{file = "black-22.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:30d78ba6bf080eeaf0b7b875d924b15cd46fec5fd044ddfbad38c8ea9171043a"}, {file = "black-23.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4f0031eaa7b921db76decd73636ef3a12c942ed367d8c3841a0739412b260a54"},
{file = "black-22.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ee8f1f7228cce7dffc2b464f07ce769f478968bfb3dd1254a4c2eeed84928aad"}, {file = "black-23.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97e56155c6b737854e60a9ab1c598ff2533d57e7506d97af5481141671abf3ea"},
{file = "black-22.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ee227b696ca60dd1c507be80a6bc849a5a6ab57ac7352aad1ffec9e8b805f21"}, {file = "black-23.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:dd15245c8b68fe2b6bd0f32c1556509d11bb33aec9b5d0866dd8e2ed3dba09c2"},
{file = "black-22.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:9b542ced1ec0ceeff5b37d69838106a6348e60db7b8fdd245294dc1d26136265"}, {file = "black-23.12.1-py3-none-any.whl", hash = "sha256:78baad24af0f033958cad29731e27363183e140962595def56423e626f4bee3e"},
{file = "black-22.3.0-py3-none-any.whl", hash = "sha256:bc58025940a896d7e5356952228b68f793cf5fcb342be703c3a2669a1488cb72"}, {file = "black-23.12.1.tar.gz", hash = "sha256:4ce3ef14ebe8d9509188014d96af1c456a910d5b5cbf434a09fef7e024b3d0d5"},
{file = "black-22.3.0.tar.gz", hash = "sha256:35020b8886c022ced9282b51b5a875b6d1ab0c387b31a065b84db7c33085ca79"},
] ]
[package.dependencies] [package.dependencies]
click = ">=8.0.0" click = ">=8.0.0"
mypy-extensions = ">=0.4.3" mypy-extensions = ">=0.4.3"
packaging = ">=22.0"
pathspec = ">=0.9.0" pathspec = ">=0.9.0"
platformdirs = ">=2" platformdirs = ">=2"
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""}
[package.extras] [package.extras]
colorama = ["colorama (>=0.4.3)"] colorama = ["colorama (>=0.4.3)"]
d = ["aiohttp (>=3.7.4)"] d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"]
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
uvloop = ["uvloop (>=0.15.2)"] uvloop = ["uvloop (>=0.15.2)"]
@ -558,6 +544,73 @@ mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.6.1)", "types-Pill
test = ["Pillow", "contourpy[test-no-images]", "matplotlib"] test = ["Pillow", "contourpy[test-no-images]", "matplotlib"]
test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"] test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"]
[[package]]
name = "coverage"
version = "7.5.1"
description = "Code coverage measurement for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "coverage-7.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0884920835a033b78d1c73b6d3bbcda8161a900f38a488829a83982925f6c2e"},
{file = "coverage-7.5.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:39afcd3d4339329c5f58de48a52f6e4e50f6578dd6099961cf22228feb25f38f"},
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a7b0ceee8147444347da6a66be737c9d78f3353b0681715b668b72e79203e4a"},
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a9ca3f2fae0088c3c71d743d85404cec8df9be818a005ea065495bedc33da35"},
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd215c0c7d7aab005221608a3c2b46f58c0285a819565887ee0b718c052aa4e"},
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4bf0655ab60d754491004a5efd7f9cccefcc1081a74c9ef2da4735d6ee4a6223"},
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:61c4bf1ba021817de12b813338c9be9f0ad5b1e781b9b340a6d29fc13e7c1b5e"},
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:db66fc317a046556a96b453a58eced5024af4582a8dbdc0c23ca4dbc0d5b3146"},
{file = "coverage-7.5.1-cp310-cp310-win32.whl", hash = "sha256:b016ea6b959d3b9556cb401c55a37547135a587db0115635a443b2ce8f1c7228"},
{file = "coverage-7.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:df4e745a81c110e7446b1cc8131bf986157770fa405fe90e15e850aaf7619bc8"},
{file = "coverage-7.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:796a79f63eca8814ca3317a1ea443645c9ff0d18b188de470ed7ccd45ae79428"},
{file = "coverage-7.5.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4fc84a37bfd98db31beae3c2748811a3fa72bf2007ff7902f68746d9757f3746"},
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6175d1a0559986c6ee3f7fccfc4a90ecd12ba0a383dcc2da30c2b9918d67d8a3"},
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fc81d5878cd6274ce971e0a3a18a8803c3fe25457165314271cf78e3aae3aa2"},
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:556cf1a7cbc8028cb60e1ff0be806be2eded2daf8129b8811c63e2b9a6c43bca"},
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9981706d300c18d8b220995ad22627647be11a4276721c10911e0e9fa44c83e8"},
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d7fed867ee50edf1a0b4a11e8e5d0895150e572af1cd6d315d557758bfa9c057"},
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ef48e2707fb320c8f139424a596f5b69955a85b178f15af261bab871873bb987"},
{file = "coverage-7.5.1-cp311-cp311-win32.whl", hash = "sha256:9314d5678dcc665330df5b69c1e726a0e49b27df0461c08ca12674bcc19ef136"},
{file = "coverage-7.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:5fa567e99765fe98f4e7d7394ce623e794d7cabb170f2ca2ac5a4174437e90dd"},
{file = "coverage-7.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b6cf3764c030e5338e7f61f95bd21147963cf6aa16e09d2f74f1fa52013c1206"},
{file = "coverage-7.5.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ec92012fefebee89a6b9c79bc39051a6cb3891d562b9270ab10ecfdadbc0c34"},
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16db7f26000a07efcf6aea00316f6ac57e7d9a96501e990a36f40c965ec7a95d"},
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:beccf7b8a10b09c4ae543582c1319c6df47d78fd732f854ac68d518ee1fb97fa"},
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8748731ad392d736cc9ccac03c9845b13bb07d020a33423fa5b3a36521ac6e4e"},
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7352b9161b33fd0b643ccd1f21f3a3908daaddf414f1c6cb9d3a2fd618bf2572"},
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:7a588d39e0925f6a2bff87154752481273cdb1736270642aeb3635cb9b4cad07"},
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:68f962d9b72ce69ea8621f57551b2fa9c70509af757ee3b8105d4f51b92b41a7"},
{file = "coverage-7.5.1-cp312-cp312-win32.whl", hash = "sha256:f152cbf5b88aaeb836127d920dd0f5e7edff5a66f10c079157306c4343d86c19"},
{file = "coverage-7.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:5a5740d1fb60ddf268a3811bcd353de34eb56dc24e8f52a7f05ee513b2d4f596"},
{file = "coverage-7.5.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e2213def81a50519d7cc56ed643c9e93e0247f5bbe0d1247d15fa520814a7cd7"},
{file = "coverage-7.5.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5037f8fcc2a95b1f0e80585bd9d1ec31068a9bcb157d9750a172836e98bc7a90"},
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3721c2c9e4c4953a41a26c14f4cef64330392a6d2d675c8b1db3b645e31f0e"},
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca498687ca46a62ae590253fba634a1fe9836bc56f626852fb2720f334c9e4e5"},
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cdcbc320b14c3e5877ee79e649677cb7d89ef588852e9583e6b24c2e5072661"},
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:57e0204b5b745594e5bc14b9b50006da722827f0b8c776949f1135677e88d0b8"},
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fe7502616b67b234482c3ce276ff26f39ffe88adca2acf0261df4b8454668b4"},
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9e78295f4144f9dacfed4f92935fbe1780021247c2fabf73a819b17f0ccfff8d"},
{file = "coverage-7.5.1-cp38-cp38-win32.whl", hash = "sha256:1434e088b41594baa71188a17533083eabf5609e8e72f16ce8c186001e6b8c41"},
{file = "coverage-7.5.1-cp38-cp38-win_amd64.whl", hash = "sha256:0646599e9b139988b63704d704af8e8df7fa4cbc4a1f33df69d97f36cb0a38de"},
{file = "coverage-7.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4cc37def103a2725bc672f84bd939a6fe4522310503207aae4d56351644682f1"},
{file = "coverage-7.5.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fc0b4d8bfeabd25ea75e94632f5b6e047eef8adaed0c2161ada1e922e7f7cece"},
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d0a0f5e06881ecedfe6f3dd2f56dcb057b6dbeb3327fd32d4b12854df36bf26"},
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9735317685ba6ec7e3754798c8871c2f49aa5e687cc794a0b1d284b2389d1bd5"},
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d21918e9ef11edf36764b93101e2ae8cc82aa5efdc7c5a4e9c6c35a48496d601"},
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c3e757949f268364b96ca894b4c342b41dc6f8f8b66c37878aacef5930db61be"},
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:79afb6197e2f7f60c4824dd4b2d4c2ec5801ceb6ba9ce5d2c3080e5660d51a4f"},
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d1d0d98d95dd18fe29dc66808e1accf59f037d5716f86a501fc0256455219668"},
{file = "coverage-7.5.1-cp39-cp39-win32.whl", hash = "sha256:1cc0fe9b0b3a8364093c53b0b4c0c2dd4bb23acbec4c9240b5f284095ccf7981"},
{file = "coverage-7.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:dde0070c40ea8bb3641e811c1cfbf18e265d024deff6de52c5950677a8fb1e0f"},
{file = "coverage-7.5.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:6537e7c10cc47c595828b8a8be04c72144725c383c4702703ff4e42e44577312"},
{file = "coverage-7.5.1.tar.gz", hash = "sha256:54de9ef3a9da981f7af93eafde4ede199e0846cd819eb27c88e2b712aae9708c"},
]
[package.dependencies]
tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""}
[package.extras]
toml = ["tomli"]
[[package]] [[package]]
name = "cycler" name = "cycler"
version = "0.12.1" version = "0.12.1"
@ -671,19 +724,19 @@ typing = ["typing-extensions (>=4.8)"]
[[package]] [[package]]
name = "flake8" name = "flake8"
version = "3.9.2" version = "7.0.0"
description = "the modular source code checker: pep8 pyflakes and co" description = "the modular source code checker: pep8 pyflakes and co"
optional = false optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" python-versions = ">=3.8.1"
files = [ files = [
{file = "flake8-3.9.2-py2.py3-none-any.whl", hash = "sha256:bf8fd333346d844f616e8d47905ef3a3384edae6b4e9beb0c5101e25e3110907"}, {file = "flake8-7.0.0-py2.py3-none-any.whl", hash = "sha256:a6dfbb75e03252917f2473ea9653f7cd799c3064e54d4c8140044c5c065f53c3"},
{file = "flake8-3.9.2.tar.gz", hash = "sha256:07528381786f2a6237b061f6e96610a4167b226cb926e2aa2b6b1d78057c576b"}, {file = "flake8-7.0.0.tar.gz", hash = "sha256:33f96621059e65eec474169085dc92bf26e7b2d47366b70be2f67ab80dc25132"},
] ]
[package.dependencies] [package.dependencies]
mccabe = ">=0.6.0,<0.7.0" mccabe = ">=0.7.0,<0.8.0"
pycodestyle = ">=2.7.0,<2.8.0" pycodestyle = ">=2.11.0,<2.12.0"
pyflakes = ">=2.3.0,<2.4.0" pyflakes = ">=3.2.0,<3.3.0"
[[package]] [[package]]
name = "fonttools" name = "fonttools"
@ -1376,13 +1429,13 @@ traitlets = "*"
[[package]] [[package]]
name = "mccabe" name = "mccabe"
version = "0.6.1" version = "0.7.0"
description = "McCabe checker, plugin for flake8" description = "McCabe checker, plugin for flake8"
optional = false optional = false
python-versions = "*" python-versions = ">=3.6"
files = [ files = [
{file = "mccabe-0.6.1-py2.py3-none-any.whl", hash = "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42"}, {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"},
{file = "mccabe-0.6.1.tar.gz", hash = "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f"}, {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
] ]
[[package]] [[package]]
@ -1973,13 +2026,13 @@ pyasn1 = ">=0.4.6,<0.6.0"
[[package]] [[package]]
name = "pycodestyle" name = "pycodestyle"
version = "2.7.0" version = "2.11.1"
description = "Python style guide checker" description = "Python style guide checker"
optional = false optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" python-versions = ">=3.8"
files = [ files = [
{file = "pycodestyle-2.7.0-py2.py3-none-any.whl", hash = "sha256:514f76d918fcc0b55c6680472f0a37970994e07bbb80725808c17089be302068"}, {file = "pycodestyle-2.11.1-py2.py3-none-any.whl", hash = "sha256:44fe31000b2d866f2e41841b18528a505fbd7fef9017b04eff4e2648a0fadc67"},
{file = "pycodestyle-2.7.0.tar.gz", hash = "sha256:c389c1d06bf7904078ca03399a4816f974a1d590090fecea0c63ec26ebaf1cef"}, {file = "pycodestyle-2.11.1.tar.gz", hash = "sha256:41ba0e7afc9752dfb53ced5489e89f8186be00e599e712660695b7a75ff2663f"},
] ]
[[package]] [[package]]
@ -2047,13 +2100,13 @@ email = ["email-validator (>=1.0.3)"]
[[package]] [[package]]
name = "pyflakes" name = "pyflakes"
version = "2.3.1" version = "3.2.0"
description = "passive checker of Python programs" description = "passive checker of Python programs"
optional = false optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" python-versions = ">=3.8"
files = [ files = [
{file = "pyflakes-2.3.1-py2.py3-none-any.whl", hash = "sha256:7893783d01b8a89811dd72d7dfd4d84ff098e5eed95cfa8905b22bbffe52efc3"}, {file = "pyflakes-3.2.0-py2.py3-none-any.whl", hash = "sha256:84b5be138a2dfbb40689ca07e2152deb896a65c3a3e24c251c5c62489568074a"},
{file = "pyflakes-2.3.1.tar.gz", hash = "sha256:f5bc8ecabc05bb9d291eb5203d6810b49040f6ff446a756326104746cc00c1db"}, {file = "pyflakes-3.2.0.tar.gz", hash = "sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f"},
] ]
[[package]] [[package]]
@ -2085,6 +2138,24 @@ files = [
[package.extras] [package.extras]
diagrams = ["jinja2", "railroad-diagrams"] diagrams = ["jinja2", "railroad-diagrams"]
[[package]]
name = "pyright"
version = "1.1.364"
description = "Command line wrapper for pyright"
optional = false
python-versions = ">=3.7"
files = [
{file = "pyright-1.1.364-py3-none-any.whl", hash = "sha256:865f1e02873c5dc7427c95acf53659a118574010e6fb364e27e47ec5c46a9f26"},
{file = "pyright-1.1.364.tar.gz", hash = "sha256:612a2106a4078ec57efc22b5620729e9bdf4a3c17caba013b534bd33f7d08e5a"},
]
[package.dependencies]
nodeenv = ">=1.6.0"
[package.extras]
all = ["twine (>=3.4.1)"]
dev = ["twine (>=3.4.1)"]
[[package]] [[package]]
name = "pysocks" name = "pysocks"
version = "1.7.1" version = "1.7.1"
@ -2137,6 +2208,24 @@ pytest = ">=7.0.0"
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
[[package]]
name = "pytest-cov"
version = "5.0.0"
description = "Pytest plugin for measuring coverage."
optional = false
python-versions = ">=3.8"
files = [
{file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"},
{file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"},
]
[package.dependencies]
coverage = {version = ">=5.2.1", extras = ["toml"]}
pytest = ">=4.6"
[package.extras]
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
[[package]] [[package]]
name = "python-dateutil" name = "python-dateutil"
version = "2.8.2" version = "2.8.2"
@ -2774,4 +2863,4 @@ multidict = ">=4.0"
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "6eefdbbefb500de627cac39eb6eb1fdcecab76dd4c3599cf08ef6dc647cf71c9" content-hash = "4a980e6d8f54a2f7f6a3c55d4f40ac3a4b27b5ac6573dd2a39e11213a4b126dd"

View File

@ -37,59 +37,49 @@ click-default-group = "^1.2.4"
tabulate = "^0.9.0" tabulate = "^0.9.0"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
flake8 = "^3.9.2" black = "^23.12.1"
isort = "^5.9.3" flake8 = "^7.0.0"
black = "22.3" isort = "^5.13.1"
autoflake = "^1.4" pyright = "^1.1.364"
pandas = "^2.0.3" pandas = "^2.0.3"
gspread = "^5.10.0" gspread = "^5.10.0"
oauth2client = "^4.1.3" oauth2client = "^4.1.3"
pre-commit = "^3.3.3" pre-commit = "^3.3.3"
pytest-cov = "^5.0.0"
[tool.poetry.scripts]
agbenchmark = "agbenchmark.__main__:cli"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
minversion = "6.0"
addopts = "-ra -q"
testpaths = [
"tests", "agbenchmark",
]
asyncio_mode = "auto"
markers = [
"interface",
"code",
"memory",
"iterate",
"adaptability",
"safety",
"content_gen",
"product_advisor"
]
filterwarnings = [
"ignore::pytest.PytestAssertRewriteWarning",
"ignore::matplotlib.MatplotlibDeprecationWarning"
]
[tool.black] [tool.black]
line-length = 88 line-length = 88
target-version = ['py310'] target-version = ['py310']
include = '\.pyi?$' include = '\.pyi?$'
packages = ["autogpt"]
extend-exclude = '(/dist|/.venv|/venv|/build|/agent|agbenchmark/challenges)/'
[tool.isort] [tool.isort]
profile = "black" profile = "black"
multi_line_output = 3 skip_glob = ["reports"]
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
line_length = 88
sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"]
skip_glob = [".tox", "__pycache__", "*.pyc", "venv*/*", "reports", "venv", "env", "node_modules", ".env", ".venv", "dist", "agent/*", "agbenchmark/challenges/*"]
[tool.poetry.scripts]
agbenchmark = "agbenchmark.__main__:cli" [tool.pyright]
pythonVersion = "3.10"
exclude = [
"notebooks/**",
"reports/**",
"**/node_modules",
"**/__pycache__",
"**/.*",
]
ignore = [
"../forge/**"
]
[tool.pytest.ini_options]
testpaths = ["tests"]

View File

@ -17,7 +17,7 @@ def print_markdown_report(report_json_file: str):
report = Report.parse_file(report_json_file) report = Report.parse_file(report_json_file)
# Header and metadata # Header and metadata
click.echo(f"# Benchmark Report") click.echo("# Benchmark Report")
click.echo(f"- ⌛ **Run time:** `{report.metrics.run_time}`") click.echo(f"- ⌛ **Run time:** `{report.metrics.run_time}`")
click.echo( click.echo(
f" - **Started at:** `{report.benchmark_start_time[:16].replace('T', '` `')}`" f" - **Started at:** `{report.benchmark_start_time[:16].replace('T', '` `')}`"

View File

@ -1,11 +1,16 @@
import datetime
import time
import pytest import pytest
import requests import requests
URL_BENCHMARK = "http://localhost:8080/ap/v1" URL_BENCHMARK = "http://localhost:8080/ap/v1"
URL_AGENT = "http://localhost:8000/ap/v1" URL_AGENT = "http://localhost:8000/ap/v1"
import datetime try:
import time response = requests.get(f"{URL_AGENT}/agent/tasks")
except requests.exceptions.ConnectionError:
pytest.skip("No agent available to test against", allow_module_level=True)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -20,7 +25,8 @@ import time
), ),
( (
"f219f3d3-a41b-45a9-a3d0-389832086ee8", "f219f3d3-a41b-45a9-a3d0-389832086ee8",
"Read the file called file_to_read.txt and write its content to a file called output.txt", "Read the file called file_to_read.txt "
"and write its content to a file called output.txt",
1, 1,
"ReadFile", "ReadFile",
False, False,
@ -28,7 +34,11 @@ import time
], ],
) )
def test_entire_workflow( def test_entire_workflow(
eval_id, input_text, expected_artifact_length, test_name, should_be_successful eval_id: str,
input_text: str,
expected_artifact_length: int,
test_name: str,
should_be_successful: bool,
): ):
task_request = {"eval_id": eval_id, "input": input_text} task_request = {"eval_id": eval_id, "input": input_text}
response = requests.get(f"{URL_AGENT}/agent/tasks") response = requests.get(f"{URL_AGENT}/agent/tasks")
@ -64,7 +74,7 @@ def test_entire_workflow(
) )
assert step_response.status_code == 200 assert step_response.status_code == 200
step_response = step_response.json() step_response = step_response.json()
assert step_response["is_last"] == True # Assuming is_last is always True assert step_response["is_last"] is True # Assuming is_last is always True
eval_response = requests.post( eval_response = requests.post(
URL_BENCHMARK + "/agent/tasks/" + task_response_benchmark_id + "/evaluations", URL_BENCHMARK + "/agent/tasks/" + task_response_benchmark_id + "/evaluations",

8
cli.py
View File

@ -131,7 +131,9 @@ def start(agent_name: str, no_setup: bool):
script_dir = os.path.dirname(os.path.realpath(__file__)) script_dir = os.path.dirname(os.path.realpath(__file__))
agent_dir = os.path.join( agent_dir = os.path.join(
script_dir, script_dir,
f"agents/{agent_name}" if agent_name not in ["autogpt", "forge"] else agent_name, f"agents/{agent_name}"
if agent_name not in ["autogpt", "forge"]
else agent_name,
) )
run_command = os.path.join(agent_dir, "run") run_command = os.path.join(agent_dir, "run")
run_bench_command = os.path.join(agent_dir, "run_benchmark") run_bench_command = os.path.join(agent_dir, "run_benchmark")
@ -247,7 +249,9 @@ def start(agent_name, subprocess_args):
script_dir = os.path.dirname(os.path.realpath(__file__)) script_dir = os.path.dirname(os.path.realpath(__file__))
agent_dir = os.path.join( agent_dir = os.path.join(
script_dir, script_dir,
f"agents/{agent_name}" if agent_name not in ["autogpt", "forge"] else agent_name, f"agents/{agent_name}"
if agent_name not in ["autogpt", "forge"]
else agent_name,
) )
benchmark_script = os.path.join(agent_dir, "run_benchmark") benchmark_script = os.path.join(agent_dir, "run_benchmark")
if os.path.exists(agent_dir) and os.path.isfile(benchmark_script): if os.path.exists(agent_dir) and os.path.isfile(benchmark_script):

View File

@ -202,7 +202,7 @@ class MyAgent(Agent):
def __init__( def __init__(
self, self,
settings: AgentSettings, settings: AgentSettings,
llm_provider: ChatModelProvider, llm_provider: MultiProvider,
file_storage: FileStorage, file_storage: FileStorage,
legacy_config: Config, legacy_config: Config,
): ):
@ -219,7 +219,7 @@ class MyAgent(Agent):
def __init__( def __init__(
self, self,
settings: AgentSettings, settings: AgentSettings,
llm_provider: ChatModelProvider, llm_provider: MultiProvider,
file_storage: FileStorage, file_storage: FileStorage,
legacy_config: Config, legacy_config: Config,
): ):

View File

@ -1,15 +1,11 @@
[flake8] [flake8]
max-line-length = 88 max-line-length = 88
select = "E303, W293, W292, E305, E231, E302" # Ignore rules that conflict with Black code style
extend-ignore = E203, W503
exclude = exclude =
.tox, .git,
__pycache__, __pycache__/,
*.pyc, *.pyc,
.env .pytest_cache/,
venv*/*, venv*/,
.venv/*, .venv/,
reports/*,
dist/*,
agent/*,
code,
agbenchmark/challenges/*

5
forge/.gitignore vendored
View File

@ -160,7 +160,8 @@ CURRENT_BULLETIN.md
agbenchmark_config/workspace agbenchmark_config/workspace
agbenchmark_config/reports agbenchmark_config/reports
*.sqlite *.sqlite*
*.db
.agbench .agbench
.agbenchmark .agbenchmark
.benchmarks .benchmarks
@ -168,7 +169,7 @@ agbenchmark_config/reports
.pytest_cache .pytest_cache
.vscode .vscode
ig_* ig_*
agent.db
agbenchmark_config/updates.json agbenchmark_config/updates.json
agbenchmark_config/challenges_already_beaten.json agbenchmark_config/challenges_already_beaten.json
agbenchmark_config/temp_folder/* agbenchmark_config/temp_folder/*
test_workspace/

View File

@ -1,43 +0,0 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-added-large-files
args: ['--maxkb=500']
- id: check-byte-order-marker
- id: check-case-conflict
- id: check-merge-conflict
- id: check-symlinks
- id: debug-statements
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
language_version: python3.11
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
language_version: python3.11
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: 'v1.3.0'
# hooks:
# - id: mypy
- repo: local
hooks:
- id: autoflake
name: autoflake
entry: autoflake --in-place --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring forge/autogpt
language: python
types: [ python ]
# Mono repo has bronken this TODO: fix
# - id: pytest-check
# name: pytest-check
# entry: pytest
# language: system
# pass_filenames: false
# always_run: true

View File

@ -9,27 +9,24 @@ from forge.logging.config import configure_logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logo = """\n\n logo = """\n\n
d8888 888 .d8888b. 8888888b. 88888888888 d8888 888 .d8888b. 8888888b. 88888888888
d88888 888 d88P Y88b 888 Y88b 888 d88P888 888 888 888 888 888 888
d88P888 888 888 888 888 888 888 d88P 888 888 888 888888 .d88b. 888 888 d88P 888
d88P 888 888 888 888888 .d88b. 888 888 d88P 888 d88P 888 888 888 888 d88""88b 888 88888 8888888P" 888
d88P 888 888 888 888 d88""88b 888 88888 8888888P" 888 d88P 888 888 888 888 888 888 888 888 888 888
d88P 888 888 888 888 888 888 888 888 888 888 d8888888888 Y88b 888 Y88b. Y88..88P Y88b d88P 888 888
d8888888888 Y88b 888 Y88b. Y88..88P Y88b d88P 888 888 d88P 888 "Y88888 "Y888 "Y88P" "Y8888P88 888 888
d88P 888 "Y88888 "Y888 "Y88P" "Y8888P88 888 888
8888888888
888
8888888888 888 .d88b. 888d888 .d88b. .d88b.
888 888888 d88""88b 888P" d88P"88b d8P Y8b
888 888 888 888 888 888 888 88888888
8888888 .d88b. 888d888 .d88b. .d88b. 888 Y88..88P 888 Y88b 888 Y8b.
888 d88""88b 888P" d88P"88b d8P Y8b 888 "Y88P" 888 "Y88888 "Y8888
888 888 888 888 888 888 88888888 888
888 Y88..88P 888 Y88b 888 Y8b. Y8b d88P
888 "Y88P" 888 "Y88888 "Y8888
888
Y8b d88P
"Y88P" v0.1.0 "Y88P" v0.1.0
\n""" \n"""

View File

@ -1,15 +1,7 @@
from .base import AgentMeta, BaseAgent, BaseAgentConfiguration, BaseAgentSettings from .base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
from .components import (
AgentComponent, __all__ = [
ComponentEndpointError, "BaseAgent",
ComponentSystemError, "BaseAgentConfiguration",
EndpointPipelineError, "BaseAgentSettings",
) ]
from .protocols import (
AfterExecute,
AfterParse,
CommandProvider,
DirectiveProvider,
ExecutionFailure,
MessageProvider,
)

View File

@ -24,7 +24,6 @@ from forge.agent_protocol.models.task import (
TaskStepsListResponse, TaskStepsListResponse,
) )
from forge.file_storage.base import FileStorage from forge.file_storage.base import FileStorage
from forge.utils.exceptions import NotFoundError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -79,7 +78,8 @@ class Agent:
else: else:
logger.warning( logger.warning(
f"Frontend not found. {frontend_path} does not exist. The frontend will not be served" f"Frontend not found. {frontend_path} does not exist. "
"The frontend will not be served."
) )
app.add_middleware(AgentMiddleware, agent=self) app.add_middleware(AgentMiddleware, agent=self)
@ -94,34 +94,25 @@ class Agent:
""" """
Create a task for the agent. Create a task for the agent.
""" """
try: task = await self.db.create_task(
task = await self.db.create_task( input=task_request.input,
input=task_request.input, additional_input=task_request.additional_input,
additional_input=task_request.additional_input, )
) return task
return task
except Exception as e:
raise
async def list_tasks(self, page: int = 1, pageSize: int = 10) -> TaskListResponse: async def list_tasks(self, page: int = 1, pageSize: int = 10) -> TaskListResponse:
""" """
List all tasks that the agent has created. List all tasks that the agent has created.
""" """
try: tasks, pagination = await self.db.list_tasks(page, pageSize)
tasks, pagination = await self.db.list_tasks(page, pageSize) response = TaskListResponse(tasks=tasks, pagination=pagination)
response = TaskListResponse(tasks=tasks, pagination=pagination) return response
return response
except Exception as e:
raise
async def get_task(self, task_id: str) -> Task: async def get_task(self, task_id: str) -> Task:
""" """
Get a task by ID. Get a task by ID.
""" """
try: task = await self.db.get_task(task_id)
task = await self.db.get_task(task_id)
except Exception as e:
raise
return task return task
async def list_steps( async def list_steps(
@ -130,12 +121,9 @@ class Agent:
""" """
List the IDs of all steps that the task has created. List the IDs of all steps that the task has created.
""" """
try: steps, pagination = await self.db.list_steps(task_id, page, pageSize)
steps, pagination = await self.db.list_steps(task_id, page, pageSize) response = TaskStepsListResponse(steps=steps, pagination=pagination)
response = TaskStepsListResponse(steps=steps, pagination=pagination) return response
return response
except Exception as e:
raise
async def execute_step(self, task_id: str, step_request: StepRequestBody) -> Step: async def execute_step(self, task_id: str, step_request: StepRequestBody) -> Step:
""" """
@ -147,11 +135,8 @@ class Agent:
""" """
Get a step by ID. Get a step by ID.
""" """
try: step = await self.db.get_step(task_id, step_id)
step = await self.db.get_step(task_id, step_id) return step
return step
except Exception as e:
raise
async def list_artifacts( async def list_artifacts(
self, task_id: str, page: int = 1, pageSize: int = 10 self, task_id: str, page: int = 1, pageSize: int = 10
@ -159,62 +144,45 @@ class Agent:
""" """
List the artifacts that the task has created. List the artifacts that the task has created.
""" """
try: artifacts, pagination = await self.db.list_artifacts(task_id, page, pageSize)
artifacts, pagination = await self.db.list_artifacts( return TaskArtifactsListResponse(artifacts=artifacts, pagination=pagination)
task_id, page, pageSize
)
return TaskArtifactsListResponse(artifacts=artifacts, pagination=pagination)
except Exception as e:
raise
async def create_artifact( async def create_artifact(
self, task_id: str, file: UploadFile, relative_path: str self, task_id: str, file: UploadFile, relative_path: str = ""
) -> Artifact: ) -> Artifact:
""" """
Create an artifact for the task. Create an artifact for the task.
""" """
data = None
file_name = file.filename or str(uuid4()) file_name = file.filename or str(uuid4())
try: data = b""
data = b"" while contents := file.file.read(1024 * 1024):
while contents := file.file.read(1024 * 1024): data += contents
data += contents # Check if relative path ends with filename
# Check if relative path ends with filename if relative_path.endswith(file_name):
if relative_path.endswith(file_name): file_path = relative_path
file_path = relative_path else:
else: file_path = os.path.join(relative_path, file_name)
file_path = os.path.join(relative_path, file_name)
await self.workspace.write_file(file_path, data) await self.workspace.write_file(file_path, data)
artifact = await self.db.create_artifact( artifact = await self.db.create_artifact(
task_id=task_id, task_id=task_id,
file_name=file_name, file_name=file_name,
relative_path=relative_path, relative_path=relative_path,
agent_created=False, agent_created=False,
) )
except Exception as e:
raise
return artifact return artifact
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact: async def get_artifact(self, task_id: str, artifact_id: str) -> StreamingResponse:
""" """
Get an artifact by ID. Get an artifact by ID.
""" """
try: artifact = await self.db.get_artifact(artifact_id)
artifact = await self.db.get_artifact(artifact_id) if artifact.file_name not in artifact.relative_path:
if artifact.file_name not in artifact.relative_path: file_path = os.path.join(artifact.relative_path, artifact.file_name)
file_path = os.path.join(artifact.relative_path, artifact.file_name) else:
else: file_path = artifact.relative_path
file_path = artifact.relative_path retrieved_artifact = self.workspace.read_file(file_path, binary=True)
retrieved_artifact = self.workspace.read_file(file_path)
except NotFoundError as e:
raise
except FileNotFoundError as e:
raise
except Exception as e:
raise
return StreamingResponse( return StreamingResponse(
BytesIO(retrieved_artifact), BytesIO(retrieved_artifact),

View File

@ -1,6 +1,7 @@
from pathlib import Path from pathlib import Path
import pytest import pytest
from fastapi import UploadFile
from forge.agent_protocol.database.db import AgentDB from forge.agent_protocol.database.db import AgentDB
from forge.agent_protocol.models.task import ( from forge.agent_protocol.models.task import (
@ -16,16 +17,23 @@ from .agent import Agent
@pytest.fixture @pytest.fixture
def agent(): def agent(test_workspace: Path):
db = AgentDB("sqlite:///test.db") db = AgentDB("sqlite:///test.db")
config = FileStorageConfiguration(root=Path("./test_workspace")) config = FileStorageConfiguration(root=test_workspace)
workspace = LocalFileStorage(config) workspace = LocalFileStorage(config)
return Agent(db, workspace) return Agent(db, workspace)
@pytest.mark.skip @pytest.fixture
def file_upload():
this_file = Path(__file__)
file_handle = this_file.open("rb")
yield UploadFile(file_handle, filename=this_file.name)
file_handle.close()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_task(agent): async def test_create_task(agent: Agent):
task_request = TaskRequestBody( task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"} input="test_input", additional_input={"input": "additional_test_input"}
) )
@ -33,20 +41,18 @@ async def test_create_task(agent):
assert task.input == "test_input" assert task.input == "test_input"
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_tasks(agent): async def test_list_tasks(agent: Agent):
task_request = TaskRequestBody( task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"} input="test_input", additional_input={"input": "additional_test_input"}
) )
task = await agent.create_task(task_request) await agent.create_task(task_request)
tasks = await agent.list_tasks() tasks = await agent.list_tasks()
assert isinstance(tasks, TaskListResponse) assert isinstance(tasks, TaskListResponse)
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_task(agent): async def test_get_task(agent: Agent):
task_request = TaskRequestBody( task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"} input="test_input", additional_input={"input": "additional_test_input"}
) )
@ -55,9 +61,9 @@ async def test_get_task(agent):
assert retrieved_task.task_id == task.task_id assert retrieved_task.task_id == task.task_id
@pytest.mark.skip @pytest.mark.xfail(reason="execute_step is not implemented")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_and_execute_step(agent): async def test_execute_step(agent: Agent):
task_request = TaskRequestBody( task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"} input="test_input", additional_input={"input": "additional_test_input"}
) )
@ -65,14 +71,14 @@ async def test_create_and_execute_step(agent):
step_request = StepRequestBody( step_request = StepRequestBody(
input="step_input", additional_input={"input": "additional_test_input"} input="step_input", additional_input={"input": "additional_test_input"}
) )
step = await agent.create_and_execute_step(task.task_id, step_request) step = await agent.execute_step(task.task_id, step_request)
assert step.input == "step_input" assert step.input == "step_input"
assert step.additional_input == {"input": "additional_test_input"} assert step.additional_input == {"input": "additional_test_input"}
@pytest.mark.skip @pytest.mark.xfail(reason="execute_step is not implemented")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_step(agent): async def test_get_step(agent: Agent):
task_request = TaskRequestBody( task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"} input="test_input", additional_input={"input": "additional_test_input"}
) )
@ -80,38 +86,52 @@ async def test_get_step(agent):
step_request = StepRequestBody( step_request = StepRequestBody(
input="step_input", additional_input={"input": "additional_test_input"} input="step_input", additional_input={"input": "additional_test_input"}
) )
step = await agent.create_and_execute_step(task.task_id, step_request) step = await agent.execute_step(task.task_id, step_request)
retrieved_step = await agent.get_step(task.task_id, step.step_id) retrieved_step = await agent.get_step(task.task_id, step.step_id)
assert retrieved_step.step_id == step.step_id assert retrieved_step.step_id == step.step_id
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_artifacts(agent): async def test_list_artifacts(agent: Agent):
artifacts = await agent.list_artifacts() tasks = await agent.list_tasks()
assert isinstance(artifacts, list) assert tasks.tasks, "No tasks in test.db"
artifacts = await agent.list_artifacts(tasks.tasks[0].task_id)
assert isinstance(artifacts.artifacts, list)
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_artifact(agent): async def test_create_artifact(agent: Agent, file_upload: UploadFile):
task_request = TaskRequestBody( task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"} input="test_input", additional_input={"input": "additional_test_input"}
) )
task = await agent.create_task(task_request) task = await agent.create_task(task_request)
artifact_request = ArtifactRequestBody(file=None, uri="test_uri") artifact = await agent.create_artifact(
artifact = await agent.create_artifact(task.task_id, artifact_request) task_id=task.task_id,
assert artifact.uri == "test_uri" file=file_upload,
relative_path=f"a_dir/{file_upload.filename}",
)
assert artifact.file_name == file_upload.filename
assert artifact.relative_path == f"a_dir/{file_upload.filename}"
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_artifact(agent): async def test_create_and_get_artifact(agent: Agent, file_upload: UploadFile):
task_request = TaskRequestBody( task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"} input="test_input", additional_input={"input": "additional_test_input"}
) )
task = await agent.create_task(task_request) task = await agent.create_task(task_request)
artifact_request = ArtifactRequestBody(file=None, uri="test_uri")
artifact = await agent.create_artifact(task.task_id, artifact_request) artifact = await agent.create_artifact(
task_id=task.task_id,
file=file_upload,
relative_path=f"b_dir/{file_upload.filename}",
)
await file_upload.seek(0)
file_upload_content = await file_upload.read()
retrieved_artifact = await agent.get_artifact(task.task_id, artifact.artifact_id) retrieved_artifact = await agent.get_artifact(task.task_id, artifact.artifact_id)
assert retrieved_artifact.artifact_id == artifact.artifact_id retrieved_artifact_content = bytearray()
async for b in retrieved_artifact.body_iterator:
retrieved_artifact_content.extend(b) # type: ignore
assert retrieved_artifact_content == file_upload_content

View File

@ -5,22 +5,21 @@ import inspect
import logging import logging
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Awaitable,
Callable, Callable,
Generic,
Iterator, Iterator,
Optional, Optional,
ParamSpec, ParamSpec,
TypeVar, TypeVar,
cast,
overload, overload,
) )
from colorama import Fore from colorama import Fore
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
if TYPE_CHECKING:
from forge.models.action import ActionProposal, ActionResult
from forge.agent import protocols from forge.agent import protocols
from forge.agent.components import ( from forge.agent.components import (
AgentComponent, AgentComponent,
@ -29,15 +28,10 @@ from forge.agent.components import (
) )
from forge.config.ai_directives import AIDirectives from forge.config.ai_directives import AIDirectives
from forge.config.ai_profile import AIProfile from forge.config.ai_profile import AIProfile
from forge.config.config import ConfigBuilder
from forge.llm.providers import CHAT_MODELS, ModelName, OpenAIModelName from forge.llm.providers import CHAT_MODELS, ModelName, OpenAIModelName
from forge.llm.providers.schema import ChatModelInfo from forge.llm.providers.schema import ChatModelInfo
from forge.models.config import ( from forge.models.action import ActionResult, AnyProposal
Configurable, from forge.models.config import SystemConfiguration, SystemSettings, UserConfigurable
SystemConfiguration,
SystemSettings,
UserConfigurable,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -133,17 +127,7 @@ class AgentMeta(ABCMeta):
return instance return instance
class BaseAgent(Generic[AnyProposal], metaclass=AgentMeta):
class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
C = TypeVar("C", bound=AgentComponent)
default_settings = BaseAgentSettings(
name="BaseAgent",
description=__doc__ if __doc__ else "",
)
def __init__( def __init__(
self, self,
settings: BaseAgentSettings, settings: BaseAgentSettings,
@ -173,13 +157,13 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
return self.config.send_token_limit or self.llm.max_tokens * 3 // 4 return self.config.send_token_limit or self.llm.max_tokens * 3 // 4
@abstractmethod @abstractmethod
async def propose_action(self) -> ActionProposal: async def propose_action(self) -> AnyProposal:
... ...
@abstractmethod @abstractmethod
async def execute( async def execute(
self, self,
proposal: ActionProposal, proposal: AnyProposal,
user_feedback: str = "", user_feedback: str = "",
) -> ActionResult: ) -> ActionResult:
... ...
@ -187,7 +171,7 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
@abstractmethod @abstractmethod
async def do_not_execute( async def do_not_execute(
self, self,
denied_proposal: ActionProposal, denied_proposal: AnyProposal,
user_feedback: str, user_feedback: str,
) -> ActionResult: ) -> ActionResult:
... ...
@ -203,13 +187,16 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
@overload @overload
async def run_pipeline( async def run_pipeline(
self, protocol_method: Callable[P, None], *args, retry_limit: int = 3 self,
protocol_method: Callable[P, None | Awaitable[None]],
*args,
retry_limit: int = 3,
) -> list[None]: ) -> list[None]:
... ...
async def run_pipeline( async def run_pipeline(
self, self,
protocol_method: Callable[P, Iterator[T] | None], protocol_method: Callable[P, Iterator[T] | None | Awaitable[None]],
*args, *args,
retry_limit: int = 3, retry_limit: int = 3,
) -> list[T] | list[None]: ) -> list[T] | list[None]:
@ -240,7 +227,10 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
) )
continue continue
method = getattr(component, method_name, None) method = cast(
Callable[..., Iterator[T] | None | Awaitable[None]] | None,
getattr(component, method_name, None),
)
if not callable(method): if not callable(method):
continue continue
@ -248,10 +238,9 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
while component_attempts < retry_limit: while component_attempts < retry_limit:
try: try:
component_args = self._selective_copy(args) component_args = self._selective_copy(args)
if inspect.iscoroutinefunction(method): result = method(*component_args)
result = await method(*component_args) if inspect.isawaitable(result):
else: result = await result
result = method(*component_args)
if result is not None: if result is not None:
method_result.extend(result) method_result.extend(result)
args = component_args args = component_args
@ -269,9 +258,9 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
break break
# Successful pipeline execution # Successful pipeline execution
break break
except EndpointPipelineError: except EndpointPipelineError as e:
self._trace.append( self._trace.append(
f"{Fore.LIGHTRED_EX}{component.__class__.__name__}: " f"{Fore.LIGHTRED_EX}{e.triggerer.__class__.__name__}: "
f"EndpointPipelineError{Fore.RESET}" f"EndpointPipelineError{Fore.RESET}"
) )
# Restart from the beginning on EndpointPipelineError # Restart from the beginning on EndpointPipelineError

View File

@ -36,8 +36,9 @@ class AgentComponent(ABC):
class ComponentEndpointError(Exception): class ComponentEndpointError(Exception):
"""Error of a single protocol method on a component.""" """Error of a single protocol method on a component."""
def __init__(self, message: str = ""): def __init__(self, message: str, component: AgentComponent):
self.message = message self.message = message
self.triggerer = component
super().__init__(message) super().__init__(message)

View File

@ -1,14 +1,13 @@
from abc import abstractmethod from abc import abstractmethod
from typing import TYPE_CHECKING, Iterator from typing import TYPE_CHECKING, Awaitable, Generic, Iterator
from forge.models.action import ActionResult, AnyProposal
from .components import AgentComponent from .components import AgentComponent
if TYPE_CHECKING: if TYPE_CHECKING:
from forge.command.command import Command from forge.command.command import Command
from forge.llm.providers import ChatMessage from forge.llm.providers import ChatMessage
from forge.models.action import ActionResult
from .base import ActionProposal
class DirectiveProvider(AgentComponent): class DirectiveProvider(AgentComponent):
@ -34,19 +33,19 @@ class MessageProvider(AgentComponent):
... ...
class AfterParse(AgentComponent): class AfterParse(AgentComponent, Generic[AnyProposal]):
@abstractmethod @abstractmethod
def after_parse(self, result: "ActionProposal") -> None: def after_parse(self, result: AnyProposal) -> None | Awaitable[None]:
... ...
class ExecutionFailure(AgentComponent): class ExecutionFailure(AgentComponent):
@abstractmethod @abstractmethod
def execution_failure(self, error: Exception) -> None: def execution_failure(self, error: Exception) -> None | Awaitable[None]:
... ...
class AfterExecute(AgentComponent): class AfterExecute(AgentComponent):
@abstractmethod @abstractmethod
def after_execute(self, result: "ActionResult") -> None: def after_execute(self, result: "ActionResult") -> None | Awaitable[None]:
... ...

View File

@ -1,39 +1,16 @@
""" """
Routes for the Agent Service. Routes for the Agent Service.
This module defines the API routes for the Agent service. While there are multiple endpoints provided by the service, This module defines the API routes for the Agent service.
the ones that require special attention due to their complexity are:
1. `execute_agent_task_step`: Developers and contributors should be especially careful when making modifications
This route is significant because this is where the agent actually performs the work. The function handles to these routes to ensure consistency and correctness in the system's behavior.
executing the next step for a task based on its current state, and it requires careful implementation to ensure
all scenarios (like the presence or absence of steps or a step marked as `last_step`) are handled correctly.
2. `upload_agent_task_artifacts`:
This route allows for the upload of artifacts, supporting various URI types (e.g., s3, gcs, ftp, http).
The support for different URI types makes it a bit more complex, and it's important to ensure that all
supported URI types are correctly managed. NOTE: The AutoGPT team will eventually handle the most common
uri types for you.
3. `create_agent_task`:
While this is a simpler route, it plays a crucial role in the workflow, as it's responsible for the creation
of a new task.
Developers and contributors should be especially careful when making modifications to these routes to ensure
consistency and correctness in the system's behavior.
""" """
import json
import logging import logging
from typing import Optional from typing import TYPE_CHECKING, Optional
from fastapi import APIRouter, Query, Request, Response, UploadFile from fastapi import APIRouter, HTTPException, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse from fastapi.responses import StreamingResponse
from forge.utils.exceptions import (
NotFoundError,
get_detailed_traceback,
get_exception_message,
)
from .models import ( from .models import (
Artifact, Artifact,
@ -46,6 +23,9 @@ from .models import (
TaskStepsListResponse, TaskStepsListResponse,
) )
if TYPE_CHECKING:
from forge.agent.agent import Agent
base_router = APIRouter() base_router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -73,10 +53,10 @@ async def create_agent_task(request: Request, task_request: TaskRequestBody) ->
Args: Args:
request (Request): FastAPI request object. request (Request): FastAPI request object.
task (TaskRequestBody): The task request containing input and additional input data. task (TaskRequestBody): The task request containing input data.
Returns: Returns:
Task: A new task with task_id, input, additional_input, and empty lists for artifacts and steps. Task: A new task with task_id, input, and additional_input set.
Example: Example:
Request (TaskRequestBody defined in schema.py): Request (TaskRequestBody defined in schema.py):
@ -93,46 +73,32 @@ async def create_agent_task(request: Request, task_request: TaskRequestBody) ->
"artifacts": [], "artifacts": [],
} }
""" """
agent = request["agent"] agent: "Agent" = request["agent"]
try: try:
task_request = await agent.create_task(task_request) task = await agent.create_task(task_request)
return Response( return task
content=task_request.json(),
status_code=200,
media_type="application/json",
)
except Exception: except Exception:
logger.exception(f"Error whilst trying to create a task: {task_request}") logger.exception(f"Error whilst trying to create a task: {task_request}")
return Response( raise
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
@base_router.get("/agent/tasks", tags=["agent"], response_model=TaskListResponse) @base_router.get("/agent/tasks", tags=["agent"], response_model=TaskListResponse)
async def list_agent_tasks( async def list_agent_tasks(
request: Request, request: Request,
page: Optional[int] = Query(1, ge=1), page: int = Query(1, ge=1),
page_size: Optional[int] = Query(10, ge=1), page_size: int = Query(10, ge=1),
) -> TaskListResponse: ) -> TaskListResponse:
""" """
Retrieves a paginated list of all tasks. Retrieves a paginated list of all tasks.
Args: Args:
request (Request): FastAPI request object. request (Request): FastAPI request object.
page (int, optional): The page number for pagination. Defaults to 1. page (int, optional): Page number for pagination. Default: 1
page_size (int, optional): The number of tasks per page for pagination. Defaults to 10. page_size (int, optional): Number of tasks per page for pagination. Default: 10
Returns: Returns:
TaskListResponse: A response object containing a list of tasks and pagination details. TaskListResponse: A list of tasks, and pagination details.
Example: Example:
Request: Request:
@ -158,34 +124,13 @@ async def list_agent_tasks(
} }
} }
""" """
agent = request["agent"] agent: "Agent" = request["agent"]
try: try:
tasks = await agent.list_tasks(page, page_size) tasks = await agent.list_tasks(page, page_size)
return Response( return tasks
content=tasks.json(),
status_code=200,
media_type="application/json",
)
except NotFoundError:
logger.exception("Error whilst trying to list tasks")
return Response(
content=json.dumps({"error": "Tasks not found"}),
status_code=404,
media_type="application/json",
)
except Exception: except Exception:
logger.exception("Error whilst trying to list tasks") logger.exception("Error whilst trying to list tasks")
return Response( raise
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
@base_router.get("/agent/tasks/{task_id}", tags=["agent"], response_model=Task) @base_router.get("/agent/tasks/{task_id}", tags=["agent"], response_model=Task)
@ -239,36 +184,14 @@ async def get_agent_task(request: Request, task_id: str) -> Task:
} }
] ]
} }
""" """ # noqa: E501
agent = request["agent"] agent: "Agent" = request["agent"]
try: try:
task = await agent.get_task(task_id) task = await agent.get_task(task_id)
return task
return Response(
content=task.json(),
status_code=200,
media_type="application/json",
)
except NotFoundError:
logger.exception(f"Error whilst trying to get task: {task_id}")
return Response(
content=json.dumps({"error": "Task not found"}),
status_code=404,
media_type="application/json",
)
except Exception: except Exception:
logger.exception(f"Error whilst trying to get task: {task_id}") logger.exception(f"Error whilst trying to get task: {task_id}")
return Response( raise
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
@base_router.get( @base_router.get(
@ -279,8 +202,8 @@ async def get_agent_task(request: Request, task_id: str) -> Task:
async def list_agent_task_steps( async def list_agent_task_steps(
request: Request, request: Request,
task_id: str, task_id: str,
page: Optional[int] = Query(1, ge=1), page: int = Query(1, ge=1),
page_size: Optional[int] = Query(10, ge=1, alias="pageSize"), page_size: int = Query(10, ge=1, alias="pageSize"),
) -> TaskStepsListResponse: ) -> TaskStepsListResponse:
""" """
Retrieves a paginated list of steps associated with a specific task. Retrieves a paginated list of steps associated with a specific task.
@ -289,10 +212,10 @@ async def list_agent_task_steps(
request (Request): FastAPI request object. request (Request): FastAPI request object.
task_id (str): The ID of the task. task_id (str): The ID of the task.
page (int, optional): The page number for pagination. Defaults to 1. page (int, optional): The page number for pagination. Defaults to 1.
page_size (int, optional): The number of steps per page for pagination. Defaults to 10. page_size (int, optional): Number of steps per page for pagination. Default: 10.
Returns: Returns:
TaskStepsListResponse: A response object containing a list of steps and pagination details. TaskStepsListResponse: A list of steps, and pagination details.
Example: Example:
Request: Request:
@ -315,54 +238,40 @@ async def list_agent_task_steps(
"pageSize": 10 "pageSize": 10
} }
} }
""" """ # noqa: E501
agent = request["agent"] agent: "Agent" = request["agent"]
try: try:
steps = await agent.list_steps(task_id, page, page_size) steps = await agent.list_steps(task_id, page, page_size)
return Response( return steps
content=steps.json(),
status_code=200,
media_type="application/json",
)
except NotFoundError:
logger.exception("Error whilst trying to list steps")
return Response(
content=json.dumps({"error": "Steps not found"}),
status_code=404,
media_type="application/json",
)
except Exception: except Exception:
logger.exception("Error whilst trying to list steps") logger.exception("Error whilst trying to list steps")
return Response( raise
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
@base_router.post("/agent/tasks/{task_id}/steps", tags=["agent"], response_model=Step) @base_router.post("/agent/tasks/{task_id}/steps", tags=["agent"], response_model=Step)
async def execute_agent_task_step( async def execute_agent_task_step(
request: Request, task_id: str, step: Optional[StepRequestBody] = None request: Request, task_id: str, step_request: Optional[StepRequestBody] = None
) -> Step: ) -> Step:
""" """
Executes the next step for a specified task based on the current task status and returns the Executes the next step for a specified task based on the current task status and
executed step with additional feedback fields. returns the executed step with additional feedback fields.
Depending on the current state of the task, the following scenarios are supported: This route is significant because this is where the agent actually performs work.
The function handles executing the next step for a task based on its current state,
and it requires careful implementation to ensure all scenarios (like the presence
or absence of steps or a step marked as `last_step`) are handled correctly.
Depending on the current state of the task, the following scenarios are possible:
1. No steps exist for the task. 1. No steps exist for the task.
2. There is at least one step already for the task, and the task does not have a completed step marked as `last_step`. 2. There is at least one step already for the task, and the task does not have a
completed step marked as `last_step`.
3. There is a completed step marked as `last_step` already on the task. 3. There is a completed step marked as `last_step` already on the task.
In each of these scenarios, a step object will be returned with two additional fields: `output` and `additional_output`. In each of these scenarios, a step object will be returned with two additional
fields: `output` and `additional_output`.
- `output`: Provides the primary response or feedback to the user. - `output`: Provides the primary response or feedback to the user.
- `additional_output`: Supplementary information or data. Its specific content is not strictly defined and can vary based on the step or agent's implementation. - `additional_output`: Supplementary information or data. Its specific content is
not strictly defined and can vary based on the step or agent's implementation.
Args: Args:
request (Request): FastAPI request object. request (Request): FastAPI request object.
@ -389,39 +298,17 @@ async def execute_agent_task_step(
... ...
} }
""" """
agent = request["agent"] agent: "Agent" = request["agent"]
try: try:
# An empty step request represents a yes to continue command # An empty step request represents a yes to continue command
if not step: if not step_request:
step = StepRequestBody(input="y") step_request = StepRequestBody(input="y")
step = await agent.execute_step(task_id, step) step = await agent.execute_step(task_id, step_request)
return step
return Response(
content=step.json(),
status_code=200,
media_type="application/json",
)
except NotFoundError:
logger.exception(f"Error whilst trying to execute a task step: {task_id}")
return Response(
content=json.dumps({"error": f"Task not found {task_id}"}),
status_code=404,
media_type="application/json",
)
except Exception: except Exception:
logger.exception(f"Error whilst trying to execute a task step: {task_id}") logger.exception(f"Error whilst trying to execute a task step: {task_id}")
return Response( raise
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
@base_router.get( @base_router.get(
@ -450,31 +337,13 @@ async def get_agent_task_step(request: Request, task_id: str, step_id: str) -> S
... ...
} }
""" """
agent = request["agent"] agent: "Agent" = request["agent"]
try: try:
step = await agent.get_step(task_id, step_id) step = await agent.get_step(task_id, step_id)
return step
return Response(content=step.json(), status_code=200)
except NotFoundError:
logger.exception(f"Error whilst trying to get step: {step_id}")
return Response(
content=json.dumps({"error": "Step not found"}),
status_code=404,
media_type="application/json",
)
except Exception: except Exception:
logger.exception(f"Error whilst trying to get step: {step_id}") logger.exception(f"Error whilst trying to get step: {step_id}")
return Response( raise
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
@base_router.get( @base_router.get(
@ -485,8 +354,8 @@ async def get_agent_task_step(request: Request, task_id: str, step_id: str) -> S
async def list_agent_task_artifacts( async def list_agent_task_artifacts(
request: Request, request: Request,
task_id: str, task_id: str,
page: Optional[int] = Query(1, ge=1), page: int = Query(1, ge=1),
page_size: Optional[int] = Query(10, ge=1, alias="pageSize"), page_size: int = Query(10, ge=1, alias="pageSize"),
) -> TaskArtifactsListResponse: ) -> TaskArtifactsListResponse:
""" """
Retrieves a paginated list of artifacts associated with a specific task. Retrieves a paginated list of artifacts associated with a specific task.
@ -495,10 +364,10 @@ async def list_agent_task_artifacts(
request (Request): FastAPI request object. request (Request): FastAPI request object.
task_id (str): The ID of the task. task_id (str): The ID of the task.
page (int, optional): The page number for pagination. Defaults to 1. page (int, optional): The page number for pagination. Defaults to 1.
page_size (int, optional): The number of items per page for pagination. Defaults to 10. page_size (int, optional): Number of items per page for pagination. Default: 10.
Returns: Returns:
TaskArtifactsListResponse: A response object containing a list of artifacts and pagination details. TaskArtifactsListResponse: A list of artifacts, and pagination details.
Example: Example:
Request: Request:
@ -518,52 +387,33 @@ async def list_agent_task_artifacts(
"pageSize": 10 "pageSize": 10
} }
} }
""" """ # noqa: E501
agent = request["agent"] agent: "Agent" = request["agent"]
try: try:
artifacts: TaskArtifactsListResponse = await agent.list_artifacts( artifacts = await agent.list_artifacts(task_id, page, page_size)
task_id, page, page_size
)
return artifacts return artifacts
except NotFoundError:
logger.exception("Error whilst trying to list artifacts")
return Response(
content=json.dumps({"error": "Artifacts not found for task_id"}),
status_code=404,
media_type="application/json",
)
except Exception: except Exception:
logger.exception("Error whilst trying to list artifacts") logger.exception("Error whilst trying to list artifacts")
return Response( raise
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
@base_router.post( @base_router.post(
"/agent/tasks/{task_id}/artifacts", tags=["agent"], response_model=Artifact "/agent/tasks/{task_id}/artifacts", tags=["agent"], response_model=Artifact
) )
async def upload_agent_task_artifacts( async def upload_agent_task_artifacts(
request: Request, task_id: str, file: UploadFile, relative_path: Optional[str] = "" request: Request, task_id: str, file: UploadFile, relative_path: str = ""
) -> Artifact: ) -> Artifact:
""" """
This endpoint is used to upload an artifact associated with a specific task. The artifact is provided as a file. This endpoint is used to upload an artifact (file) associated with a specific task.
Args: Args:
request (Request): The FastAPI request object. request (Request): The FastAPI request object.
task_id (str): The unique identifier of the task for which the artifact is being uploaded. task_id (str): The ID of the task for which the artifact is being uploaded.
file (UploadFile): The file being uploaded as an artifact. file (UploadFile): The file being uploaded as an artifact.
relative_path (str): The relative path for the file. This is a query parameter. relative_path (str): The relative path for the file. This is a query parameter.
Returns: Returns:
Artifact: An object containing metadata of the uploaded artifact, including its unique identifier. Artifact: Metadata object for the uploaded artifact, including its ID and path.
Example: Example:
Request: Request:
@ -579,35 +429,17 @@ async def upload_agent_task_artifacts(
"relative_path": "/my_folder/my_other_folder/", "relative_path": "/my_folder/my_other_folder/",
"file_name": "main.py" "file_name": "main.py"
} }
""" """ # noqa: E501
agent = request["agent"] agent: "Agent" = request["agent"]
if file is None: if file is None:
return Response( raise HTTPException(status_code=400, detail="File must be specified")
content=json.dumps({"error": "File must be specified"}),
status_code=404,
media_type="application/json",
)
try: try:
artifact = await agent.create_artifact(task_id, file, relative_path) artifact = await agent.create_artifact(task_id, file, relative_path)
return Response( return artifact
content=artifact.json(),
status_code=200,
media_type="application/json",
)
except Exception: except Exception:
logger.exception(f"Error whilst trying to upload artifact: {task_id}") logger.exception(f"Error whilst trying to upload artifact: {task_id}")
return Response( raise
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
@base_router.get( @base_router.get(
@ -617,7 +449,7 @@ async def upload_agent_task_artifacts(
) )
async def download_agent_task_artifact( async def download_agent_task_artifact(
request: Request, task_id: str, artifact_id: str request: Request, task_id: str, artifact_id: str
) -> FileResponse: ) -> StreamingResponse:
""" """
Downloads an artifact associated with a specific task. Downloads an artifact associated with a specific task.
@ -636,32 +468,9 @@ async def download_agent_task_artifact(
Response: Response:
<file_content_of_artifact> <file_content_of_artifact>
""" """
agent = request["agent"] agent: "Agent" = request["agent"]
try: try:
return await agent.get_artifact(task_id, artifact_id) return await agent.get_artifact(task_id, artifact_id)
except NotFoundError:
logger.exception(f"Error whilst trying to download artifact: {task_id}")
return Response(
content=json.dumps(
{
"error": f"Artifact not found "
"- task_id: {task_id}, artifact_id: {artifact_id}"
}
),
status_code=404,
media_type="application/json",
)
except Exception: except Exception:
logger.exception(f"Error whilst trying to download artifact: {task_id}") logger.exception(f"Error whilst trying to download artifact: {task_id}")
return Response( raise
content=json.dumps(
{
"error": f"Internal server error "
"- task_id: {task_id}, artifact_id: {artifact_id}",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)

View File

@ -1 +1,3 @@
from .db import AgentDB from .db import AgentDB
__all__ = ["AgentDB"]

View File

@ -4,23 +4,22 @@ It uses SQLite as the database and file store backend.
IT IS NOT ADVISED TO USE THIS IN PRODUCTION! IT IS NOT ADVISED TO USE THIS IN PRODUCTION!
""" """
import datetime
import logging import logging
import math import math
import uuid import uuid
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Tuple from typing import Any, Dict, List, Literal, Optional, Tuple
from sqlalchemy import ( from sqlalchemy import JSON, Boolean, DateTime, ForeignKey, create_engine
JSON,
Boolean,
Column,
DateTime,
ForeignKey,
String,
create_engine,
)
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import DeclarativeBase, joinedload, relationship, sessionmaker from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
joinedload,
mapped_column,
relationship,
sessionmaker,
)
from forge.utils.exceptions import NotFoundError from forge.utils.exceptions import NotFoundError
@ -32,18 +31,20 @@ logger = logging.getLogger(__name__)
class Base(DeclarativeBase): class Base(DeclarativeBase):
pass type_annotation_map = {
dict[str, Any]: JSON,
}
class TaskModel(Base): class TaskModel(Base):
__tablename__ = "tasks" __tablename__ = "tasks"
task_id = Column(String, primary_key=True, index=True) task_id: Mapped[str] = mapped_column(primary_key=True, index=True)
input = Column(String) input: Mapped[str]
additional_input = Column(JSON) additional_input: Mapped[dict[str, Any]] = mapped_column(default=dict)
created_at = Column(DateTime, default=datetime.datetime.utcnow) created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
modified_at = Column( modified_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
) )
artifacts = relationship("ArtifactModel", back_populates="task") artifacts = relationship("ArtifactModel", back_populates="task")
@ -52,35 +53,35 @@ class TaskModel(Base):
class StepModel(Base): class StepModel(Base):
__tablename__ = "steps" __tablename__ = "steps"
step_id = Column(String, primary_key=True, index=True) step_id: Mapped[str] = mapped_column(primary_key=True, index=True)
task_id = Column(String, ForeignKey("tasks.task_id")) task_id: Mapped[str] = mapped_column(ForeignKey("tasks.task_id"))
name = Column(String) name: Mapped[str]
input = Column(String) input: Mapped[str]
status = Column(String) status: Mapped[str]
output = Column(String) output: Mapped[Optional[str]]
is_last = Column(Boolean, default=False) is_last: Mapped[bool] = mapped_column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow) created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
modified_at = Column( modified_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
) )
additional_input = Column(JSON) additional_input: Mapped[dict[str, Any]] = mapped_column(default=dict)
additional_output = Column(JSON) additional_output: Mapped[Optional[dict[str, Any]]]
artifacts = relationship("ArtifactModel", back_populates="step") artifacts = relationship("ArtifactModel", back_populates="step")
class ArtifactModel(Base): class ArtifactModel(Base):
__tablename__ = "artifacts" __tablename__ = "artifacts"
artifact_id = Column(String, primary_key=True, index=True) artifact_id: Mapped[str] = mapped_column(primary_key=True, index=True)
task_id = Column(String, ForeignKey("tasks.task_id")) task_id: Mapped[str] = mapped_column(ForeignKey("tasks.task_id"))
step_id = Column(String, ForeignKey("steps.step_id")) step_id: Mapped[Optional[str]] = mapped_column(ForeignKey("steps.step_id"))
agent_created = Column(Boolean, default=False) agent_created: Mapped[bool] = mapped_column(default=False)
file_name = Column(String) file_name: Mapped[str]
relative_path = Column(String) relative_path: Mapped[str]
created_at = Column(DateTime, default=datetime.datetime.utcnow) created_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
modified_at = Column( modified_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow default=datetime.utcnow, onupdate=datetime.utcnow
) )
step = relationship("StepModel", back_populates="artifacts") step = relationship("StepModel", back_populates="artifacts")
@ -150,6 +151,10 @@ class AgentDB:
Base.metadata.create_all(self.engine) Base.metadata.create_all(self.engine)
self.Session = sessionmaker(bind=self.engine) self.Session = sessionmaker(bind=self.engine)
def close(self) -> None:
self.Session.close_all()
self.engine.dispose()
async def create_task( async def create_task(
self, input: Optional[str], additional_input: Optional[dict] = {} self, input: Optional[str], additional_input: Optional[dict] = {}
) -> Task: ) -> Task:
@ -172,8 +177,6 @@ class AgentDB:
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while creating task: {e}") logger.error(f"SQLAlchemy error while creating task: {e}")
raise raise
except NotFoundError as e:
raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error while creating task: {e}") logger.error(f"Unexpected error while creating task: {e}")
raise raise
@ -207,8 +210,6 @@ class AgentDB:
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while creating step: {e}") logger.error(f"SQLAlchemy error while creating step: {e}")
raise raise
except NotFoundError as e:
raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error while creating step: {e}") logger.error(f"Unexpected error while creating step: {e}")
raise raise
@ -237,7 +238,7 @@ class AgentDB:
session.close() session.close()
if self.debug_enabled: if self.debug_enabled:
logger.debug( logger.debug(
f"Artifact already exists with relative_path: {relative_path}" f"Artifact {file_name} already exists at {relative_path}/"
) )
return convert_to_artifact(existing_artifact) return convert_to_artifact(existing_artifact)
@ -254,14 +255,12 @@ class AgentDB:
session.refresh(new_artifact) session.refresh(new_artifact)
if self.debug_enabled: if self.debug_enabled:
logger.debug( logger.debug(
f"Created new artifact with artifact_id: {new_artifact.artifact_id}" f"Created new artifact with ID: {new_artifact.artifact_id}"
) )
return convert_to_artifact(new_artifact) return convert_to_artifact(new_artifact)
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while creating step: {e}") logger.error(f"SQLAlchemy error while creating step: {e}")
raise raise
except NotFoundError as e:
raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error while creating step: {e}") logger.error(f"Unexpected error while creating step: {e}")
raise raise
@ -285,8 +284,6 @@ class AgentDB:
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while getting task: {e}") logger.error(f"SQLAlchemy error while getting task: {e}")
raise raise
except NotFoundError as e:
raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error while getting task: {e}") logger.error(f"Unexpected error while getting task: {e}")
raise raise
@ -312,8 +309,6 @@ class AgentDB:
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while getting step: {e}") logger.error(f"SQLAlchemy error while getting step: {e}")
raise raise
except NotFoundError as e:
raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error while getting step: {e}") logger.error(f"Unexpected error while getting step: {e}")
raise raise
@ -337,8 +332,6 @@ class AgentDB:
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while getting artifact: {e}") logger.error(f"SQLAlchemy error while getting artifact: {e}")
raise raise
except NotFoundError as e:
raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error while getting artifact: {e}") logger.error(f"Unexpected error while getting artifact: {e}")
raise raise
@ -375,14 +368,13 @@ class AgentDB:
return await self.get_step(task_id, step_id) return await self.get_step(task_id, step_id)
else: else:
logger.error( logger.error(
f"Step not found for update with task_id: {task_id} and step_id: {step_id}" "Can't update non-existent Step with "
f"task_id: {task_id} and step_id: {step_id}"
) )
raise NotFoundError("Step not found") raise NotFoundError("Step not found")
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while getting step: {e}") logger.error(f"SQLAlchemy error while getting step: {e}")
raise raise
except NotFoundError as e:
raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error while getting step: {e}") logger.error(f"Unexpected error while getting step: {e}")
raise raise
@ -441,8 +433,6 @@ class AgentDB:
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while listing tasks: {e}") logger.error(f"SQLAlchemy error while listing tasks: {e}")
raise raise
except NotFoundError as e:
raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error while listing tasks: {e}") logger.error(f"Unexpected error while listing tasks: {e}")
raise raise
@ -475,8 +465,6 @@ class AgentDB:
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while listing steps: {e}") logger.error(f"SQLAlchemy error while listing steps: {e}")
raise raise
except NotFoundError as e:
raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error while listing steps: {e}") logger.error(f"Unexpected error while listing steps: {e}")
raise raise
@ -509,8 +497,6 @@ class AgentDB:
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while listing artifacts: {e}") logger.error(f"SQLAlchemy error while listing artifacts: {e}")
raise raise
except NotFoundError as e:
raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error while listing artifacts: {e}") logger.error(f"Unexpected error while listing artifacts: {e}")
raise raise

View File

@ -22,14 +22,27 @@ from forge.agent_protocol.models import (
) )
from forge.utils.exceptions import NotFoundError as DataNotFoundError from forge.utils.exceptions import NotFoundError as DataNotFoundError
TEST_DB_FILENAME = "test_db.sqlite3"
TEST_DB_URL = f"sqlite:///{TEST_DB_FILENAME}"
@pytest.mark.asyncio
def test_table_creation():
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
conn = sqlite3.connect("test_db.sqlite3") @pytest.fixture
cursor = conn.cursor() def agent_db():
db = AgentDB(TEST_DB_URL)
yield db
db.close()
os.remove(TEST_DB_FILENAME)
@pytest.fixture
def raw_db_connection(agent_db: AgentDB):
connection = sqlite3.connect(TEST_DB_FILENAME)
yield connection
connection.close()
def test_table_creation(raw_db_connection: sqlite3.Connection):
cursor = raw_db_connection.cursor()
# Test for tasks table existence # Test for tasks table existence
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='tasks'") cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='tasks'")
@ -45,8 +58,6 @@ def test_table_creation():
) )
assert cursor.fetchone() is not None assert cursor.fetchone() is not None
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_task_schema(): async def test_task_schema():
@ -84,7 +95,10 @@ async def test_step_schema():
name="Write to file", name="Write to file",
input="Write the words you receive to the file 'output.txt'.", input="Write the words you receive to the file 'output.txt'.",
status=StepStatus.created, status=StepStatus.created,
output="I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')>", output=(
"I am going to use the write_to_file command and write Washington "
"to a file called output.txt <write_to_file('output.txt', 'Washington')>"
),
artifacts=[ artifacts=[
Artifact( Artifact(
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56", artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
@ -101,13 +115,13 @@ async def test_step_schema():
assert step.step_id == "6bb1801a-fd80-45e8-899a-4dd723cc602e" assert step.step_id == "6bb1801a-fd80-45e8-899a-4dd723cc602e"
assert step.name == "Write to file" assert step.name == "Write to file"
assert step.status == StepStatus.created assert step.status == StepStatus.created
assert ( assert step.output == (
step.output "I am going to use the write_to_file command and write Washington "
== "I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')>" "to a file called output.txt <write_to_file('output.txt', 'Washington')>"
) )
assert len(step.artifacts) == 1 assert len(step.artifacts) == 1
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56" assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
assert step.is_last == False assert step.is_last is False
@pytest.mark.asyncio @pytest.mark.asyncio
@ -118,6 +132,7 @@ async def test_convert_to_task():
created_at=now, created_at=now,
modified_at=now, modified_at=now,
input="Write the words you receive to the file 'output.txt'.", input="Write the words you receive to the file 'output.txt'.",
additional_input={},
artifacts=[ artifacts=[
ArtifactModel( ArtifactModel(
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56", artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
@ -147,6 +162,7 @@ async def test_convert_to_step():
name="Write to file", name="Write to file",
status="created", status="created",
input="Write the words you receive to the file 'output.txt'.", input="Write the words you receive to the file 'output.txt'.",
additional_input={},
artifacts=[ artifacts=[
ArtifactModel( ArtifactModel(
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56", artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
@ -166,7 +182,7 @@ async def test_convert_to_step():
assert step.status == StepStatus.created assert step.status == StepStatus.created
assert len(step.artifacts) == 1 assert len(step.artifacts) == 1
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56" assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
assert step.is_last == False assert step.is_last is False
@pytest.mark.asyncio @pytest.mark.asyncio
@ -183,91 +199,67 @@ async def test_convert_to_artifact():
artifact = convert_to_artifact(artifact_model) artifact = convert_to_artifact(artifact_model)
assert artifact.artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56" assert artifact.artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
assert artifact.relative_path == "file:///path/to/main.py" assert artifact.relative_path == "file:///path/to/main.py"
assert artifact.agent_created == True assert artifact.agent_created is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_task(): async def test_create_task(agent_db: AgentDB):
# Having issues with pytest fixture so added setup and teardown in each test as a rapid workaround
# TODO: Fix this!
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
task = await agent_db.create_task("task_input") task = await agent_db.create_task("task_input")
assert task.input == "task_input" assert task.input == "task_input"
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_and_get_task(): async def test_create_and_get_task(agent_db: AgentDB):
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
task = await agent_db.create_task("test_input") task = await agent_db.create_task("test_input")
fetched_task = await agent_db.get_task(task.task_id) fetched_task = await agent_db.get_task(task.task_id)
assert fetched_task.input == "test_input" assert fetched_task.input == "test_input"
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_task_not_found(): async def test_get_task_not_found(agent_db: AgentDB):
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
with pytest.raises(DataNotFoundError): with pytest.raises(DataNotFoundError):
await agent_db.get_task(9999) await agent_db.get_task("9999")
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_and_get_step(): async def test_create_and_get_step(agent_db: AgentDB):
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
task = await agent_db.create_task("task_input") task = await agent_db.create_task("task_input")
step_input = StepInput(type="python/code") step_input = {"type": "python/code"}
request = StepRequestBody(input="test_input debug", additional_input=step_input) request = StepRequestBody(input="test_input debug", additional_input=step_input)
step = await agent_db.create_step(task.task_id, request) step = await agent_db.create_step(task.task_id, request)
step = await agent_db.get_step(task.task_id, step.step_id) step = await agent_db.get_step(task.task_id, step.step_id)
assert step.input == "test_input debug" assert step.input == "test_input debug"
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_updating_step(): async def test_updating_step(agent_db: AgentDB):
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
created_task = await agent_db.create_task("task_input") created_task = await agent_db.create_task("task_input")
step_input = StepInput(type="python/code") step_input = {"type": "python/code"}
request = StepRequestBody(input="test_input debug", additional_input=step_input) request = StepRequestBody(input="test_input debug", additional_input=step_input)
created_step = await agent_db.create_step(created_task.task_id, request) created_step = await agent_db.create_step(created_task.task_id, request)
await agent_db.update_step(created_task.task_id, created_step.step_id, "completed") await agent_db.update_step(created_task.task_id, created_step.step_id, "completed")
step = await agent_db.get_step(created_task.task_id, created_step.step_id) step = await agent_db.get_step(created_task.task_id, created_step.step_id)
assert step.status.value == "completed" assert step.status.value == "completed"
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_step_not_found(): async def test_get_step_not_found(agent_db: AgentDB):
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
with pytest.raises(DataNotFoundError): with pytest.raises(DataNotFoundError):
await agent_db.get_step(9999, 9999) await agent_db.get_step("9999", "9999")
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_artifact(): async def test_get_artifact(agent_db: AgentDB):
db_name = "sqlite:///test_db.sqlite3"
db = AgentDB(db_name)
# Given: A task and its corresponding artifact # Given: A task and its corresponding artifact
task = await db.create_task("test_input debug") task = await agent_db.create_task("test_input debug")
step_input = StepInput(type="python/code") step_input = {"type": "python/code"}
requst = StepRequestBody(input="test_input debug", additional_input=step_input) requst = StepRequestBody(input="test_input debug", additional_input=step_input)
step = await db.create_step(task.task_id, requst) step = await agent_db.create_step(task.task_id, requst)
# Create an artifact # Create an artifact
artifact = await db.create_artifact( artifact = await agent_db.create_artifact(
task_id=task.task_id, task_id=task.task_id,
file_name="test_get_artifact_sample_file.txt", file_name="test_get_artifact_sample_file.txt",
relative_path="file:///path/to/test_get_artifact_sample_file.txt", relative_path="file:///path/to/test_get_artifact_sample_file.txt",
@ -276,7 +268,7 @@ async def test_get_artifact():
) )
# When: The artifact is fetched by its ID # When: The artifact is fetched by its ID
fetched_artifact = await db.get_artifact(artifact.artifact_id) fetched_artifact = await agent_db.get_artifact(artifact.artifact_id)
# Then: The fetched artifact matches the original # Then: The fetched artifact matches the original
assert fetched_artifact.artifact_id == artifact.artifact_id assert fetched_artifact.artifact_id == artifact.artifact_id
@ -285,47 +277,37 @@ async def test_get_artifact():
== "file:///path/to/test_get_artifact_sample_file.txt" == "file:///path/to/test_get_artifact_sample_file.txt"
) )
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_tasks(): async def test_list_tasks(agent_db: AgentDB):
db_name = "sqlite:///test_db.sqlite3"
db = AgentDB(db_name)
# Given: Multiple tasks in the database # Given: Multiple tasks in the database
task1 = await db.create_task("test_input_1") task1 = await agent_db.create_task("test_input_1")
task2 = await db.create_task("test_input_2") task2 = await agent_db.create_task("test_input_2")
# When: All tasks are fetched # When: All tasks are fetched
fetched_tasks, pagination = await db.list_tasks() fetched_tasks, pagination = await agent_db.list_tasks()
# Then: The fetched tasks list includes the created tasks # Then: The fetched tasks list includes the created tasks
task_ids = [task.task_id for task in fetched_tasks] task_ids = [task.task_id for task in fetched_tasks]
assert task1.task_id in task_ids assert task1.task_id in task_ids
assert task2.task_id in task_ids assert task2.task_id in task_ids
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_steps(): async def test_list_steps(agent_db: AgentDB):
db_name = "sqlite:///test_db.sqlite3" step_input = {"type": "python/code"}
db = AgentDB(db_name) request = StepRequestBody(input="test_input debug", additional_input=step_input)
step_input = StepInput(type="python/code")
requst = StepRequestBody(input="test_input debug", additional_input=step_input)
# Given: A task and multiple steps for that task # Given: A task and multiple steps for that task
task = await db.create_task("test_input") task = await agent_db.create_task("test_input")
step1 = await db.create_step(task.task_id, requst) step1 = await agent_db.create_step(task.task_id, request)
requst = StepRequestBody(input="step two", additional_input=step_input) request = StepRequestBody(input="step two")
step2 = await db.create_step(task.task_id, requst) step2 = await agent_db.create_step(task.task_id, request)
# When: All steps for the task are fetched # When: All steps for the task are fetched
fetched_steps, pagination = await db.list_steps(task.task_id) fetched_steps, pagination = await agent_db.list_steps(task.task_id)
# Then: The fetched steps list includes the created steps # Then: The fetched steps list includes the created steps
step_ids = [step.step_id for step in fetched_steps] step_ids = [step.step_id for step in fetched_steps]
assert step1.step_id in step_ids assert step1.step_id in step_ids
assert step2.step_id in step_ids assert step2.step_id in step_ids
os.remove(db_name.split("///")[1])

View File

@ -1,4 +1,4 @@
from .artifact import Artifact, ArtifactUpload from .artifact import Artifact
from .pagination import Pagination from .pagination import Pagination
from .task import ( from .task import (
Step, Step,
@ -10,3 +10,16 @@ from .task import (
TaskRequestBody, TaskRequestBody,
TaskStepsListResponse, TaskStepsListResponse,
) )
__all__ = [
"Artifact",
"Pagination",
"Step",
"StepRequestBody",
"StepStatus",
"Task",
"TaskArtifactsListResponse",
"TaskListResponse",
"TaskRequestBody",
"TaskStepsListResponse",
]

View File

@ -3,15 +3,6 @@ from datetime import datetime
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class ArtifactUpload(BaseModel):
file: str = Field(..., description="File to upload.", format="binary")
relative_path: str = Field(
...,
description="Relative path of the artifact in the agent's workspace.",
example="python/code",
)
class Artifact(BaseModel): class Artifact(BaseModel):
created_at: datetime = Field( created_at: datetime = Field(
..., ...,

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import List, Optional from typing import Any, List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -17,7 +17,7 @@ class TaskRequestBody(BaseModel):
description="Input prompt for the task.", description="Input prompt for the task.",
example="Write the words you receive to the file 'output.txt'.", example="Write the words you receive to the file 'output.txt'.",
) )
additional_input: Optional[dict] = None additional_input: dict[str, Any] = Field(default_factory=dict)
class Task(TaskRequestBody): class Task(TaskRequestBody):
@ -38,8 +38,8 @@ class Task(TaskRequestBody):
description="The ID of the task.", description="The ID of the task.",
example="50da533e-3904-4401-8a07-c49adf88b5eb", example="50da533e-3904-4401-8a07-c49adf88b5eb",
) )
artifacts: Optional[List[Artifact]] = Field( artifacts: list[Artifact] = Field(
[], default_factory=list,
description="A list of artifacts that the task has produced.", description="A list of artifacts that the task has produced.",
example=[ example=[
"7a49f31c-f9c6-4346-a22c-e32bc5af4d8e", "7a49f31c-f9c6-4346-a22c-e32bc5af4d8e",
@ -50,14 +50,12 @@ class Task(TaskRequestBody):
class StepRequestBody(BaseModel): class StepRequestBody(BaseModel):
name: Optional[str] = Field( name: Optional[str] = Field(
None, description="The name of the task step.", example="Write to file" default=None, description="The name of the task step.", example="Write to file"
) )
input: Optional[str] = Field( input: str = Field(
None, ..., description="Input prompt for the step.", example="Washington"
description="Input prompt for the step.",
example="Washington",
) )
additional_input: Optional[dict] = None additional_input: dict[str, Any] = Field(default_factory=dict)
class StepStatus(Enum): class StepStatus(Enum):
@ -90,19 +88,23 @@ class Step(StepRequestBody):
example="6bb1801a-fd80-45e8-899a-4dd723cc602e", example="6bb1801a-fd80-45e8-899a-4dd723cc602e",
) )
name: Optional[str] = Field( name: Optional[str] = Field(
None, description="The name of the task step.", example="Write to file" default=None, description="The name of the task step.", example="Write to file"
) )
status: StepStatus = Field( status: StepStatus = Field(
..., description="The status of the task step.", example="created" ..., description="The status of the task step.", example="created"
) )
output: Optional[str] = Field( output: Optional[str] = Field(
None, default=None,
description="Output of the task step.", description="Output of the task step.",
example="I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')", example=(
"I am going to use the write_to_file command and write Washington "
"to a file called output.txt <write_to_file('output.txt', 'Washington')"
),
) )
additional_output: Optional[dict] = None additional_output: Optional[dict[str, Any]] = None
artifacts: Optional[List[Artifact]] = Field( artifacts: list[Artifact] = Field(
[], description="A list of artifacts that the step has produced." default_factory=list,
description="A list of artifacts that the step has produced.",
) )
is_last: bool = Field( is_last: bool = Field(
..., description="Whether this is the last step in the task.", example=True ..., description="Whether this is the last step in the task.", example=True

View File

@ -1,3 +1,5 @@
from .command import Command, CommandOutput, CommandParameter from .command import Command
from .decorator import command from .decorator import command
from .parameter import CommandParameter from .parameter import CommandParameter
__all__ = ["Command", "CommandParameter", "command"]

View File

@ -1,14 +1,16 @@
from __future__ import annotations from __future__ import annotations
import inspect import inspect
from typing import Any, Callable, Generic, ParamSpec, TypeVar from typing import Callable, Concatenate, Generic, ParamSpec, TypeVar, cast
from forge.agent.protocols import CommandProvider
from .parameter import CommandParameter from .parameter import CommandParameter
CommandOutput = Any
P = ParamSpec("P") P = ParamSpec("P")
CO = TypeVar("CO", bound=CommandOutput) CO = TypeVar("CO") # command output
_CP = TypeVar("_CP", bound=CommandProvider)
class Command(Generic[P, CO]): class Command(Generic[P, CO]):
@ -24,7 +26,7 @@ class Command(Generic[P, CO]):
self, self,
names: list[str], names: list[str],
description: str, description: str,
method: Callable[P, CO], method: Callable[Concatenate[_CP, P], CO],
parameters: list[CommandParameter], parameters: list[CommandParameter],
): ):
# Check if all parameters are provided # Check if all parameters are provided
@ -34,7 +36,9 @@ class Command(Generic[P, CO]):
) )
self.names = names self.names = names
self.description = description self.description = description
self.method = method # Method technically has a `self` parameter, but we can ignore that
# since Python passes it internally.
self.method = cast(Callable[P, CO], method)
self.parameters = parameters self.parameters = parameters
@property @property
@ -62,7 +66,8 @@ class Command(Generic[P, CO]):
def __str__(self) -> str: def __str__(self) -> str:
params = [ params = [
f"{param.name}: " f"{param.name}: "
+ ("%s" if param.spec.required else "Optional[%s]") % param.spec.type.value + ("%s" if param.spec.required else "Optional[%s]")
% (param.spec.type.value if param.spec.type else "Any")
for param in self.parameters for param in self.parameters
] ]
return ( return (

View File

@ -1,2 +1,4 @@
from .action_history import ActionHistoryComponent from .action_history import ActionHistoryComponent
from .model import Episode, EpisodicActionHistory from .model import Episode, EpisodicActionHistory
__all__ = ["ActionHistoryComponent", "Episode", "EpisodicActionHistory"]

View File

@ -1,27 +1,27 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Callable, Generic, Iterator, Optional from typing import TYPE_CHECKING, Callable, Iterator, Optional
from forge.agent.protocols import AfterExecute, AfterParse, MessageProvider from forge.agent.protocols import AfterExecute, AfterParse, MessageProvider
from forge.llm.prompting.utils import indent from forge.llm.prompting.utils import indent
from forge.llm.providers import ChatMessage, ChatModelProvider from forge.llm.providers import ChatMessage, MultiProvider
if TYPE_CHECKING: if TYPE_CHECKING:
from forge.config.config import Config from forge.config.config import Config
from .model import AP, ActionResult, Episode, EpisodicActionHistory from .model import ActionResult, AnyProposal, Episode, EpisodicActionHistory
class ActionHistoryComponent(MessageProvider, AfterParse, AfterExecute, Generic[AP]): class ActionHistoryComponent(MessageProvider, AfterParse[AnyProposal], AfterExecute):
"""Keeps track of the event history and provides a summary of the steps.""" """Keeps track of the event history and provides a summary of the steps."""
def __init__( def __init__(
self, self,
event_history: EpisodicActionHistory[AP], event_history: EpisodicActionHistory[AnyProposal],
max_tokens: int, max_tokens: int,
count_tokens: Callable[[str], int], count_tokens: Callable[[str], int],
legacy_config: Config, legacy_config: Config,
llm_provider: ChatModelProvider, llm_provider: MultiProvider,
) -> None: ) -> None:
self.event_history = event_history self.event_history = event_history
self.max_tokens = max_tokens self.max_tokens = max_tokens
@ -37,7 +37,7 @@ class ActionHistoryComponent(MessageProvider, AfterParse, AfterExecute, Generic[
): ):
yield ChatMessage.system(f"## Progress on your Task so far\n\n{progress}") yield ChatMessage.system(f"## Progress on your Task so far\n\n{progress}")
def after_parse(self, result: AP) -> None: def after_parse(self, result: AnyProposal) -> None:
self.event_history.register_action(result) self.event_history.register_action(result)
async def after_execute(self, result: ActionResult) -> None: async def after_execute(self, result: ActionResult) -> None:
@ -48,7 +48,7 @@ class ActionHistoryComponent(MessageProvider, AfterParse, AfterExecute, Generic[
def _compile_progress( def _compile_progress(
self, self,
episode_history: list[Episode], episode_history: list[Episode[AnyProposal]],
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
count_tokens: Optional[Callable[[str], int]] = None, count_tokens: Optional[Callable[[str], int]] = None,
) -> str: ) -> str:

View File

@ -1,25 +1,23 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from typing import TYPE_CHECKING, Generic, Iterator, TypeVar from typing import TYPE_CHECKING, Generic
from pydantic import Field from pydantic import Field
from pydantic.generics import GenericModel from pydantic.generics import GenericModel
from forge.content_processing.text import summarize_text from forge.content_processing.text import summarize_text
from forge.llm.prompting.utils import format_numbered_list, indent from forge.llm.prompting.utils import format_numbered_list, indent
from forge.models.action import ActionProposal, ActionResult from forge.models.action import ActionResult, AnyProposal
from forge.models.utils import ModelWithSummary from forge.models.utils import ModelWithSummary
if TYPE_CHECKING: if TYPE_CHECKING:
from forge.config.config import Config from forge.config.config import Config
from forge.llm.providers import ChatModelProvider from forge.llm.providers import MultiProvider
AP = TypeVar("AP", bound=ActionProposal)
class Episode(GenericModel, Generic[AP]): class Episode(GenericModel, Generic[AnyProposal]):
action: AP action: AnyProposal
result: ActionResult | None result: ActionResult | None
summary: str | None = None summary: str | None = None
@ -54,32 +52,29 @@ class Episode(GenericModel, Generic[AP]):
return executed_action + action_result return executed_action + action_result
class EpisodicActionHistory(GenericModel, Generic[AP]): class EpisodicActionHistory(GenericModel, Generic[AnyProposal]):
"""Utility container for an action history""" """Utility container for an action history"""
episodes: list[Episode[AP]] = Field(default_factory=list) episodes: list[Episode[AnyProposal]] = Field(default_factory=list)
cursor: int = 0 cursor: int = 0
_lock = asyncio.Lock() _lock = asyncio.Lock()
@property @property
def current_episode(self) -> Episode[AP] | None: def current_episode(self) -> Episode[AnyProposal] | None:
if self.cursor == len(self): if self.cursor == len(self):
return None return None
return self[self.cursor] return self[self.cursor]
def __getitem__(self, key: int) -> Episode[AP]: def __getitem__(self, key: int) -> Episode[AnyProposal]:
return self.episodes[key] return self.episodes[key]
def __iter__(self) -> Iterator[Episode[AP]]:
return iter(self.episodes)
def __len__(self) -> int: def __len__(self) -> int:
return len(self.episodes) return len(self.episodes)
def __bool__(self) -> bool: def __bool__(self) -> bool:
return len(self.episodes) > 0 return len(self.episodes) > 0
def register_action(self, action: AP) -> None: def register_action(self, action: AnyProposal) -> None:
if not self.current_episode: if not self.current_episode:
self.episodes.append(Episode(action=action, result=None)) self.episodes.append(Episode(action=action, result=None))
assert self.current_episode assert self.current_episode
@ -113,7 +108,7 @@ class EpisodicActionHistory(GenericModel, Generic[AP]):
self.cursor = len(self.episodes) self.cursor = len(self.episodes)
async def handle_compression( async def handle_compression(
self, llm_provider: ChatModelProvider, app_config: Config self, llm_provider: MultiProvider, app_config: Config
) -> None: ) -> None:
"""Compresses each episode in the action history using an LLM. """Compresses each episode in the action history using an LLM.

View File

@ -3,6 +3,11 @@ from .code_executor import (
DENYLIST_CONTROL, DENYLIST_CONTROL,
CodeExecutionError, CodeExecutionError,
CodeExecutorComponent, CodeExecutorComponent,
is_docker_available,
we_are_running_in_a_docker_container,
) )
__all__ = [
"ALLOWLIST_CONTROL",
"DENYLIST_CONTROL",
"CodeExecutionError",
"CodeExecutorComponent",
]

Some files were not shown because too many files have changed in this diff Show More