Set up unified pre-commit + CI w/ linting + type checking & FIX EVERYTHING (#7171)

- **FIX ALL LINT/TYPE ERRORS IN AUTOGPT, FORGE, AND BENCHMARK**

### Linting
- Clean up linter configs for `autogpt`, `forge`, and `benchmark`
- Add type checking with Pyright
- Create unified pre-commit config
- Create unified linting and type checking CI workflow

### Testing
- Synchronize CI test setups for `autogpt`, `forge`, and `benchmark`
   - Add missing pytest-cov to benchmark dependencies
- Mark GCS tests as slow to speed up pre-commit test runs
- Repair `forge` test suite
  - Add `AgentDB.close()` method for test DB teardown in db_test.py
  - Use actual temporary dir instead of forge/test_workspace/
- Move left-behind dependencies for moved `forge`-code to from autogpt to forge

### Notable type changes
- Replace uses of `ChatModelProvider` by `MultiProvider`
- Removed unnecessary exports from various __init__.py
- Simplify `FileStorage.open_file` signature by removing `IOBase` from return type union
  - Implement `S3BinaryIOWrapper(BinaryIO)` type interposer for `S3FileStorage`

- Expand overloads of `GCSFileStorage.open_file` for improved typing of read and write modes

  Had to silence type checking for the extra overloads, because (I think) Pyright is reporting a false-positive:
  https://github.com/microsoft/pyright/issues/8007

- Change `count_tokens`, `get_tokenizer`, `count_message_tokens` methods on `ModelProvider`s from class methods to instance methods

- Move `CompletionModelFunction.schema` method -> helper function `format_function_def_for_openai` in `forge.llm.providers.openai`

- Rename `ModelProvider` -> `BaseModelProvider`
- Rename `ChatModelProvider` -> `BaseChatModelProvider`
- Add type `ChatModelProvider` which is a union of all subclasses of `BaseChatModelProvider`

### Removed rather than fixed
- Remove deprecated and broken autogpt/agbenchmark_config/benchmarks.py
- Various base classes and properties on base classes in `forge.llm.providers.schema` and `forge.models.providers`

### Fixes for other issues that came to light
- Clean up `forge.agent_protocol.api_router`, `forge.agent_protocol.database`, and `forge.agent.agent`

- Add fallback behavior to `ImageGeneratorComponent`
   - Remove test for deprecated failure behavior

- Fix `agbenchmark.challenges.builtin` challenge exclusion mechanism on Windows

- Fix `_tool_calls_compat_extract_calls` in `forge.llm.providers.openai`

- Add support for `any` (= no type specified) in `JSONSchema.typescript_type`
This commit is contained in:
Reinier van der Leer 2024-05-28 05:04:21 +02:00 committed by GitHub
parent 2c13a2706c
commit f107ff8cf0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
147 changed files with 2897 additions and 2425 deletions

View File

@ -1,4 +1,4 @@
name: AutoGPT Python CI
name: AutoGPT CI
on:
push:
@ -24,57 +24,6 @@ defaults:
working-directory: autogpt
jobs:
lint:
runs-on: ubuntu-latest
env:
min-python-version: "3.10"
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python ${{ env.min-python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ env.min-python-version }}
- id: get_date
name: Get date
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
- name: Set up Python dependency cache
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: ${{ runner.os }}-poetry-${{ hashFiles('autogpt/pyproject.toml') }}-${{ steps.get_date.outputs.date }}
- name: Install Python dependencies
run: |
curl -sSL https://install.python-poetry.org | python3 -
poetry install
- name: Lint with flake8
run: poetry run flake8
- name: Check black formatting
run: poetry run black . --check
if: success() || failure()
- name: Check isort formatting
run: poetry run isort . --check
if: success() || failure()
# - name: Check mypy formatting
# run: poetry run mypy
# if: success() || failure()
# - name: Check for unused imports and pass statements
# run: |
# cmd="autoflake --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring autogpt tests"
# poetry run $cmd --check || (echo "You have unused imports or pass statements, please run '${cmd} --in-place'" && exit 1)
test:
permissions:
contents: read

View File

@ -1,4 +1,4 @@
name: AutoGPTs smoke test CI
name: Agent smoke tests
on:
workflow_dispatch:
@ -28,7 +28,7 @@ on:
- '!**/*.md'
jobs:
run-tests:
serve-agent-protocol:
runs-on: ubuntu-latest
strategy:
matrix:

View File

@ -1,4 +1,4 @@
name: Benchmark CI
name: AGBenchmark CI
on:
push:
@ -14,62 +14,91 @@ on:
- '!benchmark/reports/**'
- .github/workflows/benchmark-ci.yml
concurrency:
group: ${{ format('benchmark-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
cancel-in-progress: ${{ startsWith(github.event_name, 'pull_request') }}
defaults:
run:
shell: bash
env:
min-python-version: '3.10'
jobs:
lint:
runs-on: ubuntu-latest
test:
permissions:
contents: read
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
platform-os: [ubuntu, macos, macos-arm64, windows]
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
defaults:
run:
shell: bash
working-directory: benchmark
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
- name: Set up Python ${{ env.min-python-version }}
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ env.min-python-version }}
python-version: ${{ matrix.python-version }}
- id: get_date
name: Get date
working-directory: ./benchmark/
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
- name: Set up Python dependency cache
# On Windows, unpacking cached dependencies takes longer than just installing them
if: runner.os != 'Windows'
uses: actions/cache@v4
with:
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
key: poetry-${{ runner.os }}-${{ hashFiles('benchmark/poetry.lock') }}
- name: Install Poetry
working-directory: ./benchmark/
- name: Install Poetry (Unix)
if: runner.os != 'Windows'
run: |
curl -sSL https://install.python-poetry.org | python -
curl -sSL https://install.python-poetry.org | python3 -
- name: Install dependencies
working-directory: ./benchmark/
if [ "${{ runner.os }}" = "macOS" ]; then
PATH="$HOME/.local/bin:$PATH"
echo "$HOME/.local/bin" >> $GITHUB_PATH
fi
- name: Install Poetry (Windows)
if: runner.os == 'Windows'
shell: pwsh
run: |
export POETRY_VIRTUALENVS_IN_PROJECT=true
poetry install -vvv
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
- name: Lint with flake8
working-directory: ./benchmark/
run: poetry run flake8
$env:PATH += ";$env:APPDATA\Python\Scripts"
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
- name: Check black formatting
working-directory: ./benchmark/
run: poetry run black . --exclude test.py --check
if: success() || failure()
- name: Install Python dependencies
run: poetry install
- name: Check isort formatting
working-directory: ./benchmark/
run: poetry run isort . --check
if: success() || failure()
- name: Check for unused imports and pass statements
working-directory: ./benchmark/
- name: Run pytest with coverage
run: |
cmd="poetry run autoflake --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring agbenchmark"
$cmd --check || (echo "You have unused imports or pass statements, please run '${cmd} --in-place'" && exit 1)
if: success() || failure()
poetry run pytest -vv \
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
--durations=10 \
tests
env:
CI: true
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
tests-agbenchmark:
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: agbenchmark,${{ runner.os }}
self-test-with-agent:
runs-on: ubuntu-latest
strategy:
matrix:
@ -89,11 +118,11 @@ jobs:
python-version: ${{ env.min-python-version }}
- name: Install Poetry
working-directory: ./${{ matrix.agent-name }}/
run: |
curl -sSL https://install.python-poetry.org | python -
- name: Run regression tests
working-directory: .
run: |
./run agent start ${{ matrix.agent-name }}
cd ${{ matrix.agent-name }}
@ -125,7 +154,6 @@ jobs:
export BUILD_SKILL_TREE=true
poetry run agbenchmark --mock
poetry run pytest -vv -s tests
CHANGED=$(git diff --name-only | grep -E '(agbenchmark/challenges)|(../frontend/assets)') || echo "No diffs"
if [ ! -z "$CHANGED" ]; then

129
.github/workflows/forge-ci.yml vendored Normal file
View File

@ -0,0 +1,129 @@
name: Forge CI
on:
push:
branches: [ master, development, ci-test* ]
paths:
- '.github/workflows/forge-ci.yml'
- 'forge/**'
pull_request:
branches: [ master, development, release-* ]
paths:
- '.github/workflows/forge-ci.yml'
- 'forge/**'
concurrency:
group: ${{ format('forge-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
cancel-in-progress: ${{ startsWith(github.event_name, 'pull_request') }}
defaults:
run:
shell: bash
working-directory: forge
jobs:
test:
permissions:
contents: read
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
platform-os: [ubuntu, macos, macos-arm64, windows]
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
steps:
# Quite slow on macOS (2~4 minutes to set up Docker)
# - name: Set up Docker (macOS)
# if: runner.os == 'macOS'
# uses: crazy-max/ghaction-setup-docker@v3
- name: Start MinIO service (Linux)
if: runner.os == 'Linux'
working-directory: '.'
run: |
docker pull minio/minio:edge-cicd
docker run -d -p 9000:9000 minio/minio:edge-cicd
- name: Start MinIO service (macOS)
if: runner.os == 'macOS'
working-directory: ${{ runner.temp }}
run: |
brew install minio/stable/minio
mkdir data
minio server ./data &
# No MinIO on Windows:
# - Windows doesn't support running Linux Docker containers
# - It doesn't seem possible to start background processes on Windows. They are
# killed after the step returns.
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Set up Python dependency cache
# On Windows, unpacking cached dependencies takes longer than just installing them
if: runner.os != 'Windows'
uses: actions/cache@v4
with:
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
key: poetry-${{ runner.os }}-${{ hashFiles('forge/poetry.lock') }}
- name: Install Poetry (Unix)
if: runner.os != 'Windows'
run: |
curl -sSL https://install.python-poetry.org | python3 -
if [ "${{ runner.os }}" = "macOS" ]; then
PATH="$HOME/.local/bin:$PATH"
echo "$HOME/.local/bin" >> $GITHUB_PATH
fi
- name: Install Poetry (Windows)
if: runner.os == 'Windows'
shell: pwsh
run: |
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
$env:PATH += ";$env:APPDATA\Python\Scripts"
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
- name: Install Python dependencies
run: poetry install
- name: Run pytest with coverage
run: |
poetry run pytest -vv \
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
--durations=10 \
forge
env:
CI: true
PLAIN_OUTPUT: True
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
AWS_ACCESS_KEY_ID: minioadmin
AWS_SECRET_ACCESS_KEY: minioadmin
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: forge,${{ runner.os }}
- name: Upload logs to artifact
if: always()
uses: actions/upload-artifact@v4
with:
name: test-logs
path: forge/logs/

151
.github/workflows/python-checks.yml vendored Normal file
View File

@ -0,0 +1,151 @@
name: Python checks
on:
push:
branches: [ master, development, ci-test* ]
paths:
- '.github/workflows/lint-ci.yml'
- 'autogpt/**'
- 'forge/**'
- 'benchmark/**'
- '**.py'
- '!autogpt/tests/vcr_cassettes'
pull_request:
branches: [ master, development, release-* ]
paths:
- '.github/workflows/lint-ci.yml'
- 'autogpt/**'
- 'forge/**'
- 'benchmark/**'
- '**.py'
- '!autogpt/tests/vcr_cassettes'
concurrency:
group: ${{ format('lint-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
cancel-in-progress: ${{ startsWith(github.event_name, 'pull_request') }}
defaults:
run:
shell: bash
jobs:
get-changed-parts:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- id: changes-in
name: Determine affected subprojects
uses: dorny/paths-filter@v3
with:
filters: |
autogpt:
- autogpt/autogpt/**
- autogpt/tests/**
- autogpt/poetry.lock
forge:
- forge/forge/**
- forge/tests/**
- forge/poetry.lock
benchmark:
- benchmark/agbenchmark/**
- benchmark/tests/**
- benchmark/poetry.lock
outputs:
changed-parts: ${{ steps.changes-in.outputs.changes }}
lint:
needs: get-changed-parts
runs-on: ubuntu-latest
env:
min-python-version: "3.10"
strategy:
matrix:
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
fail-fast: false
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python ${{ env.min-python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ env.min-python-version }}
- name: Set up Python dependency cache
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
# Install dependencies
- name: Install Python dependencies
run: poetry -C ${{ matrix.sub-package }} install
# Lint
- name: Lint (isort)
run: poetry run isort --check .
working-directory: ${{ matrix.sub-package }}
- name: Lint (Black)
if: success() || failure()
run: poetry run black --check .
working-directory: ${{ matrix.sub-package }}
- name: Lint (Flake8)
if: success() || failure()
run: poetry run flake8 .
working-directory: ${{ matrix.sub-package }}
types:
needs: get-changed-parts
runs-on: ubuntu-latest
env:
min-python-version: "3.10"
strategy:
matrix:
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
fail-fast: false
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python ${{ env.min-python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ env.min-python-version }}
- name: Set up Python dependency cache
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
# Install dependencies
- name: Install Python dependencies
run: poetry -C ${{ matrix.sub-package }} install
# Typecheck
- name: Typecheck
if: success() || failure()
run: poetry run pyright
working-directory: ${{ matrix.sub-package }}

127
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,127 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-added-large-files
args: ["--maxkb=500"]
- id: fix-byte-order-marker
- id: check-case-conflict
- id: check-merge-conflict
- id: check-symlinks
- id: debug-statements
- repo: local
# isort needs the context of which packages are installed to function, so we
# can't use a vendored isort pre-commit hook (which runs in its own isolated venv).
hooks:
- id: isort-autogpt
name: Lint (isort) - AutoGPT
entry: poetry -C autogpt run isort
files: ^autogpt/
types: [file, python]
language: system
- id: isort-forge
name: Lint (isort) - Forge
entry: poetry -C forge run isort
files: ^forge/
types: [file, python]
language: system
- id: isort-benchmark
name: Lint (isort) - Benchmark
entry: poetry -C benchmark run isort
files: ^benchmark/
types: [file, python]
language: system
- repo: https://github.com/psf/black
rev: 23.12.1
# Black has sensible defaults, doesn't need package context, and ignores
# everything in .gitignore, so it works fine without any config or arguments.
hooks:
- id: black
name: Lint (Black)
language_version: python3.10
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
# To have flake8 load the config of the individual subprojects, we have to call
# them separately.
hooks:
- id: flake8
name: Lint (Flake8) - AutoGPT
alias: flake8-autogpt
files: ^autogpt/(autogpt|scripts|tests)/
args: [--config=autogpt/.flake8]
- id: flake8
name: Lint (Flake8) - Forge
alias: flake8-forge
files: ^forge/(forge|tests)/
args: [--config=forge/.flake8]
- id: flake8
name: Lint (Flake8) - Benchmark
alias: flake8-benchmark
files: ^benchmark/(agbenchmark|tests)/((?!reports).)*[/.]
args: [--config=benchmark/.flake8]
- repo: local
# To have watertight type checking, we check *all* the files in an affected
# project. To trigger on poetry.lock we also reset the file `types` filter.
hooks:
- id: pyright
name: Typecheck - AutoGPT
alias: pyright-autogpt
entry: poetry -C autogpt run pyright
args: [-p, autogpt, autogpt]
# include forge source (since it's a path dependency) but exclude *_test.py files:
files: ^(autogpt/((autogpt|scripts|tests)/|poetry\.lock$)|forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
types: [file]
language: system
pass_filenames: false
- id: pyright
name: Typecheck - Forge
alias: pyright-forge
entry: poetry -C forge run pyright
args: [-p, forge, forge]
files: ^forge/(forge/|poetry\.lock$)
types: [file]
language: system
pass_filenames: false
- id: pyright
name: Typecheck - Benchmark
alias: pyright-benchmark
entry: poetry -C benchmark run pyright
args: [-p, benchmark, benchmark]
files: ^benchmark/(agbenchmark|tests)/
types: [file]
language: system
pass_filenames: false
- repo: local
hooks:
- id: pytest-autogpt
name: Run tests - AutoGPT (excl. slow tests)
entry: bash -c 'cd autogpt && poetry run pytest --cov=autogpt -m "not slow" tests/unit tests/integration'
# include forge source (since it's a path dependency) but exclude *_test.py files:
files: ^(autogpt/((autogpt|tests)/|poetry\.lock$)|forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
language: system
pass_filenames: false
- id: pytest-forge
name: Run tests - Forge (excl. slow tests)
entry: bash -c 'cd forge && poetry run pytest --cov=forge -m "not slow"'
files: ^forge/(forge/|tests/|poetry\.lock$)
language: system
pass_filenames: false
- id: pytest-benchmark
name: Run tests - Benchmark
entry: bash -c 'cd benchmark && poetry run pytest --cov=benchmark'
files: ^benchmark/(agbenchmark/|tests/|poetry\.lock$)
language: system
pass_filenames: false

View File

@ -1,11 +1,14 @@
[flake8]
max-line-length = 88
extend-exclude =
.*_cache/,
.venv,
# Ignore rules that conflict with Black code style
extend-ignore = E203, W503
exclude =
.git,
__pycache__/,
*.pyc,
.pytest_cache/,
venv*/,
.venv/,
data/,
logs/,
tests/unit/data/,
extend-ignore =
# No whitespace before ':' conflicts with Black style for slices
E203,

View File

@ -1,47 +0,0 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-added-large-files
args: ['--maxkb=500']
- id: check-byte-order-marker
- id: check-case-conflict
- id: check-merge-conflict
- id: check-symlinks
- id: debug-statements
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
language_version: python3.10
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
language_version: python3.10
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
hooks:
- id: flake8
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: 'v1.3.0'
# hooks:
# - id: mypy
- repo: local
hooks:
# - id: autoflake
# name: autoflake
# entry: autoflake --in-place --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring autogpt tests
# language: python
# types: [ python ]
- id: pytest-check
name: pytest-check
entry: bash -c 'cd autogpt && poetry run pytest --cov=autogpt tests/unit'
language: system
pass_filenames: false
always_run: true

View File

@ -1,77 +0,0 @@
import asyncio
import logging
import sys
from pathlib import Path
from forge.config.ai_profile import AIProfile
from forge.config.config import ConfigBuilder
from forge.file_storage import FileStorageBackendName, get_storage
from forge.logging.config import configure_logging
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
from autogpt.agents.agent_manager import AgentManager
from autogpt.app.main import _configure_llm_provider, run_interaction_loop
LOG_DIR = Path(__file__).parent / "logs"
def run_specific_agent(task: str, continuous_mode: bool = False) -> None:
agent = bootstrap_agent(task, continuous_mode)
asyncio.run(run_interaction_loop(agent))
def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
configure_logging(
level=logging.DEBUG,
log_dir=LOG_DIR,
plain_console_output=True,
)
config = ConfigBuilder.build_config_from_env()
config.continuous_mode = continuous_mode
config.continuous_limit = 20
config.noninteractive_mode = True
ai_profile = AIProfile(
ai_name="AutoGPT",
ai_role="a multi-purpose AI assistant.",
ai_goals=[task],
)
agent_settings = AgentSettings(
name=Agent.default_settings.name,
agent_id=AgentManager.generate_id("AutoGPT-benchmark"),
description=Agent.default_settings.description,
ai_profile=ai_profile,
config=AgentConfiguration(
fast_llm=config.fast_llm,
smart_llm=config.smart_llm,
allow_fs_access=not config.restrict_to_workspace,
use_functions_api=config.openai_functions,
),
history=Agent.default_settings.history.copy(deep=True),
)
local = config.file_storage_backend == FileStorageBackendName.LOCAL
restrict_to_root = not local or config.restrict_to_workspace
file_storage = get_storage(
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root
)
file_storage.initialize()
agent = Agent(
settings=agent_settings,
llm_provider=_configure_llm_provider(config),
file_storage=file_storage,
legacy_config=config,
)
return agent
if __name__ == "__main__":
# The first argument is the script name itself, second is the task
if len(sys.argv) != 2:
print("Usage: python script.py <task>")
sys.exit(1)
task = sys.argv[1]
run_specific_agent(task, continuous_mode=True)

View File

@ -4,7 +4,7 @@ from forge.config.ai_directives import AIDirectives
from forge.config.ai_profile import AIProfile
from forge.config.config import Config
from forge.file_storage.base import FileStorage
from forge.llm.providers import ChatModelProvider
from forge.llm.providers import MultiProvider
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
@ -15,7 +15,7 @@ def create_agent(
ai_profile: AIProfile,
app_config: Config,
file_storage: FileStorage,
llm_provider: ChatModelProvider,
llm_provider: MultiProvider,
directives: Optional[AIDirectives] = None,
) -> Agent:
if not task:
@ -39,7 +39,7 @@ def configure_agent_with_state(
state: AgentSettings,
app_config: Config,
file_storage: FileStorage,
llm_provider: ChatModelProvider,
llm_provider: MultiProvider,
) -> Agent:
return _configure_agent(
state=state,
@ -51,7 +51,7 @@ def configure_agent_with_state(
def _configure_agent(
app_config: Config,
llm_provider: ChatModelProvider,
llm_provider: MultiProvider,
file_storage: FileStorage,
agent_id: str = "",
task: str = "",
@ -59,20 +59,22 @@ def _configure_agent(
directives: Optional[AIDirectives] = None,
state: Optional[AgentSettings] = None,
) -> Agent:
if not (state or agent_id and task and ai_profile and directives):
if state:
agent_state = state
elif agent_id and task and ai_profile and directives:
agent_state = state or create_agent_state(
agent_id=agent_id,
task=task,
ai_profile=ai_profile,
directives=directives,
app_config=app_config,
)
else:
raise TypeError(
"Either (state) or (agent_id, task, ai_profile, directives)"
" must be specified"
)
agent_state = state or create_agent_state(
agent_id=agent_id,
task=task,
ai_profile=ai_profile,
directives=directives,
app_config=app_config,
)
return Agent(
settings=agent_state,
llm_provider=llm_provider,

View File

@ -7,7 +7,7 @@ from forge.file_storage.base import FileStorage
if TYPE_CHECKING:
from autogpt.agents.agent import Agent
from forge.config.config import Config
from forge.llm.providers.schema import ChatModelProvider
from forge.llm.providers import MultiProvider
from .configurators import _configure_agent
from .profile_generator import generate_agent_profile_for_task
@ -18,7 +18,7 @@ async def generate_agent_for_task(
task: str,
app_config: Config,
file_storage: FileStorage,
llm_provider: ChatModelProvider,
llm_provider: MultiProvider,
) -> Agent:
ai_profile, task_directives = await generate_agent_profile_for_task(
task=task,

View File

@ -5,10 +5,10 @@ from forge.config.ai_directives import AIDirectives
from forge.config.ai_profile import AIProfile
from forge.config.config import Config
from forge.llm.prompting import ChatPrompt, LanguageModelClassification, PromptStrategy
from forge.llm.providers import MultiProvider
from forge.llm.providers.schema import (
AssistantChatMessage,
ChatMessage,
ChatModelProvider,
CompletionModelFunction,
)
from forge.models.config import SystemConfiguration, UserConfigurable
@ -141,7 +141,7 @@ class AgentProfileGeneratorConfiguration(SystemConfiguration):
required=True,
),
},
).schema
).dict()
)
@ -160,7 +160,7 @@ class AgentProfileGenerator(PromptStrategy):
self._model_classification = model_classification
self._system_prompt_message = system_prompt
self._user_prompt_template = user_prompt_template
self._create_agent_function = CompletionModelFunction.parse(
self._create_agent_function = CompletionModelFunction.parse_obj(
create_agent_function
)
@ -183,7 +183,7 @@ class AgentProfileGenerator(PromptStrategy):
def parse_response_content(
self,
response_content: AssistantChatMessage,
response: AssistantChatMessage,
) -> tuple[AIProfile, AIDirectives]:
"""Parse the actual text response from the objective model.
@ -195,15 +195,15 @@ class AgentProfileGenerator(PromptStrategy):
"""
try:
if not response_content.tool_calls:
if not response.tool_calls:
raise ValueError(
f"LLM did not call {self._create_agent_function.name} function; "
"agent profile creation failed"
)
arguments: object = response_content.tool_calls[0].function.arguments
arguments: object = response.tool_calls[0].function.arguments
ai_profile = AIProfile(
ai_name=arguments.get("name"),
ai_role=arguments.get("description"),
ai_name=arguments.get("name"), # type: ignore
ai_role=arguments.get("description"), # type: ignore
)
ai_directives = AIDirectives(
best_practices=arguments.get("directives", {}).get("best_practices"),
@ -211,7 +211,7 @@ class AgentProfileGenerator(PromptStrategy):
resources=[],
)
except KeyError:
logger.debug(f"Failed to parse this response content: {response_content}")
logger.debug(f"Failed to parse this response content: {response}")
raise
return ai_profile, ai_directives
@ -219,7 +219,7 @@ class AgentProfileGenerator(PromptStrategy):
async def generate_agent_profile_for_task(
task: str,
app_config: Config,
llm_provider: ChatModelProvider,
llm_provider: MultiProvider,
) -> tuple[AIProfile, AIDirectives]:
"""Generates an AIConfig object from the given string.

View File

@ -24,7 +24,7 @@ class MyAgent(Agent):
def __init__(
self,
settings: AgentSettings,
llm_provider: ChatModelProvider,
llm_provider: MultiProvider
file_storage: FileStorage,
legacy_config: Config,
):

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import inspect
import logging
from datetime import datetime
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, ClassVar, Optional
import sentry_sdk
from forge.agent.base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
@ -14,7 +14,7 @@ from forge.agent.protocols import (
DirectiveProvider,
MessageProvider,
)
from forge.command.command import Command, CommandOutput
from forge.command.command import Command
from forge.components.action_history import (
ActionHistoryComponent,
EpisodicActionHistory,
@ -34,8 +34,8 @@ from forge.llm.prompting.utils import dump_prompt
from forge.llm.providers import (
AssistantFunctionCall,
ChatMessage,
ChatModelProvider,
ChatModelResponse,
MultiProvider,
)
from forge.llm.providers.utils import function_specs_from_commands
from forge.models.action import (
@ -76,7 +76,9 @@ class AgentConfiguration(BaseAgentConfiguration):
class AgentSettings(BaseAgentSettings):
config: AgentConfiguration = Field(default_factory=AgentConfiguration)
config: AgentConfiguration = Field( # type: ignore
default_factory=AgentConfiguration
)
history: EpisodicActionHistory[OneShotAgentActionProposal] = Field(
default_factory=EpisodicActionHistory[OneShotAgentActionProposal]
@ -86,8 +88,8 @@ class AgentSettings(BaseAgentSettings):
context: AgentContext = Field(default_factory=AgentContext)
class Agent(BaseAgent, Configurable[AgentSettings]):
default_settings: AgentSettings = AgentSettings(
class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
default_settings: ClassVar[AgentSettings] = AgentSettings(
name="Agent",
description=__doc__ if __doc__ else "",
)
@ -95,7 +97,7 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
def __init__(
self,
settings: AgentSettings,
llm_provider: ChatModelProvider,
llm_provider: MultiProvider,
file_storage: FileStorage,
legacy_config: Config,
):
@ -280,7 +282,7 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
return result
async def _execute_tool(self, tool_call: AssistantFunctionCall) -> CommandOutput:
async def _execute_tool(self, tool_call: AssistantFunctionCall) -> Any:
"""Execute the command and return the result
Args:

View File

@ -43,7 +43,7 @@ class AssistantThoughts(ModelWithSummary):
class OneShotAgentActionProposal(ActionProposal):
thoughts: AssistantThoughts
thoughts: AssistantThoughts # type: ignore
class OneShotAgentPromptConfiguration(SystemConfiguration):
@ -186,11 +186,8 @@ class OneShotAgentPromptStrategy(PromptStrategy):
def response_format_instruction(self, use_functions_api: bool) -> tuple[str, str]:
response_schema = self.response_schema.copy(deep=True)
if (
use_functions_api
and response_schema.properties
and "use_tool" in response_schema.properties
):
assert response_schema.properties
if use_functions_api and "use_tool" in response_schema.properties:
del response_schema.properties["use_tool"]
# Unindent for performance
@ -288,10 +285,10 @@ class OneShotAgentPromptStrategy(PromptStrategy):
"Parsing object extracted from LLM response:\n"
f"{json.dumps(assistant_reply_dict, indent=4)}"
)
parsed_response = OneShotAgentActionProposal.parse_obj(assistant_reply_dict)
if self.config.use_functions_api:
if not response.tool_calls:
raise InvalidAgentResponseError("Assistant did not use a tool")
parsed_response.use_tool = response.tool_calls[0].function
assistant_reply_dict["use_tool"] = response.tool_calls[0].function
parsed_response = OneShotAgentActionProposal.parse_obj(assistant_reply_dict)
return parsed_response

View File

@ -25,7 +25,7 @@ from forge.agent_protocol.models import (
)
from forge.config.config import Config
from forge.file_storage import FileStorage
from forge.llm.providers import ChatModelProvider, ModelProviderBudget
from forge.llm.providers import ModelProviderBudget, MultiProvider
from forge.models.action import ActionErrorResult, ActionSuccessResult
from forge.utils.const import ASK_COMMAND, FINISH_COMMAND
from forge.utils.exceptions import AgentFinished, NotFoundError
@ -49,7 +49,7 @@ class AgentProtocolServer:
app_config: Config,
database: AgentDB,
file_storage: FileStorage,
llm_provider: ChatModelProvider,
llm_provider: MultiProvider,
):
self.app_config = app_config
self.db = database
@ -444,9 +444,7 @@ class AgentProtocolServer:
agent_id = task_agent_id(task_id)
return self.file_storage.clone_with_subroot(f"agents/{agent_id}/workspace")
def _get_task_llm_provider(
self, task: Task, step_id: str = ""
) -> ChatModelProvider:
def _get_task_llm_provider(self, task: Task, step_id: str = "") -> MultiProvider:
"""
Configures the LLM provider with headers to link outgoing requests to the task.
"""

View File

@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Optional
from colorama import Fore, Style
from forge.agent_protocol.database import AgentDB
from forge.components.code_executor import (
from forge.components.code_executor.code_executor import (
is_docker_available,
we_are_running_in_a_docker_container,
)
@ -82,7 +82,9 @@ async def run_auto_gpt(
local = config.file_storage_backend == FileStorageBackendName.LOCAL
restrict_to_root = not local or config.restrict_to_workspace
file_storage = get_storage(
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root
config.file_storage_backend,
root_path=Path("data"),
restrict_to_root=restrict_to_root,
)
file_storage.initialize()
@ -353,7 +355,9 @@ async def run_auto_gpt_server(
local = config.file_storage_backend == FileStorageBackendName.LOCAL
restrict_to_root = not local or config.restrict_to_workspace
file_storage = get_storage(
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root
config.file_storage_backend,
root_path=Path("data"),
restrict_to_root=restrict_to_root,
)
file_storage.initialize()

View File

@ -7,7 +7,7 @@ import re
import socket
import sys
from pathlib import Path
from typing import Any, Callable, Coroutine, ParamSpec, TypeVar
from typing import Any, Callable, Coroutine, ParamSpec, TypeVar, cast
import requests
from colorama import Fore, Style
@ -88,7 +88,7 @@ def vcs_state_diverges_from_master() -> bool:
def get_git_user_email() -> str:
try:
repo = Repo(search_parent_directories=True)
return repo.config_reader().get_value("user", "email", default="")
return cast(str, repo.config_reader().get_value("user", "email", default=""))
except InvalidGitRepositoryError:
return ""

529
autogpt/poetry.lock generated

File diff suppressed because one or more lines are too long

View File

@ -1,9 +1,7 @@
[tool.poetry]
name = "agpt"
version = "0.5.0"
authors = [
"Significant Gravitas <support@agpt.co>",
]
authors = ["Significant Gravitas <support@agpt.co>"]
readme = "README.md"
description = "An open-source attempt to make GPT-4 autonomous"
homepage = "https://github.com/Significant-Gravitas/AutoGPT/tree/master/autogpt"
@ -30,11 +28,10 @@ charset-normalizer = "^3.1.0"
click = "*"
colorama = "^0.4.6"
distro = "^1.8.0"
en-core-web-sm = {url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl"}
en-core-web-sm = { url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl" }
fastapi = "^0.109.1"
ftfy = "^6.1.1"
google-api-python-client = "*"
gTTS = "^2.3.1"
hypercorn = "^0.14.4"
inflection = "*"
jsonschema = "*"
@ -58,21 +55,18 @@ openapi-python-client = "^0.14.0"
# Benchmarking
agbenchmark = { path = "../benchmark", optional = true }
# agbenchmark = {git = "https://github.com/Significant-Gravitas/AutoGPT.git", subdirectory = "benchmark", optional = true}
google-cloud-logging = "^3.8.0"
google-cloud-storage = "^2.13.0"
psycopg2-binary = "^2.9.9"
[tool.poetry.extras]
benchmark = ["agbenchmark"]
[tool.poetry.group.dev.dependencies]
black = "*"
boto3-stubs = {extras = ["s3"], version = "^1.33.6"}
flake8 = "*"
black = "^23.12.1"
flake8 = "^7.0.0"
gitpython = "^3.1.32"
isort = "*"
mypy = "*"
isort = "^5.13.1"
pre-commit = "*"
pyright = "^1.1.364"
types-beautifulsoup4 = "*"
types-colorama = "*"
types-Markdown = "*"
@ -89,7 +83,7 @@ pytest-integration = "*"
pytest-mock = "*"
pytest-recording = "*"
pytest-xdist = "*"
vcrpy = {git = "https://github.com/Significant-Gravitas/vcrpy.git", rev = "master"}
vcrpy = { git = "https://github.com/Significant-Gravitas/vcrpy.git", rev = "master" }
[build-system]
@ -101,50 +95,18 @@ build-backend = "poetry.core.masonry.api"
line-length = 88
target-version = ['py310']
include = '\.pyi?$'
packages = ["autogpt"]
extend-exclude = '.+/(dist|.venv|venv|build|data)/.+'
[tool.isort]
profile = "black"
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
line_length = 88
sections = [
"FUTURE",
"STDLIB",
"THIRDPARTY",
"FIRSTPARTY",
"LOCALFOLDER"
]
extend_skip = [
"agbenchmark_config/temp_folder/",
"data/",
]
skip_glob = ["data"]
[tool.mypy]
follow_imports = 'skip'
check_untyped_defs = true
disallow_untyped_calls = true
files = [
'autogpt/**/*.py',
'tests/**/*.py'
]
[[tool.mypy.overrides]]
module = [
'requests.*',
'yaml.*'
]
ignore_missing_imports = true
[tool.pyright]
pythonVersion = "3.10"
exclude = ["data/**", "**/node_modules", "**/__pycache__", "**/.*"]
ignore = ["../forge/**"]
[tool.pytest.ini_options]
markers = [
"requires_openai_api_key",
"requires_huggingface_api_key"
]
markers = ["slow", "requires_openai_api_key", "requires_huggingface_api_key"]

View File

@ -4,12 +4,12 @@ import sys
from importlib.metadata import version
try:
import poetry.factory # noqa
import poetry.factory # type: ignore # noqa
except ModuleNotFoundError:
os.system(f"{sys.executable} -m pip install 'poetry>=1.6.1,<2.0.0'")
from poetry.core.constraints.version.version import Version
from poetry.factory import Factory
from poetry.core.constraints.version.version import Version # type: ignore
from poetry.factory import Factory # type: ignore
def main():

View File

@ -20,7 +20,7 @@ from autogpt.app.utils import coroutine
)
@coroutine
async def generate_release_notes(repo_path: Optional[Path] = None):
logger = logging.getLogger(generate_release_notes.name)
logger = logging.getLogger(generate_release_notes.name) # pyright: ignore
repo = Repo(repo_path, search_parent_directories=True)
tags = list(repo.tags)

View File

@ -12,7 +12,7 @@ from forge.file_storage.local import (
FileStorageConfiguration,
LocalFileStorage,
)
from forge.llm.providers import ChatModelProvider
from forge.llm.providers import MultiProvider
from forge.logging.config import configure_logging
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
@ -71,14 +71,12 @@ def setup_logger(config: Config):
@pytest.fixture
def llm_provider(config: Config) -> ChatModelProvider:
def llm_provider(config: Config) -> MultiProvider:
return _configure_llm_provider(config)
@pytest.fixture
def agent(
config: Config, llm_provider: ChatModelProvider, storage: FileStorage
) -> Agent:
def agent(config: Config, llm_provider: MultiProvider, storage: FileStorage) -> Agent:
ai_profile = AIProfile(
ai_name="Base",
ai_role="A base AI",

View File

@ -1,13 +1,16 @@
from pathlib import Path
import pytest
from forge.config.ai_profile import AIProfile
from forge.config.config import Config
from forge.file_storage import FileStorageBackendName, get_storage
from forge.llm.providers import MultiProvider
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
@pytest.fixture
def dummy_agent(config: Config, llm_provider, memory_json_file):
def dummy_agent(config: Config, llm_provider: MultiProvider):
ai_profile = AIProfile(
ai_name="Dummy Agent",
ai_role="Dummy Role",
@ -31,7 +34,9 @@ def dummy_agent(config: Config, llm_provider, memory_json_file):
local = config.file_storage_backend == FileStorageBackendName.LOCAL
restrict_to_root = not local or config.restrict_to_workspace
file_storage = get_storage(
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root
config.file_storage_backend,
root_path=Path("data"),
restrict_to_root=restrict_to_root,
)
file_storage.initialize()

View File

@ -4,7 +4,7 @@ import tempfile
from pathlib import Path
import pytest
from forge.components.code_executor import (
from forge.components.code_executor.code_executor import (
ALLOWLIST_CONTROL,
CodeExecutorComponent,
is_docker_available,

View File

@ -257,17 +257,3 @@ def test_huggingface_fail_request_bad_image(
result = image_gen_component.generate_image("astronaut riding a horse", 512)
assert result == "Error creating image."
def test_huggingface_fail_missing_api_token(
mocker, image_gen_component: ImageGeneratorComponent, agent: Agent
):
agent.legacy_config.image_provider = "huggingface"
agent.legacy_config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"
# Mock requests.post to raise ValueError
mocker.patch("requests.post", side_effect=ValueError)
# Verify request raises an error.
with pytest.raises(ValueError):
image_gen_component.generate_image("astronaut riding a horse", 512)

View File

@ -67,8 +67,8 @@ def test_missing_azure_config(config: Config) -> None:
with pytest.raises(ValueError):
config.openai_credentials.load_azure_config(config_file)
assert config.openai_credentials.api_type != "azure"
assert config.openai_credentials.api_version == ""
assert config.openai_credentials.api_type != SecretStr("azure")
assert config.openai_credentials.api_version is None
assert config.openai_credentials.azure_model_to_deploy_id_map is None
@ -98,8 +98,8 @@ azure_model_map:
def test_azure_config(config_with_azure: Config) -> None:
assert (credentials := config_with_azure.openai_credentials) is not None
assert credentials.api_type == "azure"
assert credentials.api_version == "2023-06-01-preview"
assert credentials.api_type == SecretStr("azure")
assert credentials.api_version == SecretStr("2023-06-01-preview")
assert credentials.azure_endpoint == SecretStr("https://dummy.openai.azure.com")
assert credentials.azure_model_to_deploy_id_map == {
config_with_azure.fast_llm: "FAST-LLM_ID",

View File

@ -4,7 +4,7 @@ from pathlib import Path
import pytest
import pytest_asyncio
from forge.file_storage import GCSFileStorage, GCSFileStorageConfiguration
from forge.file_storage.gcs import GCSFileStorage, GCSFileStorageConfiguration
from google.auth.exceptions import GoogleAuthError
from google.cloud import storage
from google.cloud.exceptions import NotFound
@ -14,6 +14,8 @@ try:
except GoogleAuthError:
pytest.skip("Google Cloud Authentication not configured", allow_module_level=True)
pytestmark = pytest.mark.slow
@pytest.fixture(scope="module")
def gcs_bucket_name() -> str:
@ -26,7 +28,7 @@ def gcs_root() -> Path:
@pytest.fixture(scope="module")
def gcs_storage_uninitialized(gcs_bucket_name: str, gcs_root: Path) -> GCSFileStorage:
def gcs_storage_uninitialized(gcs_bucket_name: str, gcs_root: Path):
os.environ["STORAGE_BUCKET"] = gcs_bucket_name
storage_config = GCSFileStorageConfiguration.from_env()
storage_config.root = gcs_root
@ -52,7 +54,7 @@ def test_initialize(gcs_bucket_name: str, gcs_storage_uninitialized: GCSFileStor
@pytest.fixture(scope="module")
def gcs_storage(gcs_storage_uninitialized: GCSFileStorage) -> GCSFileStorage:
def gcs_storage(gcs_storage_uninitialized: GCSFileStorage):
(gcs_storage := gcs_storage_uninitialized).initialize()
yield gcs_storage # type: ignore
@ -77,7 +79,7 @@ TEST_FILES: list[tuple[str | Path, str]] = [
@pytest_asyncio.fixture
async def gcs_storage_with_files(gcs_storage: GCSFileStorage) -> GCSFileStorage:
async def gcs_storage_with_files(gcs_storage: GCSFileStorage):
for file_name, file_content in TEST_FILES:
gcs_storage._bucket.blob(
str(gcs_storage.get_path(file_name))

View File

@ -1,7 +1,7 @@
import json
import pytest
from forge.json import json_loads
from forge.json.parsing import json_loads
_JSON_FIXABLE: list[tuple[str, str]] = [
# Missing comma

View File

@ -1,7 +1,7 @@
from pathlib import Path
import pytest
from forge.file_storage import FileStorageConfiguration, LocalFileStorage
from forge.file_storage.local import FileStorageConfiguration, LocalFileStorage
_ACCESSIBLE_PATHS = [
Path("."),

View File

@ -5,7 +5,7 @@ from pathlib import Path
import pytest
import pytest_asyncio
from botocore.exceptions import ClientError
from forge.file_storage import S3FileStorage, S3FileStorageConfiguration
from forge.file_storage.s3 import S3FileStorage, S3FileStorageConfiguration
if not (os.getenv("S3_ENDPOINT_URL") and os.getenv("AWS_ACCESS_KEY_ID")):
pytest.skip("S3 environment variables are not set", allow_module_level=True)
@ -22,7 +22,7 @@ def s3_root() -> Path:
@pytest.fixture
def s3_storage_uninitialized(s3_bucket_name: str, s3_root: Path) -> S3FileStorage:
def s3_storage_uninitialized(s3_bucket_name: str, s3_root: Path):
os.environ["STORAGE_BUCKET"] = s3_bucket_name
storage_config = S3FileStorageConfiguration.from_env()
storage_config.root = s3_root
@ -36,12 +36,13 @@ def test_initialize(s3_bucket_name: str, s3_storage_uninitialized: S3FileStorage
# test that the bucket doesn't exist yet
with pytest.raises(ClientError):
s3.meta.client.head_bucket(Bucket=s3_bucket_name)
s3.meta.client.head_bucket(Bucket=s3_bucket_name) # pyright: ignore
s3_storage_uninitialized.initialize()
# test that the bucket has been created
s3.meta.client.head_bucket(Bucket=s3_bucket_name)
s3.meta.client.head_bucket(Bucket=s3_bucket_name) # pyright: ignore
# FIXME: remove the "pyright: ignore" comments after moving this test file to forge
def test_workspace_bucket_name(
@ -52,7 +53,7 @@ def test_workspace_bucket_name(
@pytest.fixture
def s3_storage(s3_storage_uninitialized: S3FileStorage) -> S3FileStorage:
def s3_storage(s3_storage_uninitialized: S3FileStorage):
(s3_storage := s3_storage_uninitialized).initialize()
yield s3_storage # type: ignore
@ -71,7 +72,7 @@ TEST_FILES: list[tuple[str | Path, str]] = [
@pytest_asyncio.fixture
async def s3_storage_with_files(s3_storage: S3FileStorage) -> S3FileStorage:
async def s3_storage_with_files(s3_storage: S3FileStorage):
for file_name, file_content in TEST_FILES:
s3_storage._bucket.Object(str(s3_storage.get_path(file_name))).put(
Body=file_content

View File

@ -1,6 +1,7 @@
import logging
import os
from hashlib import sha256
from typing import cast
import pytest
from openai import OpenAI
@ -53,11 +54,14 @@ def cached_openai_client(mocker: MockerFixture) -> OpenAI:
def _patched_prepare_options(self, options: FinalRequestOptions):
_prepare_options(options)
if not options.json_data:
return
headers: dict[str, str | Omit] = (
{**options.headers} if is_given(options.headers) else {}
)
options.headers = headers
data: dict = options.json_data
data = cast(dict, options.json_data)
logging.getLogger("cached_openai_client").debug(
f"Outgoing API request: {headers}\n{data if data else None}"

View File

@ -1,15 +1,12 @@
[flake8]
max-line-length = 88
select = "E303, W293, W291, W292, E305, E231, E302"
# Ignore rules that conflict with Black code style
extend-ignore = E203, W503
exclude =
.tox,
__pycache__,
__pycache__/,
*.pyc,
.env
venv*/*,
.venv/*,
reports/*,
dist/*,
agent/*,
code,
agbenchmark/challenges/*
.pytest_cache/,
venv*/,
.venv/,
reports/,
agbenchmark/reports/,

View File

@ -1,36 +0,0 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-added-large-files
args: ['--maxkb=500']
- id: check-byte-order-marker
- id: check-case-conflict
- id: check-merge-conflict
- id: check-symlinks
- id: debug-statements
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
language_version: python3.10
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
language_version: python3.10
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.3.0'
hooks:
- id: mypy
- repo: local
hooks:
- id: autoflake
name: autoflake
entry: autoflake --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring --in-place agbenchmark
language: python
types: [ python ]

View File

@ -28,7 +28,7 @@ async def run_api_agent(
configuration = Configuration(host=config.host)
async with ApiClient(configuration) as api_client:
api_instance = AgentApi(api_client)
task_request_body = TaskRequestBody(input=task)
task_request_body = TaskRequestBody(input=task, additional_input=None)
start_time = time.time()
response = await api_instance.create_agent_task(

View File

@ -106,8 +106,8 @@ def find_agbenchmark_without_uvicorn():
class CreateReportRequest(BaseModel):
test: str = None
test_run_id: str = None
test: str
test_run_id: str
# category: Optional[str] = []
mock: Optional[bool] = False
@ -178,8 +178,8 @@ def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
logger.debug(f"Benchmark finished running in {time.time() - start_time} s")
# List all folders in the current working directory
path_reports = agbenchmark_config.reports_folder
folders = [folder for folder in path_reports.iterdir() if folder.is_dir()]
reports_folder = agbenchmark_config.reports_folder
folders = [folder for folder in reports_folder.iterdir() if folder.is_dir()]
# Sort the folders based on their names
sorted_folders = sorted(folders, key=lambda x: x.name)
@ -196,13 +196,14 @@ def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
data = json.load(file)
logger.debug(f"Report data: {data}")
else:
logger.error(
raise HTTPException(
502,
"Could not get result after running benchmark: "
f"'report.json' does not exist in '{latest_folder}'"
f"'report.json' does not exist in '{latest_folder}'",
)
else:
logger.error(
"Could not get result after running benchmark: no reports found"
raise HTTPException(
504, "Could not get result after running benchmark: no reports found"
)
return data
@ -239,7 +240,9 @@ def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
api_instance = AgentApi(api_client)
task_input = challenge_info.task
task_request_body = TaskRequestBody(input=task_input)
task_request_body = TaskRequestBody(
input=task_input, additional_input=None
)
task_response = await api_instance.create_agent_task(
task_request_body=task_request_body
)
@ -276,7 +279,7 @@ def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
# Forward the request
response = await client.post(
new_url,
data=await request.body(),
content=await request.body(),
headers=dict(request.headers),
)

View File

@ -1,7 +1,7 @@
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import AsyncIterator, ClassVar, Optional
from typing import AsyncIterator, Awaitable, ClassVar, Optional
import pytest
from agent_protocol_client import AgentApi, Step
@ -54,7 +54,7 @@ class BaseChallenge(ABC):
config: AgentBenchmarkConfig,
request: pytest.FixtureRequest,
i_attempt: int,
) -> None:
) -> None | Awaitable[None]:
"""
Test method for use by Pytest-based benchmark sessions. Should return normally
if the challenge passes, and raise a (preferably descriptive) error otherwise.

View File

@ -1,4 +1,3 @@
from collections import deque
import glob
import json
import logging
@ -6,19 +5,17 @@ import os
import subprocess
import sys
import tempfile
from collections import deque
from pathlib import Path
from typing import Any, ClassVar, Iterator, Literal, Optional
from typing import Annotated, Any, ClassVar, Iterator, Literal, Optional
import pytest
from agent_protocol_client import (
AgentApi,
ApiClient,
Configuration as ClientConfig,
Step,
)
from agent_protocol_client import AgentApi, ApiClient
from agent_protocol_client import Configuration as ClientConfig
from agent_protocol_client import Step
from colorama import Fore, Style
from openai import _load_client as get_openai_client
from pydantic import BaseModel, constr, Field, validator
from pydantic import BaseModel, Field, constr, validator
from agbenchmark.agent_api_interface import download_agent_artifacts_into_folder
from agbenchmark.agent_interface import copy_challenge_artifacts_into_workspace
@ -49,7 +46,7 @@ class BuiltinChallengeSpec(BaseModel):
class Info(BaseModel):
difficulty: DifficultyLevel
description: constr(regex=r"^Tests if the agent can.*")
description: Annotated[str, constr(regex=r"^Tests if the agent can.*")]
side_effects: list[str] = Field(default_factory=list)
info: Info
@ -184,7 +181,7 @@ class BuiltinChallenge(BaseChallenge):
steps: list[Step] = []
try:
async for step in self.run_challenge(
config, timeout, mock=request.config.getoption("--mock")
config, timeout, mock=bool(request.config.getoption("--mock"))
):
if not task_id:
task_id = step.task_id
@ -199,6 +196,8 @@ class BuiltinChallenge(BaseChallenge):
timed_out = False
except TimeoutError:
timed_out = True
assert isinstance(request.node, pytest.Item)
request.node.user_properties.append(("steps", steps))
request.node.user_properties.append(("n_steps", n_steps))
request.node.user_properties.append(("timed_out", timed_out))
@ -411,15 +410,10 @@ class BuiltinChallenge(BaseChallenge):
def load_builtin_challenges() -> Iterator[type[BuiltinChallenge]]:
logger.info("Loading built-in challenges...")
challenges_path = os.path.dirname(__file__)
challenges_path = Path(__file__).parent
logger.debug(f"Looking for challenge spec files in {challenges_path}...")
json_files = deque(
glob.glob(
f"{challenges_path}/**/data.json",
recursive=True,
)
)
json_files = deque(challenges_path.rglob("data.json"))
logger.debug(f"Found {len(json_files)} built-in challenges.")
@ -431,7 +425,7 @@ def load_builtin_challenges() -> Iterator[type[BuiltinChallenge]]:
ignored += 1
continue
challenge = BuiltinChallenge.from_challenge_spec_file(Path(json_file))
challenge = BuiltinChallenge.from_challenge_spec_file(json_file)
logger.debug(f"Generated test for {challenge.info.name}")
yield challenge
@ -442,8 +436,8 @@ def load_builtin_challenges() -> Iterator[type[BuiltinChallenge]]:
)
def _challenge_should_be_ignored(json_file_path: str):
def _challenge_should_be_ignored(json_file_path: Path):
return (
"challenges/deprecated" in json_file_path
or "challenges/library" in json_file_path
"challenges/deprecated" in json_file_path.as_posix()
or "challenges/library" in json_file_path.as_posix()
)

View File

@ -23,9 +23,10 @@ def test_get_ethereum_price() -> None:
real_eth_price_value = float(real_eth_price)
# Check if the eth price is within $50 of the actual Ethereum price
assert (
abs(real_eth_price_value - eth_price_value) <= 50
), f"AssertionError: Ethereum price is not within $50 of the actual Ethereum price (Provided price: ${eth_price}, Real price: ${real_eth_price})"
assert abs(real_eth_price_value - eth_price_value) <= 50, (
"AssertionError: Ethereum price is not within $50 of the actual Ethereum price "
f"(Provided price: ${eth_price}, Real price: ${real_eth_price})"
)
print("Matches")

View File

@ -23,9 +23,10 @@ def test_get_ethereum_price() -> None:
real_eth_price_value = float(real_eth_price)
# Check if the eth price is within $50 of the actual Ethereum price
assert (
abs(real_eth_price_value - eth_price_value) <= 50
), f"AssertionError: Ethereum price is not within $50 of the actual Ethereum price (Provided price: ${eth_price}, Real price: ${real_eth_price})"
assert abs(real_eth_price_value - eth_price_value) <= 50, (
"AssertionError: Ethereum price is not within $50 of the actual Ethereum price "
f"(Provided price: ${eth_price}, Real price: ${real_eth_price})"
)
print("Matches")

View File

@ -1,4 +1,3 @@
# mypy: ignore-errors
from typing import List, Optional

View File

@ -1,4 +1,4 @@
# mypy: ignore-errors
# pyright: reportMissingImports=false
from typing import List
from sample_code import three_sum

View File

@ -21,7 +21,6 @@ def generate_password(length: int = 8) -> str:
if __name__ == "__main__":
password_length = (
int(sys.argv[sys.argv.index("--length") + 1])
if "--length" in sys.argv else 8
int(sys.argv[sys.argv.index("--length") + 1]) if "--length" in sys.argv else 8
)
print(generate_password(password_length))

View File

@ -1,3 +1,4 @@
# pyright: reportMissingImports=false
import unittest
import password_generator
@ -18,7 +19,9 @@ class TestPasswordGenerator(unittest.TestCase):
def test_password_content(self):
password = password_generator.generate_password()
self.assertTrue(any(c.isdigit() for c in password))
self.assertTrue(any(c in password_generator.string.punctuation for c in password))
self.assertTrue(
any(c in password_generator.string.punctuation for c in password)
)
if __name__ == "__main__":

View File

@ -1,3 +1,4 @@
# pyright: reportMissingImports=false
import unittest
from url_shortener import retrieve_url, shorten_url

View File

@ -56,7 +56,7 @@ def winner(board):
def getLocation():
location = input(
"Choose where to play. Enter two numbers separated by a comma, for example: 1,1 "
"Choose where to play. Enter two numbers separated by a comma [example: 1,1]: "
)
print(f"\nYou picked {location}")
coordinates = [int(x) for x in location.split(",")]
@ -69,7 +69,8 @@ def getLocation():
):
print("You inputted a location in an invalid format")
location = input(
"Choose where to play. Enter two numbers separated by a comma, for example: 1,1 "
"Choose where to play. Enter two numbers separated by a comma "
"[example: 1,1]: "
)
coordinates = [int(x) for x in location.split(",")]
return coordinates

View File

@ -37,15 +37,14 @@ class GameStatus(BaseModel):
winner: Optional[str]
from typing import List
class Game(BaseModel):
game_id: str
players: List[str]
board: dict # This could represent the state of the game board, you might need to flesh this out further
ships: List[ShipPlacement] # List of ship placements for this game
turns: List[Turn] # List of turns that have been taken
players: list[str]
# This could represent the state of the game board,
# you might need to flesh this out further:
board: dict
ships: list[ShipPlacement] # List of ship placements for this game
turns: list[Turn] # List of turns that have been taken
class AbstractBattleship(ABC):
@ -86,7 +85,7 @@ class AbstractBattleship(ABC):
pass
@abstractmethod
def get_game(self) -> Game:
def get_game(self) -> Game | None:
"""
Retrieve the state of the game.
"""
@ -103,5 +102,8 @@ class AbstractBattleship(ABC):
def create_game(self) -> None:
"""
Create a new game.
Returns:
str: The ID of the created game.
"""
pass

View File

@ -1,3 +1,4 @@
# pyright: reportMissingImports=false
import pytest
from abstract_class import ShipPlacement, Turn
from battleship import Battleship

View File

@ -50,7 +50,7 @@ def test_cant_hit_before_ships_placed(battleship_game):
def test_cant_place_ship_after_all_ships_placed(battleship_game, initialized_game_id):
game = battleship_game.get_game(initialized_game_id)
battleship_game.get_game(initialized_game_id)
additional_ship = ShipPlacement(
ship_type="carrier", start={"row": 2, "column": "E"}, direction="horizontal"
)

View File

@ -61,6 +61,7 @@ def test_ship_sinking_feedback(battleship_game, initialized_game_id):
{"row": 1, "column": "H"},
]
response = None
for index, hit in enumerate(hits):
turn = Turn(target={"row": 2, "column": hit})
response = battleship_game.create_turn(initialized_game_id, turn)
@ -69,7 +70,7 @@ def test_ship_sinking_feedback(battleship_game, initialized_game_id):
static_turn = Turn(target=static_moves[index])
battleship_game.create_turn(initialized_game_id, static_turn)
assert response.result == "sunk"
assert response and response.result == "sunk"
def test_restart_game(battleship_game):

View File

@ -37,15 +37,14 @@ class GameStatus(BaseModel):
winner: Optional[str]
from typing import List
class Game(BaseModel):
game_id: str
players: List[str]
board: dict # This could represent the state of the game board, you might need to flesh this out further
ships: List[ShipPlacement] # List of ship placements for this game
turns: List[Turn] # List of turns that have been taken
players: list[str]
# This could represent the state of the game board,
# you might need to flesh this out further:
board: dict
ships: list[ShipPlacement] # List of ship placements for this game
turns: list[Turn] # List of turns that have been taken
class AbstractBattleship(ABC):
@ -86,7 +85,7 @@ class AbstractBattleship(ABC):
pass
@abstractmethod
def get_game(self) -> Game:
def get_game(self, game_id: str) -> Game | None:
"""
Retrieve the state of the game.
"""
@ -100,8 +99,11 @@ class AbstractBattleship(ABC):
pass
@abstractmethod
def create_game(self) -> None:
def create_game(self) -> str:
"""
Create a new game.
Returns:
str: The ID of the created game.
"""
pass

View File

@ -1,14 +1,20 @@
from typing import Dict
from abstract_class import (AbstractBattleship, Game, GameStatus,
ShipPlacement, Turn, TurnResponse)
from abstract_class import (
AbstractBattleship,
Game,
GameStatus,
ShipPlacement,
Turn,
TurnResponse,
)
class Battleship(AbstractBattleship):
def __init__(self):
self.games: Dict[int, Game] = {}
self.games: Dict[str, Game] = {}
def create_game(self) -> int:
def create_game(self) -> str:
game_id = str(len(self.games))
new_game = Game(
game_id=game_id,
@ -19,7 +25,7 @@ class Battleship(AbstractBattleship):
)
self.games[game_id] = new_game
return new_game.game_id
return game_id
def create_ship_placement(self, game_id: str, placement: ShipPlacement) -> None:
game = self.games.get(game_id)
@ -79,38 +85,34 @@ class Battleship(AbstractBattleship):
game.turns.append(turn)
if hit_ship == "hit":
if not hit_ship or hit_ship == "hit": # if no ship or already hit
return TurnResponse(result="miss", ship_type=None)
if hit_ship:
ship_placement = next(sp for sp in game.ships if sp.ship_type == hit_ship)
ship_placement = next(sp for sp in game.ships if sp.ship_type == hit_ship)
start_row, start_col = (
ship_placement.start["row"],
ord(ship_placement.start["column"]) - ord("A"),
)
ship_positions = [
(
start_row + (i if ship_placement.direction == "vertical" else 0),
start_col + (i if ship_placement.direction == "horizontal" else 0),
)
for i in range(self.SHIP_LENGTHS[hit_ship])
]
if hit_ship:
ship_placement = next(sp for sp in game.ships if sp.ship_type == hit_ship)
start_row, start_col = ship_placement.start["row"], ord(
ship_placement.start["column"]
) - ord("A")
ship_positions = [
(
start_row + (i if ship_placement.direction == "vertical" else 0),
start_col + (i if ship_placement.direction == "horizontal" else 0),
)
for i in range(self.SHIP_LENGTHS[hit_ship])
]
targeted_positions = {
(t.target["row"], ord(t.target["column"]) - ord("A")) for t in game.turns
}
targeted_positions = {
(t.target["row"], ord(t.target["column"]) - ord("A"))
for t in game.turns
}
game.board[(target_row, target_col)] = "hit"
game.board[(target_row, target_col)] = "hit"
if set(ship_positions).issubset(targeted_positions):
for pos in ship_positions:
game.board[pos] = "hit"
return TurnResponse(result="sunk", ship_type=hit_ship)
else:
return TurnResponse(result="hit", ship_type=hit_ship)
if set(ship_positions).issubset(targeted_positions):
for pos in ship_positions:
game.board[pos] = "hit"
return TurnResponse(result="sunk", ship_type=hit_ship)
else:
return TurnResponse(result="hit", ship_type=hit_ship)
def get_game_status(self, game_id: str) -> GameStatus:
game = self.games.get(game_id)
@ -132,12 +134,12 @@ class Battleship(AbstractBattleship):
def get_winner(self, game_id: str) -> str:
game_status = self.get_game_status(game_id)
if game_status.is_game_over:
if game_status.is_game_over and game_status.winner:
return game_status.winner
else:
return None
raise ValueError(f"Game {game_id} isn't over yet")
def get_game(self, game_id: str) -> Game:
def get_game(self, game_id: str) -> Game | None:
return self.games.get(game_id)
def delete_game(self, game_id: str) -> None:

View File

@ -50,7 +50,7 @@ def test_cant_hit_before_ships_placed(battleship_game):
def test_cant_place_ship_after_all_ships_placed(battleship_game, initialized_game_id):
game = battleship_game.get_game(initialized_game_id)
battleship_game.get_game(initialized_game_id)
additional_ship = ShipPlacement(
ship_type="carrier", start={"row": 2, "column": "E"}, direction="horizontal"
)

View File

@ -61,6 +61,7 @@ def test_ship_sinking_feedback(battleship_game, initialized_game_id):
{"row": 1, "column": "H"},
]
response = None
for index, hit in enumerate(hits):
turn = Turn(target={"row": 2, "column": hit})
response = battleship_game.create_turn(initialized_game_id, turn)
@ -69,7 +70,7 @@ def test_ship_sinking_feedback(battleship_game, initialized_game_id):
static_turn = Turn(target=static_moves[index])
battleship_game.create_turn(initialized_game_id, static_turn)
assert response.result == "sunk"
assert response and response.result == "sunk"
def test_restart_game(battleship_game):

View File

@ -6,7 +6,7 @@ from typing import ClassVar, Iterator, Literal
import pytest
import requests
from agent_protocol_client import AgentApi, Step
from pydantic import BaseModel, validator, ValidationError
from pydantic import BaseModel, ValidationError, validator
from agbenchmark.config import AgentBenchmarkConfig
from agbenchmark.utils.data_types import Category, EvalResult
@ -93,11 +93,12 @@ class Eval(ABC):
...
class StringEval(BaseModel, Eval):
type: ReferenceAnswerType
class BaseStringEval(BaseModel, Eval):
# type: ReferenceAnswerType
pass
class ExactStringMatchEval(StringEval):
class ExactStringMatchEval(BaseStringEval):
type: Literal["exact_match"] = "exact_match"
reference_answer: str
@ -109,7 +110,7 @@ class ExactStringMatchEval(StringEval):
return string == self.reference_answer
class FuzzyStringMatchEval(StringEval):
class FuzzyStringMatchEval(BaseStringEval):
type: Literal["fuzzy_match"] = "fuzzy_match"
reference_answer: str
@ -122,7 +123,7 @@ class FuzzyStringMatchEval(StringEval):
return self.reference_answer.lower() in string.lower()
class MustIncludeStringEval(StringEval):
class MustIncludeStringEval(BaseStringEval):
type: Literal["must_include"] = "must_include"
reference_answer: str
@ -134,6 +135,9 @@ class MustIncludeStringEval(StringEval):
return self.reference_answer.lower() in string.lower()
StringEval = ExactStringMatchEval | FuzzyStringMatchEval | MustIncludeStringEval
class UrlMatchEval(BaseModel, Eval):
url: str
"""Example: `"__WIKI__/wiki/Octopus"`"""
@ -142,8 +146,8 @@ class UrlMatchEval(BaseModel, Eval):
def description(self) -> str:
return f"Agent must navigate to '{self.url}'"
def evaluate(self, url: str) -> bool:
return url == resolve_uri(self.url)
def evaluate(self, string: str) -> bool:
return string == resolve_uri(self.url)
class ProgramHtmlEval(BaseModel):
@ -258,7 +262,8 @@ class WebArenaChallengeSpec(BaseModel):
f"{' and '.join(s.base_url for s in sites)}.\n\n"
+ "\n".join(
s.additional_info.format(url=s.base_url)
for s in sites if s.additional_info
for s in sites
if s.additional_info
)
).strip()
@ -391,7 +396,9 @@ class WebArenaChallenge(BaseChallenge):
if request.config.getoption("--nc"):
timeout = 100000
elif cutoff := request.config.getoption("--cutoff"):
timeout = int(cutoff)
timeout = int(cutoff) # type: ignore
assert isinstance(request.node, pytest.Item)
n_steps = 0
timed_out = None
@ -400,7 +407,7 @@ class WebArenaChallenge(BaseChallenge):
eval_results_per_step: list[list[tuple[_Eval, EvalResult]]] = []
try:
async for step in self.run_challenge(
config, timeout, mock=request.config.getoption("--mock")
config, timeout, mock=bool(request.config.getoption("--mock"))
):
if not step.output:
logger.warn(f"Step has no output: {step}")
@ -415,7 +422,7 @@ class WebArenaChallenge(BaseChallenge):
)
step_eval_results = self.evaluate_step_result(
step, mock=request.config.getoption("--mock")
step, mock=bool(request.config.getoption("--mock"))
)
logger.debug(f"Intermediary results: {step_eval_results}")
eval_results_per_step.append(step_eval_results)
@ -462,7 +469,7 @@ class WebArenaChallenge(BaseChallenge):
def load_webarena_challenges(
skip_unavailable: bool = True
skip_unavailable: bool = True,
) -> Iterator[type[WebArenaChallenge]]:
logger.info("Loading WebArena challenges...")

View File

@ -123,8 +123,10 @@ def check_regression(request: pytest.FixtureRequest) -> None:
with contextlib.suppress(FileNotFoundError):
rt_tracker = RegressionTestsTracker(agbenchmark_config.regression_tests_file)
assert isinstance(request.node, pytest.Function)
assert isinstance(request.node.parent, pytest.Class)
test_name = request.node.parent.name
challenge_location = getattr(request.node.parent.cls, "CHALLENGE_LOCATION", "")
challenge_location = getattr(request.node.cls, "CHALLENGE_LOCATION", "")
skip_string = f"Skipping {test_name} at {challenge_location}"
# Check if the test name exists in the regression tests
@ -148,7 +150,9 @@ def mock(request: pytest.FixtureRequest) -> bool:
Returns:
bool: Whether `--mock` is set for this session.
"""
return request.config.getoption("--mock")
mock = request.config.getoption("--mock")
assert isinstance(mock, bool)
return mock
test_reports: dict[str, Test] = {}
@ -221,7 +225,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc):
def pytest_collection_modifyitems(
items: list[pytest.Item], config: pytest.Config
items: list[pytest.Function], config: pytest.Config
) -> None:
"""
Pytest hook that is called after initial test collection has been performed.
@ -248,8 +252,9 @@ def pytest_collection_modifyitems(
i = 0
while i < len(items):
item = items[i]
assert item.cls and issubclass(item.cls, BaseChallenge)
challenge = item.cls
challenge_name = item.cls.__name__
challenge_name = challenge.info.name
if not issubclass(challenge, BaseChallenge):
item.warn(

View File

@ -18,9 +18,9 @@ def run_benchmark(
maintain: bool = False,
improve: bool = False,
explore: bool = False,
tests: tuple[str] = tuple(),
categories: tuple[str] = tuple(),
skip_categories: tuple[str] = tuple(),
tests: tuple[str, ...] = tuple(),
categories: tuple[str, ...] = tuple(),
skip_categories: tuple[str, ...] = tuple(),
attempts_per_challenge: int = 1,
mock: bool = False,
no_dep: bool = False,

View File

@ -53,9 +53,9 @@ class SingletonReportManager:
@classmethod
def clear_instance(cls):
cls.instance = None
cls.INFO_MANAGER = None
cls.REGRESSION_MANAGER = None
cls.SUCCESS_RATE_TRACKER = None
del cls.INFO_MANAGER
del cls.REGRESSION_MANAGER
del cls.SUCCESS_RATE_TRACKER
class BaseReportManager:
@ -99,7 +99,8 @@ class BaseReportManager:
class SessionReportManager(BaseReportManager):
"""Abstracts interaction with the regression tests file"""
tests: dict[str, Test] | Report
tests: dict[str, Test]
report: Report | None = None
def __init__(self, report_file: Path, benchmark_start_time: datetime):
super().__init__(report_file)
@ -109,20 +110,21 @@ class SessionReportManager(BaseReportManager):
def save(self) -> None:
with self.report_file.open("w") as f:
if isinstance(self.tests, Report):
f.write(self.tests.json(indent=4))
if self.report:
f.write(self.report.json(indent=4))
else:
json.dump({k: v.dict() for k, v in self.tests.items()}, f, indent=4)
def load(self) -> None:
super().load()
if "tests" in self.tests: # type: ignore
self.tests = Report.parse_obj(self.tests)
if "tests" in self.tests:
self.report = Report.parse_obj(self.tests)
else:
self.tests = {n: Test.parse_obj(d) for n, d in self.tests.items()}
def add_test_report(self, test_name: str, test_report: Test) -> None:
if isinstance(self.tests, Report):
if self.report:
raise RuntimeError("Session report already finalized")
if test_name.startswith("Test"):
@ -134,10 +136,10 @@ class SessionReportManager(BaseReportManager):
def finalize_session_report(self, config: AgentBenchmarkConfig) -> None:
command = " ".join(sys.argv)
if isinstance(self.tests, Report):
if self.report:
raise RuntimeError("Session report already finalized")
self.tests = Report(
self.report = Report(
command=command.split(os.sep)[-1],
benchmark_git_commit_sha="---",
agent_git_commit_sha="---",
@ -156,7 +158,7 @@ class SessionReportManager(BaseReportManager):
config=config.dict(exclude={"reports_folder"}, exclude_none=True),
)
agent_categories = get_highest_achieved_difficulty_per_category(self.tests)
agent_categories = get_highest_achieved_difficulty_per_category(self.report)
if len(agent_categories) > 1:
save_single_radar_chart(
agent_categories,
@ -166,8 +168,8 @@ class SessionReportManager(BaseReportManager):
self.save()
def get_total_costs(self):
if isinstance(self.tests, Report):
tests = self.tests.tests
if self.report:
tests = self.report.tests
else:
tests = self.tests

View File

@ -3,7 +3,7 @@ Model definitions used internally and for reports generated during command-line
"""
import logging
from typing import Any, Dict, List
from typing import Annotated, Any, Dict, List
from agent_protocol_client import Step
from pydantic import BaseModel, Field, constr, validator
@ -88,7 +88,7 @@ class Test(BaseModel):
class ReportBase(BaseModel):
command: str
completion_time: str | None = None
benchmark_start_time: constr(regex=datetime_format)
benchmark_start_time: Annotated[str, constr(regex=datetime_format)]
metrics: MetricsOverall
config: Dict[str, str | dict[str, str]]
agent_git_commit_sha: str | None = None

View File

@ -1,4 +1,5 @@
"""Model definitions for use in the API"""
from typing import Annotated
from pydantic import BaseModel, constr
@ -36,7 +37,7 @@ class RunDetails(BaseModel):
run_id: str | None = None
command: str
completion_time: str | None = None
benchmark_start_time: constr(regex=datetime_format)
benchmark_start_time: Annotated[str, constr(regex=datetime_format)]
class BenchmarkRun(BaseModel):

View File

@ -1,14 +1,10 @@
from __future__ import annotations
from typing import Optional
from typing import Any, Optional
from pydantic import BaseModel, Field
class TaskInput(BaseModel):
pass
class TaskRequestBody(BaseModel):
input: str = Field(
...,
@ -16,7 +12,7 @@ class TaskRequestBody(BaseModel):
description="Input prompt for the task.",
example="Write the words you receive to the file 'output.txt'.",
)
additional_input: Optional[TaskInput] = {}
additional_input: Optional[dict[str, Any]] = Field(default_factory=dict)
class TaskEvalRequestBody(TaskRequestBody):

View File

@ -32,7 +32,10 @@ def _add_ini_and_option(
default: str | bool | int,
**kwargs: Any,
) -> None:
"""Add an option to both the ini file as well as the command line flags, with the latter overriding the former."""
"""
Add an option to both the ini file and the command line flags.
Command line flags/options takes precedence over the ini config.
"""
parser.addini(
name,
help + " This overrides the similarly named option from the config.",
@ -44,7 +47,10 @@ def _add_ini_and_option(
def _get_ini_or_option(
config: Any, name: str, choices: Optional[list[str]]
) -> str | None:
"""Get an option from either the ini file or the command line flags, the latter taking precedence."""
"""
Get an option from either the ini file or the command line flags,
with the latter taking precedence.
"""
value = config.getini(name)
if value is not None and choices is not None and value not in choices:
raise ValueError(
@ -73,7 +79,7 @@ def pytest_addoption(parser: Parser) -> None:
default=False,
help=(
"List all non-nodeid dependency names + the tests they resolve to. "
"Will also list all nodeid dependency names when verbosity is high enough."
"Will also list all nodeid dependency names in verbose mode."
),
)
@ -83,7 +89,10 @@ def pytest_addoption(parser: Parser) -> None:
"--list-processed-dependencies",
action="store_true",
default=False,
help="List all dependencies of all tests as a list of nodeids + the names that could not be resolved.",
help=(
"List all dependencies of all tests as a list of nodeids "
"+ the names that could not be resolved."
),
)
# Add an ini option + flag to choose the action to take for failed dependencies
@ -94,7 +103,8 @@ def pytest_addoption(parser: Parser) -> None:
name="failed_dependency_action",
help=(
"The action to take when a test has dependencies that failed. "
'Use "run" to run the test anyway, "skip" to skip the test, and "fail" to fail the test.'
'Use "run" to run the test anyway, "skip" to skip the test, '
'and "fail" to fail the test.'
),
default="skip",
choices=DEPENDENCY_PROBLEM_ACTIONS.keys(),
@ -107,8 +117,10 @@ def pytest_addoption(parser: Parser) -> None:
group,
name="missing_dependency_action",
help=(
"The action to take when a test has dependencies that cannot be found within the current scope. "
'Use "run" to run the test anyway, "skip" to skip the test, and "fail" to fail the test.'
"The action to take when a test has dependencies that cannot be found "
"within the current scope. "
'Use "run" to run the test anyway, "skip" to skip the test, '
'and "fail" to fail the test.'
),
default="warning",
choices=DEPENDENCY_PROBLEM_ACTIONS.keys(),
@ -139,7 +151,7 @@ def pytest_configure(config: Any) -> None:
@pytest.hookimpl(trylast=True)
def pytest_collection_modifyitems(config: Any, items: list[Item]) -> None:
def pytest_collection_modifyitems(config: Any, items: list[pytest.Function]) -> None:
manager = managers[-1]
# Register the founds tests on the manager

View File

@ -3,7 +3,7 @@
# The name of the marker used
MARKER_NAME = "depends"
# The name of the keyword argument for the marker that contains custom name(s) for the tests
# The name of the kwarg for 'depends' markers that contains custom name(s) for the tests
MARKER_KWARG_ID = "name"
# The name of the keyword argument for the marker that specifies the tests to depend on

View File

@ -57,8 +57,10 @@ def curved_edges(
"""
ax = plt.gca()
for u, v, data in G.edges(data=True):
src = np.array(pos[u])
dst = np.array(pos[v])
_src = pos[u]
_dst = pos[v]
src = np.array(_src)
dst = np.array(_dst)
same_level = abs(src[1] - dst[1]) < 0.01
@ -68,7 +70,7 @@ def curved_edges(
arrow = patches.FancyArrowPatch(
posA=curve[0], # type: ignore
posB=curve[-1], # type: ignore
connectionstyle=f"arc3,rad=0.2",
connectionstyle="arc3,rad=0.2",
color="gray",
arrowstyle="-|>",
mutation_scale=15.0,
@ -80,8 +82,8 @@ def curved_edges(
else:
ax.annotate(
"",
xy=dst,
xytext=src,
xy=_dst,
xytext=_src,
arrowprops=dict(
arrowstyle="-|>", color="gray", lw=1, shrinkA=10, shrinkB=10
),
@ -89,7 +91,8 @@ def curved_edges(
def tree_layout(graph: nx.DiGraph, root_node: Any) -> Dict[Any, Tuple[float, float]]:
"""Compute positions as a tree layout centered on the root with alternating vertical shifts."""
"""Compute positions as a tree layout centered on the root
with alternating vertical shifts."""
bfs_tree = nx.bfs_tree(graph, source=root_node)
levels = {
node: depth
@ -137,7 +140,7 @@ def tree_layout(graph: nx.DiGraph, root_node: Any) -> Dict[Any, Tuple[float, flo
def graph_spring_layout(
dag: nx.DiGraph, labels: Dict[Any, str], tree: bool = True
) -> None:
num_nodes = len(dag.nodes())
num_nodes = len(list(dag.nodes()))
# Setting up the figure and axis
fig, ax = plt.subplots()
ax.axis("off") # Turn off the axis
@ -288,7 +291,8 @@ def graph_interactive_network(
# Optionally, save to a file
# Sync with the flutter UI
# this literally only works in the AutoGPT repo, but this part of the code is not reached if BUILD_SKILL_TREE is false
# this literally only works in the AutoGPT repo, but this part of the code
# is not reached if BUILD_SKILL_TREE is false
write_pretty_json(graph_data, flutter_app_path / "tree_structure.json")
validate_skill_tree(graph_data, "")
@ -332,11 +336,13 @@ def graph_interactive_network(
def extract_subgraph_based_on_category(graph, category):
"""
Extracts a subgraph that includes all nodes and edges required to reach all nodes with a specified category.
Extracts a subgraph that includes all nodes and edges required to reach all nodes
with a specified category.
:param graph: The original graph.
:param category: The target category.
:return: Subgraph with nodes and edges required to reach the nodes with the given category.
:return: Subgraph with nodes and edges required to reach the nodes
with the given category.
"""
subgraph = {"nodes": [], "edges": []}
@ -424,7 +430,8 @@ def get_roots(graph):
def validate_skill_tree(graph, skill_tree_name):
"""
Validate if a given graph represents a valid skill tree and raise appropriate exceptions if not.
Validate if a given graph represents a valid skill tree
and raise appropriate exceptions if not.
:param graph: A dictionary representing the graph with 'nodes' and 'edges'.
:raises: ValueError with a description of the invalidity.
@ -434,7 +441,8 @@ def validate_skill_tree(graph, skill_tree_name):
if cycle_path:
cycle_str = " -> ".join(cycle_path)
raise ValueError(
f"{skill_tree_name} skill tree is circular! Circular path detected: {cycle_str}."
f"{skill_tree_name} skill tree is circular! "
f"Detected circular path: {cycle_str}."
)
# Check for multiple roots

View File

@ -1,18 +1,19 @@
"""
A module to manage dependencies between pytest tests.
This module provides the methods implementing the main logic. These are used in the pytest hooks that are in
__init__.py.
This module provides the methods implementing the main logic.
These are used in the pytest hooks that are in __init__.py.
"""
import collections
import json
import os
from typing import Any, Generator
import colorama
import networkx
from _pytest.nodes import Item
from pytest import Function, Item
from agbenchmark.challenges.base import BaseChallenge
from .constants import MARKER_KWARG_DEPENDENCIES, MARKER_NAME
from .graphs import graph_interactive_network
@ -38,7 +39,8 @@ class TestResult(object):
)
if result.when in self.results:
raise AttributeError(
f"Received multiple results for step {result.when} of test {self.nodeid}"
f"Received multiple results for step {result.when} "
f"of test {self.nodeid}"
)
self.results[result.when] = result.outcome
@ -66,7 +68,7 @@ class TestDependencies(object):
for dep in marker.kwargs.get(MARKER_KWARG_DEPENDENCIES, [])
]
for dependency in dependencies:
# If the name is not known, try to make it absolute (ie file::[class::]method)
# If the name is not known, try to make it absolute (file::[class::]method)
if dependency not in manager.name_to_nodeids:
absolute_dependency = get_absolute_nodeid(dependency, self.nodeid)
if absolute_dependency in manager.name_to_nodeids:
@ -86,20 +88,20 @@ class DependencyManager(object):
def __init__(self) -> None:
"""Create a new DependencyManager."""
self.options: dict[str, Any] = {}
self._items: list[Item] | None = None
self._items: list[Function] | None = None
self._name_to_nodeids: Any = None
self._nodeid_to_item: Any = None
self._results: Any = None
@property
def items(self) -> list[Item]:
def items(self) -> list[Function]:
"""The collected tests that are managed by this instance."""
if self._items is None:
raise AttributeError("The items attribute has not been set yet")
return self._items
@items.setter
def items(self, items: list[Item]) -> None:
def items(self, items: list[Function]) -> None:
if self._items is not None:
raise AttributeError("The items attribute has already been set")
self._items = items
@ -125,7 +127,8 @@ class DependencyManager(object):
for item in items:
nodeid = clean_nodeid(item.nodeid)
# Process the dependencies of this test
# This uses the mappings created in the previous loop, and can thus not be merged into that loop
# This uses the mappings created in the previous loop,
# and can thus not be merged into that loop
self._dependencies[nodeid] = TestDependencies(item, self)
@property
@ -135,7 +138,7 @@ class DependencyManager(object):
return self._name_to_nodeids
@property
def nodeid_to_item(self) -> dict[str, Item]:
def nodeid_to_item(self) -> dict[str, Function]:
"""A mapping from node ids to test items."""
assert self.items is not None
return self._nodeid_to_item
@ -194,7 +197,9 @@ class DependencyManager(object):
@property
def sorted_items(self) -> Generator:
"""Get a sorted list of tests where all tests are sorted after their dependencies."""
"""
Get a sorted list of tests where all tests are sorted after their dependencies.
"""
# Build a directed graph for sorting
build_skill_tree = os.getenv("BUILD_SKILL_TREE")
BUILD_SKILL_TREE = (
@ -202,8 +207,8 @@ class DependencyManager(object):
)
dag = networkx.DiGraph()
# Insert all items as nodes, to prevent items that have no dependencies and are not dependencies themselves from
# being lost
# Insert all items as nodes, to prevent items that have no dependencies
# and are not dependencies themselves from being lost
dag.add_nodes_from(self.items)
# Insert edges for all the dependencies
@ -214,11 +219,8 @@ class DependencyManager(object):
labels = {}
for item in self.items:
try:
with open(item.cls.CHALLENGE_LOCATION) as f:
data = json.load(f)
except:
data = {}
assert item.cls and issubclass(item.cls, BaseChallenge)
data = item.cls.info.dict()
node_name = get_name(item)
data["name"] = node_name

View File

@ -38,7 +38,8 @@ def strip_nodeid_parameters(nodeid: str) -> str:
def get_absolute_nodeid(nodeid: str, scope: str) -> str:
"""
Transform a possibly relative node id to an absolute one using the scope in which it is used.
Transform a possibly relative node id to an absolute one
using the scope in which it is used.
>>> scope = 'test_file.py::TestClass::test'
>>> get_absolute_nodeid('test2', scope)
@ -49,7 +50,7 @@ def get_absolute_nodeid(nodeid: str, scope: str) -> str:
'test_file2.py::TestClass2::test2'
"""
parts = nodeid.split("::")
# Completely relative (test_name), so add the full current scope (either file::class or file)
# Completely relative (test_name): add the full current scope (file::class or file)
if len(parts) == 1:
base_nodeid = scope.rsplit("::", 1)[0]
nodeid = f"{base_nodeid}::{nodeid}"

View File

@ -15,7 +15,8 @@ def get_data_from_helicone(challenge: str) -> Optional[float]:
# Define the endpoint of your GraphQL server
url = "https://www.helicone.ai/api/graphql"
# Set the headers, usually you'd need to set the content type and possibly an authorization token
# Set the headers, usually you'd need to set the content type
# and possibly an authorization token
headers = {"authorization": f"Bearer {os.environ.get('HELICONE_API_KEY')}"}
# Define the query, variables, and operation name

View File

@ -1,7 +1,18 @@
SCORING_MAP = {
"percentage": "assign a float score that will represent a percentage out of 100. Use decimal points to be even more accurate. 0 represents the worst possible generation, while 100 represents the ideal generation",
"scale": "assign an integer score from a scale of 1-10. 1 represents a really bad generation, while 10 represents an ideal generation",
"binary": "assign a binary score of either 0 or 1. 0 represents a failure, while 1 represents a success",
"percentage": (
"assign a float score that will represent a percentage out of 100. "
"Use decimal points to be even more accurate. "
"0 represents the worst possible generation, "
"while 100 represents the ideal generation"
),
"scale": (
"assign an integer score from a scale of 1-10. "
"1 represents a really bad generation, while 10 represents an ideal generation"
),
"binary": (
"assign a binary score of either 0 or 1. "
"0 represents a failure, while 1 represents a success"
),
}
@ -17,7 +28,7 @@ Here is the ideal response you're comparing to based on the task:
Here is the current machine generated response to the task that you need to evaluate:
{response}
"""
""" # noqa: E501
RUBRIC_PROMPT = """Ignore previous directions. You are now an expert at evaluating machine generated responses to given tasks.
In order to score the generated texts you will {scoring}. Make sure to factor in rubric into your thinking, deliberation, and final result regarding scoring. Return nothing but a float score.
@ -31,7 +42,7 @@ Use the below rubric to guide your thinking about scoring:
Here is the current machine generated response to the task that you need to evaluate:
{response}
"""
""" # noqa: E501
QUESTION_PROMPT = """Ignore previous directions. You are now an expert at evaluating machine generated responses to given tasks.
In order to score the generated texts you will {scoring}. Make sure to think about whether the generated response answers the question well in order to score accurately. Return nothing but a float score.
@ -45,12 +56,12 @@ Here is a question that checks if the task was completed correctly:
Here is the current machine generated response to the task that you need to evaluate:
{response}
"""
""" # noqa: E501
FEW_SHOT_EXAMPLES = """Here are some examples of how to score a machine generated response based on the above:
{examples}
"""
""" # noqa: E501
CUSTOM_PROMPT = """{custom}
{scoring}

View File

@ -202,11 +202,15 @@ def sorted_by_enum_index(
sortable: Iterable[T],
enum: type[Enum],
*,
key: Callable[[T], Enum | None] = lambda x: x, # type: ignore
key: Optional[Callable[[T], Enum | None]] = None,
reverse: bool = False,
) -> list[T]:
return sorted(
sortable,
key=lambda x: enum._member_names_.index(e.name) if (e := key(x)) else 420e3,
key=lambda x: (
enum._member_names_.index(e.name) # type: ignore
if (e := key(x) if key else x)
else 420e3
),
reverse=reverse,
)

View File

@ -1,13 +0,0 @@
[mypy]
namespace_packages = True
follow_imports = skip
check_untyped_defs = True
disallow_untyped_defs = True
exclude = ^(agbenchmark/challenges/|agent/|venv|venv-dev)
ignore_missing_imports = True
[mypy-agbenchmark.utils.data_types.*]
ignore_errors = True
[mypy-numpy.*]
ignore_errors = True

213
benchmark/poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]]
name = "agent-protocol-client"
@ -197,63 +197,49 @@ tests = ["attrs[tests-no-zope]", "zope-interface"]
tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"]
tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"]
[[package]]
name = "autoflake"
version = "1.7.8"
description = "Removes unused imports and unused variables"
optional = false
python-versions = ">=3.7"
files = [
{file = "autoflake-1.7.8-py3-none-any.whl", hash = "sha256:46373ef69b6714f5064c923bb28bd797c4f8a9497f557d87fc36665c6d956b39"},
{file = "autoflake-1.7.8.tar.gz", hash = "sha256:e7e46372dee46fa1c97acf310d99d922b63d369718a270809d7c278d34a194cf"},
]
[package.dependencies]
pyflakes = ">=1.1.0,<3"
tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""}
[[package]]
name = "black"
version = "22.3.0"
version = "23.12.1"
description = "The uncompromising code formatter."
optional = false
python-versions = ">=3.6.2"
python-versions = ">=3.8"
files = [
{file = "black-22.3.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:2497f9c2386572e28921fa8bec7be3e51de6801f7459dffd6e62492531c47e09"},
{file = "black-22.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5795a0375eb87bfe902e80e0c8cfaedf8af4d49694d69161e5bd3206c18618bb"},
{file = "black-22.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e3556168e2e5c49629f7b0f377070240bd5511e45e25a4497bb0073d9dda776a"},
{file = "black-22.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67c8301ec94e3bcc8906740fe071391bce40a862b7be0b86fb5382beefecd968"},
{file = "black-22.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:fd57160949179ec517d32ac2ac898b5f20d68ed1a9c977346efbac9c2f1e779d"},
{file = "black-22.3.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:cc1e1de68c8e5444e8f94c3670bb48a2beef0e91dddfd4fcc29595ebd90bb9ce"},
{file = "black-22.3.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2fc92002d44746d3e7db7cf9313cf4452f43e9ea77a2c939defce3b10b5c82"},
{file = "black-22.3.0-cp36-cp36m-win_amd64.whl", hash = "sha256:a6342964b43a99dbc72f72812bf88cad8f0217ae9acb47c0d4f141a6416d2d7b"},
{file = "black-22.3.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:328efc0cc70ccb23429d6be184a15ce613f676bdfc85e5fe8ea2a9354b4e9015"},
{file = "black-22.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06f9d8846f2340dfac80ceb20200ea5d1b3f181dd0556b47af4e8e0b24fa0a6b"},
{file = "black-22.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:ad4efa5fad66b903b4a5f96d91461d90b9507a812b3c5de657d544215bb7877a"},
{file = "black-22.3.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:e8477ec6bbfe0312c128e74644ac8a02ca06bcdb8982d4ee06f209be28cdf163"},
{file = "black-22.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:637a4014c63fbf42a692d22b55d8ad6968a946b4a6ebc385c5505d9625b6a464"},
{file = "black-22.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:863714200ada56cbc366dc9ae5291ceb936573155f8bf8e9de92aef51f3ad0f0"},
{file = "black-22.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10dbe6e6d2988049b4655b2b739f98785a884d4d6b85bc35133a8fb9a2233176"},
{file = "black-22.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:cee3e11161dde1b2a33a904b850b0899e0424cc331b7295f2a9698e79f9a69a0"},
{file = "black-22.3.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5891ef8abc06576985de8fa88e95ab70641de6c1fca97e2a15820a9b69e51b20"},
{file = "black-22.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:30d78ba6bf080eeaf0b7b875d924b15cd46fec5fd044ddfbad38c8ea9171043a"},
{file = "black-22.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ee8f1f7228cce7dffc2b464f07ce769f478968bfb3dd1254a4c2eeed84928aad"},
{file = "black-22.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ee227b696ca60dd1c507be80a6bc849a5a6ab57ac7352aad1ffec9e8b805f21"},
{file = "black-22.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:9b542ced1ec0ceeff5b37d69838106a6348e60db7b8fdd245294dc1d26136265"},
{file = "black-22.3.0-py3-none-any.whl", hash = "sha256:bc58025940a896d7e5356952228b68f793cf5fcb342be703c3a2669a1488cb72"},
{file = "black-22.3.0.tar.gz", hash = "sha256:35020b8886c022ced9282b51b5a875b6d1ab0c387b31a065b84db7c33085ca79"},
{file = "black-23.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0aaf6041986767a5e0ce663c7a2f0e9eaf21e6ff87a5f95cbf3675bfd4c41d2"},
{file = "black-23.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c88b3711d12905b74206227109272673edce0cb29f27e1385f33b0163c414bba"},
{file = "black-23.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a920b569dc6b3472513ba6ddea21f440d4b4c699494d2e972a1753cdc25df7b0"},
{file = "black-23.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:3fa4be75ef2a6b96ea8d92b1587dd8cb3a35c7e3d51f0738ced0781c3aa3a5a3"},
{file = "black-23.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8d4df77958a622f9b5a4c96edb4b8c0034f8434032ab11077ec6c56ae9f384ba"},
{file = "black-23.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:602cfb1196dc692424c70b6507593a2b29aac0547c1be9a1d1365f0d964c353b"},
{file = "black-23.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c4352800f14be5b4864016882cdba10755bd50805c95f728011bcb47a4afd59"},
{file = "black-23.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:0808494f2b2df923ffc5723ed3c7b096bd76341f6213989759287611e9837d50"},
{file = "black-23.12.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:25e57fd232a6d6ff3f4478a6fd0580838e47c93c83eaf1ccc92d4faf27112c4e"},
{file = "black-23.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2d9e13db441c509a3763a7a3d9a49ccc1b4e974a47be4e08ade2a228876500ec"},
{file = "black-23.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1bd9c210f8b109b1762ec9fd36592fdd528485aadb3f5849b2740ef17e674e"},
{file = "black-23.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:ae76c22bde5cbb6bfd211ec343ded2163bba7883c7bc77f6b756a1049436fbb9"},
{file = "black-23.12.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1fa88a0f74e50e4487477bc0bb900c6781dbddfdfa32691e780bf854c3b4a47f"},
{file = "black-23.12.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a4d6a9668e45ad99d2f8ec70d5c8c04ef4f32f648ef39048d010b0689832ec6d"},
{file = "black-23.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b18fb2ae6c4bb63eebe5be6bd869ba2f14fd0259bda7d18a46b764d8fb86298a"},
{file = "black-23.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:c04b6d9d20e9c13f43eee8ea87d44156b8505ca8a3c878773f68b4e4812a421e"},
{file = "black-23.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3e1b38b3135fd4c025c28c55ddfc236b05af657828a8a6abe5deec419a0b7055"},
{file = "black-23.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4f0031eaa7b921db76decd73636ef3a12c942ed367d8c3841a0739412b260a54"},
{file = "black-23.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97e56155c6b737854e60a9ab1c598ff2533d57e7506d97af5481141671abf3ea"},
{file = "black-23.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:dd15245c8b68fe2b6bd0f32c1556509d11bb33aec9b5d0866dd8e2ed3dba09c2"},
{file = "black-23.12.1-py3-none-any.whl", hash = "sha256:78baad24af0f033958cad29731e27363183e140962595def56423e626f4bee3e"},
{file = "black-23.12.1.tar.gz", hash = "sha256:4ce3ef14ebe8d9509188014d96af1c456a910d5b5cbf434a09fef7e024b3d0d5"},
]
[package.dependencies]
click = ">=8.0.0"
mypy-extensions = ">=0.4.3"
packaging = ">=22.0"
pathspec = ">=0.9.0"
platformdirs = ">=2"
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""}
[package.extras]
colorama = ["colorama (>=0.4.3)"]
d = ["aiohttp (>=3.7.4)"]
d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"]
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
uvloop = ["uvloop (>=0.15.2)"]
@ -558,6 +544,73 @@ mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.6.1)", "types-Pill
test = ["Pillow", "contourpy[test-no-images]", "matplotlib"]
test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"]
[[package]]
name = "coverage"
version = "7.5.1"
description = "Code coverage measurement for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "coverage-7.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0884920835a033b78d1c73b6d3bbcda8161a900f38a488829a83982925f6c2e"},
{file = "coverage-7.5.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:39afcd3d4339329c5f58de48a52f6e4e50f6578dd6099961cf22228feb25f38f"},
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a7b0ceee8147444347da6a66be737c9d78f3353b0681715b668b72e79203e4a"},
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a9ca3f2fae0088c3c71d743d85404cec8df9be818a005ea065495bedc33da35"},
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd215c0c7d7aab005221608a3c2b46f58c0285a819565887ee0b718c052aa4e"},
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4bf0655ab60d754491004a5efd7f9cccefcc1081a74c9ef2da4735d6ee4a6223"},
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:61c4bf1ba021817de12b813338c9be9f0ad5b1e781b9b340a6d29fc13e7c1b5e"},
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:db66fc317a046556a96b453a58eced5024af4582a8dbdc0c23ca4dbc0d5b3146"},
{file = "coverage-7.5.1-cp310-cp310-win32.whl", hash = "sha256:b016ea6b959d3b9556cb401c55a37547135a587db0115635a443b2ce8f1c7228"},
{file = "coverage-7.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:df4e745a81c110e7446b1cc8131bf986157770fa405fe90e15e850aaf7619bc8"},
{file = "coverage-7.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:796a79f63eca8814ca3317a1ea443645c9ff0d18b188de470ed7ccd45ae79428"},
{file = "coverage-7.5.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4fc84a37bfd98db31beae3c2748811a3fa72bf2007ff7902f68746d9757f3746"},
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6175d1a0559986c6ee3f7fccfc4a90ecd12ba0a383dcc2da30c2b9918d67d8a3"},
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fc81d5878cd6274ce971e0a3a18a8803c3fe25457165314271cf78e3aae3aa2"},
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:556cf1a7cbc8028cb60e1ff0be806be2eded2daf8129b8811c63e2b9a6c43bca"},
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9981706d300c18d8b220995ad22627647be11a4276721c10911e0e9fa44c83e8"},
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d7fed867ee50edf1a0b4a11e8e5d0895150e572af1cd6d315d557758bfa9c057"},
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ef48e2707fb320c8f139424a596f5b69955a85b178f15af261bab871873bb987"},
{file = "coverage-7.5.1-cp311-cp311-win32.whl", hash = "sha256:9314d5678dcc665330df5b69c1e726a0e49b27df0461c08ca12674bcc19ef136"},
{file = "coverage-7.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:5fa567e99765fe98f4e7d7394ce623e794d7cabb170f2ca2ac5a4174437e90dd"},
{file = "coverage-7.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b6cf3764c030e5338e7f61f95bd21147963cf6aa16e09d2f74f1fa52013c1206"},
{file = "coverage-7.5.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ec92012fefebee89a6b9c79bc39051a6cb3891d562b9270ab10ecfdadbc0c34"},
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16db7f26000a07efcf6aea00316f6ac57e7d9a96501e990a36f40c965ec7a95d"},
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:beccf7b8a10b09c4ae543582c1319c6df47d78fd732f854ac68d518ee1fb97fa"},
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8748731ad392d736cc9ccac03c9845b13bb07d020a33423fa5b3a36521ac6e4e"},
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7352b9161b33fd0b643ccd1f21f3a3908daaddf414f1c6cb9d3a2fd618bf2572"},
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:7a588d39e0925f6a2bff87154752481273cdb1736270642aeb3635cb9b4cad07"},
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:68f962d9b72ce69ea8621f57551b2fa9c70509af757ee3b8105d4f51b92b41a7"},
{file = "coverage-7.5.1-cp312-cp312-win32.whl", hash = "sha256:f152cbf5b88aaeb836127d920dd0f5e7edff5a66f10c079157306c4343d86c19"},
{file = "coverage-7.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:5a5740d1fb60ddf268a3811bcd353de34eb56dc24e8f52a7f05ee513b2d4f596"},
{file = "coverage-7.5.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e2213def81a50519d7cc56ed643c9e93e0247f5bbe0d1247d15fa520814a7cd7"},
{file = "coverage-7.5.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5037f8fcc2a95b1f0e80585bd9d1ec31068a9bcb157d9750a172836e98bc7a90"},
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3721c2c9e4c4953a41a26c14f4cef64330392a6d2d675c8b1db3b645e31f0e"},
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca498687ca46a62ae590253fba634a1fe9836bc56f626852fb2720f334c9e4e5"},
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cdcbc320b14c3e5877ee79e649677cb7d89ef588852e9583e6b24c2e5072661"},
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:57e0204b5b745594e5bc14b9b50006da722827f0b8c776949f1135677e88d0b8"},
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fe7502616b67b234482c3ce276ff26f39ffe88adca2acf0261df4b8454668b4"},
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9e78295f4144f9dacfed4f92935fbe1780021247c2fabf73a819b17f0ccfff8d"},
{file = "coverage-7.5.1-cp38-cp38-win32.whl", hash = "sha256:1434e088b41594baa71188a17533083eabf5609e8e72f16ce8c186001e6b8c41"},
{file = "coverage-7.5.1-cp38-cp38-win_amd64.whl", hash = "sha256:0646599e9b139988b63704d704af8e8df7fa4cbc4a1f33df69d97f36cb0a38de"},
{file = "coverage-7.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4cc37def103a2725bc672f84bd939a6fe4522310503207aae4d56351644682f1"},
{file = "coverage-7.5.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fc0b4d8bfeabd25ea75e94632f5b6e047eef8adaed0c2161ada1e922e7f7cece"},
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d0a0f5e06881ecedfe6f3dd2f56dcb057b6dbeb3327fd32d4b12854df36bf26"},
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9735317685ba6ec7e3754798c8871c2f49aa5e687cc794a0b1d284b2389d1bd5"},
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d21918e9ef11edf36764b93101e2ae8cc82aa5efdc7c5a4e9c6c35a48496d601"},
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c3e757949f268364b96ca894b4c342b41dc6f8f8b66c37878aacef5930db61be"},
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:79afb6197e2f7f60c4824dd4b2d4c2ec5801ceb6ba9ce5d2c3080e5660d51a4f"},
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d1d0d98d95dd18fe29dc66808e1accf59f037d5716f86a501fc0256455219668"},
{file = "coverage-7.5.1-cp39-cp39-win32.whl", hash = "sha256:1cc0fe9b0b3a8364093c53b0b4c0c2dd4bb23acbec4c9240b5f284095ccf7981"},
{file = "coverage-7.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:dde0070c40ea8bb3641e811c1cfbf18e265d024deff6de52c5950677a8fb1e0f"},
{file = "coverage-7.5.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:6537e7c10cc47c595828b8a8be04c72144725c383c4702703ff4e42e44577312"},
{file = "coverage-7.5.1.tar.gz", hash = "sha256:54de9ef3a9da981f7af93eafde4ede199e0846cd819eb27c88e2b712aae9708c"},
]
[package.dependencies]
tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""}
[package.extras]
toml = ["tomli"]
[[package]]
name = "cycler"
version = "0.12.1"
@ -671,19 +724,19 @@ typing = ["typing-extensions (>=4.8)"]
[[package]]
name = "flake8"
version = "3.9.2"
version = "7.0.0"
description = "the modular source code checker: pep8 pyflakes and co"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
python-versions = ">=3.8.1"
files = [
{file = "flake8-3.9.2-py2.py3-none-any.whl", hash = "sha256:bf8fd333346d844f616e8d47905ef3a3384edae6b4e9beb0c5101e25e3110907"},
{file = "flake8-3.9.2.tar.gz", hash = "sha256:07528381786f2a6237b061f6e96610a4167b226cb926e2aa2b6b1d78057c576b"},
{file = "flake8-7.0.0-py2.py3-none-any.whl", hash = "sha256:a6dfbb75e03252917f2473ea9653f7cd799c3064e54d4c8140044c5c065f53c3"},
{file = "flake8-7.0.0.tar.gz", hash = "sha256:33f96621059e65eec474169085dc92bf26e7b2d47366b70be2f67ab80dc25132"},
]
[package.dependencies]
mccabe = ">=0.6.0,<0.7.0"
pycodestyle = ">=2.7.0,<2.8.0"
pyflakes = ">=2.3.0,<2.4.0"
mccabe = ">=0.7.0,<0.8.0"
pycodestyle = ">=2.11.0,<2.12.0"
pyflakes = ">=3.2.0,<3.3.0"
[[package]]
name = "fonttools"
@ -1376,13 +1429,13 @@ traitlets = "*"
[[package]]
name = "mccabe"
version = "0.6.1"
version = "0.7.0"
description = "McCabe checker, plugin for flake8"
optional = false
python-versions = "*"
python-versions = ">=3.6"
files = [
{file = "mccabe-0.6.1-py2.py3-none-any.whl", hash = "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42"},
{file = "mccabe-0.6.1.tar.gz", hash = "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f"},
{file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"},
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
]
[[package]]
@ -1973,13 +2026,13 @@ pyasn1 = ">=0.4.6,<0.6.0"
[[package]]
name = "pycodestyle"
version = "2.7.0"
version = "2.11.1"
description = "Python style guide checker"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
python-versions = ">=3.8"
files = [
{file = "pycodestyle-2.7.0-py2.py3-none-any.whl", hash = "sha256:514f76d918fcc0b55c6680472f0a37970994e07bbb80725808c17089be302068"},
{file = "pycodestyle-2.7.0.tar.gz", hash = "sha256:c389c1d06bf7904078ca03399a4816f974a1d590090fecea0c63ec26ebaf1cef"},
{file = "pycodestyle-2.11.1-py2.py3-none-any.whl", hash = "sha256:44fe31000b2d866f2e41841b18528a505fbd7fef9017b04eff4e2648a0fadc67"},
{file = "pycodestyle-2.11.1.tar.gz", hash = "sha256:41ba0e7afc9752dfb53ced5489e89f8186be00e599e712660695b7a75ff2663f"},
]
[[package]]
@ -2047,13 +2100,13 @@ email = ["email-validator (>=1.0.3)"]
[[package]]
name = "pyflakes"
version = "2.3.1"
version = "3.2.0"
description = "passive checker of Python programs"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
python-versions = ">=3.8"
files = [
{file = "pyflakes-2.3.1-py2.py3-none-any.whl", hash = "sha256:7893783d01b8a89811dd72d7dfd4d84ff098e5eed95cfa8905b22bbffe52efc3"},
{file = "pyflakes-2.3.1.tar.gz", hash = "sha256:f5bc8ecabc05bb9d291eb5203d6810b49040f6ff446a756326104746cc00c1db"},
{file = "pyflakes-3.2.0-py2.py3-none-any.whl", hash = "sha256:84b5be138a2dfbb40689ca07e2152deb896a65c3a3e24c251c5c62489568074a"},
{file = "pyflakes-3.2.0.tar.gz", hash = "sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f"},
]
[[package]]
@ -2085,6 +2138,24 @@ files = [
[package.extras]
diagrams = ["jinja2", "railroad-diagrams"]
[[package]]
name = "pyright"
version = "1.1.364"
description = "Command line wrapper for pyright"
optional = false
python-versions = ">=3.7"
files = [
{file = "pyright-1.1.364-py3-none-any.whl", hash = "sha256:865f1e02873c5dc7427c95acf53659a118574010e6fb364e27e47ec5c46a9f26"},
{file = "pyright-1.1.364.tar.gz", hash = "sha256:612a2106a4078ec57efc22b5620729e9bdf4a3c17caba013b534bd33f7d08e5a"},
]
[package.dependencies]
nodeenv = ">=1.6.0"
[package.extras]
all = ["twine (>=3.4.1)"]
dev = ["twine (>=3.4.1)"]
[[package]]
name = "pysocks"
version = "1.7.1"
@ -2137,6 +2208,24 @@ pytest = ">=7.0.0"
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
[[package]]
name = "pytest-cov"
version = "5.0.0"
description = "Pytest plugin for measuring coverage."
optional = false
python-versions = ">=3.8"
files = [
{file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"},
{file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"},
]
[package.dependencies]
coverage = {version = ">=5.2.1", extras = ["toml"]}
pytest = ">=4.6"
[package.extras]
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
[[package]]
name = "python-dateutil"
version = "2.8.2"
@ -2774,4 +2863,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "6eefdbbefb500de627cac39eb6eb1fdcecab76dd4c3599cf08ef6dc647cf71c9"
content-hash = "4a980e6d8f54a2f7f6a3c55d4f40ac3a4b27b5ac6573dd2a39e11213a4b126dd"

View File

@ -37,59 +37,49 @@ click-default-group = "^1.2.4"
tabulate = "^0.9.0"
[tool.poetry.group.dev.dependencies]
flake8 = "^3.9.2"
isort = "^5.9.3"
black = "22.3"
autoflake = "^1.4"
black = "^23.12.1"
flake8 = "^7.0.0"
isort = "^5.13.1"
pyright = "^1.1.364"
pandas = "^2.0.3"
gspread = "^5.10.0"
oauth2client = "^4.1.3"
pre-commit = "^3.3.3"
pytest-cov = "^5.0.0"
[tool.poetry.scripts]
agbenchmark = "agbenchmark.__main__:cli"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
minversion = "6.0"
addopts = "-ra -q"
testpaths = [
"tests", "agbenchmark",
]
asyncio_mode = "auto"
markers = [
"interface",
"code",
"memory",
"iterate",
"adaptability",
"safety",
"content_gen",
"product_advisor"
]
filterwarnings = [
"ignore::pytest.PytestAssertRewriteWarning",
"ignore::matplotlib.MatplotlibDeprecationWarning"
]
[tool.black]
line-length = 88
target-version = ['py310']
include = '\.pyi?$'
packages = ["autogpt"]
extend-exclude = '(/dist|/.venv|/venv|/build|/agent|agbenchmark/challenges)/'
[tool.isort]
profile = "black"
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
line_length = 88
sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"]
skip_glob = [".tox", "__pycache__", "*.pyc", "venv*/*", "reports", "venv", "env", "node_modules", ".env", ".venv", "dist", "agent/*", "agbenchmark/challenges/*"]
skip_glob = ["reports"]
[tool.poetry.scripts]
agbenchmark = "agbenchmark.__main__:cli"
[tool.pyright]
pythonVersion = "3.10"
exclude = [
"notebooks/**",
"reports/**",
"**/node_modules",
"**/__pycache__",
"**/.*",
]
ignore = [
"../forge/**"
]
[tool.pytest.ini_options]
testpaths = ["tests"]

View File

@ -17,7 +17,7 @@ def print_markdown_report(report_json_file: str):
report = Report.parse_file(report_json_file)
# Header and metadata
click.echo(f"# Benchmark Report")
click.echo("# Benchmark Report")
click.echo(f"- ⌛ **Run time:** `{report.metrics.run_time}`")
click.echo(
f" - **Started at:** `{report.benchmark_start_time[:16].replace('T', '` `')}`"

View File

@ -1,11 +1,16 @@
import datetime
import time
import pytest
import requests
URL_BENCHMARK = "http://localhost:8080/ap/v1"
URL_AGENT = "http://localhost:8000/ap/v1"
import datetime
import time
try:
response = requests.get(f"{URL_AGENT}/agent/tasks")
except requests.exceptions.ConnectionError:
pytest.skip("No agent available to test against", allow_module_level=True)
@pytest.mark.parametrize(
@ -20,7 +25,8 @@ import time
),
(
"f219f3d3-a41b-45a9-a3d0-389832086ee8",
"Read the file called file_to_read.txt and write its content to a file called output.txt",
"Read the file called file_to_read.txt "
"and write its content to a file called output.txt",
1,
"ReadFile",
False,
@ -28,7 +34,11 @@ import time
],
)
def test_entire_workflow(
eval_id, input_text, expected_artifact_length, test_name, should_be_successful
eval_id: str,
input_text: str,
expected_artifact_length: int,
test_name: str,
should_be_successful: bool,
):
task_request = {"eval_id": eval_id, "input": input_text}
response = requests.get(f"{URL_AGENT}/agent/tasks")
@ -64,7 +74,7 @@ def test_entire_workflow(
)
assert step_response.status_code == 200
step_response = step_response.json()
assert step_response["is_last"] == True # Assuming is_last is always True
assert step_response["is_last"] is True # Assuming is_last is always True
eval_response = requests.post(
URL_BENCHMARK + "/agent/tasks/" + task_response_benchmark_id + "/evaluations",

8
cli.py
View File

@ -131,7 +131,9 @@ def start(agent_name: str, no_setup: bool):
script_dir = os.path.dirname(os.path.realpath(__file__))
agent_dir = os.path.join(
script_dir,
f"agents/{agent_name}" if agent_name not in ["autogpt", "forge"] else agent_name,
f"agents/{agent_name}"
if agent_name not in ["autogpt", "forge"]
else agent_name,
)
run_command = os.path.join(agent_dir, "run")
run_bench_command = os.path.join(agent_dir, "run_benchmark")
@ -247,7 +249,9 @@ def start(agent_name, subprocess_args):
script_dir = os.path.dirname(os.path.realpath(__file__))
agent_dir = os.path.join(
script_dir,
f"agents/{agent_name}" if agent_name not in ["autogpt", "forge"] else agent_name,
f"agents/{agent_name}"
if agent_name not in ["autogpt", "forge"]
else agent_name,
)
benchmark_script = os.path.join(agent_dir, "run_benchmark")
if os.path.exists(agent_dir) and os.path.isfile(benchmark_script):

View File

@ -202,7 +202,7 @@ class MyAgent(Agent):
def __init__(
self,
settings: AgentSettings,
llm_provider: ChatModelProvider,
llm_provider: MultiProvider,
file_storage: FileStorage,
legacy_config: Config,
):
@ -219,7 +219,7 @@ class MyAgent(Agent):
def __init__(
self,
settings: AgentSettings,
llm_provider: ChatModelProvider,
llm_provider: MultiProvider,
file_storage: FileStorage,
legacy_config: Config,
):

View File

@ -1,15 +1,11 @@
[flake8]
max-line-length = 88
select = "E303, W293, W292, E305, E231, E302"
# Ignore rules that conflict with Black code style
extend-ignore = E203, W503
exclude =
.tox,
__pycache__,
.git,
__pycache__/,
*.pyc,
.env
venv*/*,
.venv/*,
reports/*,
dist/*,
agent/*,
code,
agbenchmark/challenges/*
.pytest_cache/,
venv*/,
.venv/,

5
forge/.gitignore vendored
View File

@ -160,7 +160,8 @@ CURRENT_BULLETIN.md
agbenchmark_config/workspace
agbenchmark_config/reports
*.sqlite
*.sqlite*
*.db
.agbench
.agbenchmark
.benchmarks
@ -168,7 +169,7 @@ agbenchmark_config/reports
.pytest_cache
.vscode
ig_*
agent.db
agbenchmark_config/updates.json
agbenchmark_config/challenges_already_beaten.json
agbenchmark_config/temp_folder/*
test_workspace/

View File

@ -1,43 +0,0 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-added-large-files
args: ['--maxkb=500']
- id: check-byte-order-marker
- id: check-case-conflict
- id: check-merge-conflict
- id: check-symlinks
- id: debug-statements
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
language_version: python3.11
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
language_version: python3.11
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: 'v1.3.0'
# hooks:
# - id: mypy
- repo: local
hooks:
- id: autoflake
name: autoflake
entry: autoflake --in-place --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring forge/autogpt
language: python
types: [ python ]
# Mono repo has bronken this TODO: fix
# - id: pytest-check
# name: pytest-check
# entry: pytest
# language: system
# pass_filenames: false
# always_run: true

View File

@ -9,27 +9,24 @@ from forge.logging.config import configure_logging
logger = logging.getLogger(__name__)
logo = """\n\n
d8888 888 .d8888b. 8888888b. 88888888888
d88888 888 d88P Y88b 888 Y88b 888
d88P888 888 888 888 888 888 888
d88P 888 888 888 888888 .d88b. 888 888 d88P 888
d88P 888 888 888 888 d88""88b 888 88888 8888888P" 888
d88P 888 888 888 888 888 888 888 888 888 888
d8888888888 Y88b 888 Y88b. Y88..88P Y88b d88P 888 888
d88P 888 "Y88888 "Y888 "Y88P" "Y8888P88 888 888
8888888888
888
888
8888888 .d88b. 888d888 .d88b. .d88b.
888 d88""88b 888P" d88P"88b d8P Y8b
888 888 888 888 888 888 88888888
888 Y88..88P 888 Y88b 888 Y8b.
888 "Y88P" 888 "Y88888 "Y8888
888
Y8b d88P
d8888 888 .d8888b. 8888888b. 88888888888
d88P888 888 888 888 888 888 888
d88P 888 888 888 888888 .d88b. 888 888 d88P 888
d88P 888 888 888 888 d88""88b 888 88888 8888888P" 888
d88P 888 888 888 888 888 888 888 888 888 888
d8888888888 Y88b 888 Y88b. Y88..88P Y88b d88P 888 888
d88P 888 "Y88888 "Y888 "Y88P" "Y8888P88 888 888
8888888888
888
888 .d88b. 888d888 .d88b. .d88b.
888888 d88""88b 888P" d88P"88b d8P Y8b
888 888 888 888 888 888 88888888
888 Y88..88P 888 Y88b 888 Y8b.
888 "Y88P" 888 "Y88888 "Y8888
888
Y8b d88P
"Y88P" v0.1.0
\n"""

View File

@ -1,15 +1,7 @@
from .base import AgentMeta, BaseAgent, BaseAgentConfiguration, BaseAgentSettings
from .components import (
AgentComponent,
ComponentEndpointError,
ComponentSystemError,
EndpointPipelineError,
)
from .protocols import (
AfterExecute,
AfterParse,
CommandProvider,
DirectiveProvider,
ExecutionFailure,
MessageProvider,
)
from .base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
__all__ = [
"BaseAgent",
"BaseAgentConfiguration",
"BaseAgentSettings",
]

View File

@ -24,7 +24,6 @@ from forge.agent_protocol.models.task import (
TaskStepsListResponse,
)
from forge.file_storage.base import FileStorage
from forge.utils.exceptions import NotFoundError
logger = logging.getLogger(__name__)
@ -79,7 +78,8 @@ class Agent:
else:
logger.warning(
f"Frontend not found. {frontend_path} does not exist. The frontend will not be served"
f"Frontend not found. {frontend_path} does not exist. "
"The frontend will not be served."
)
app.add_middleware(AgentMiddleware, agent=self)
@ -94,34 +94,25 @@ class Agent:
"""
Create a task for the agent.
"""
try:
task = await self.db.create_task(
input=task_request.input,
additional_input=task_request.additional_input,
)
return task
except Exception as e:
raise
task = await self.db.create_task(
input=task_request.input,
additional_input=task_request.additional_input,
)
return task
async def list_tasks(self, page: int = 1, pageSize: int = 10) -> TaskListResponse:
"""
List all tasks that the agent has created.
"""
try:
tasks, pagination = await self.db.list_tasks(page, pageSize)
response = TaskListResponse(tasks=tasks, pagination=pagination)
return response
except Exception as e:
raise
tasks, pagination = await self.db.list_tasks(page, pageSize)
response = TaskListResponse(tasks=tasks, pagination=pagination)
return response
async def get_task(self, task_id: str) -> Task:
"""
Get a task by ID.
"""
try:
task = await self.db.get_task(task_id)
except Exception as e:
raise
task = await self.db.get_task(task_id)
return task
async def list_steps(
@ -130,12 +121,9 @@ class Agent:
"""
List the IDs of all steps that the task has created.
"""
try:
steps, pagination = await self.db.list_steps(task_id, page, pageSize)
response = TaskStepsListResponse(steps=steps, pagination=pagination)
return response
except Exception as e:
raise
steps, pagination = await self.db.list_steps(task_id, page, pageSize)
response = TaskStepsListResponse(steps=steps, pagination=pagination)
return response
async def execute_step(self, task_id: str, step_request: StepRequestBody) -> Step:
"""
@ -147,11 +135,8 @@ class Agent:
"""
Get a step by ID.
"""
try:
step = await self.db.get_step(task_id, step_id)
return step
except Exception as e:
raise
step = await self.db.get_step(task_id, step_id)
return step
async def list_artifacts(
self, task_id: str, page: int = 1, pageSize: int = 10
@ -159,62 +144,45 @@ class Agent:
"""
List the artifacts that the task has created.
"""
try:
artifacts, pagination = await self.db.list_artifacts(
task_id, page, pageSize
)
return TaskArtifactsListResponse(artifacts=artifacts, pagination=pagination)
except Exception as e:
raise
artifacts, pagination = await self.db.list_artifacts(task_id, page, pageSize)
return TaskArtifactsListResponse(artifacts=artifacts, pagination=pagination)
async def create_artifact(
self, task_id: str, file: UploadFile, relative_path: str
self, task_id: str, file: UploadFile, relative_path: str = ""
) -> Artifact:
"""
Create an artifact for the task.
"""
data = None
file_name = file.filename or str(uuid4())
try:
data = b""
while contents := file.file.read(1024 * 1024):
data += contents
# Check if relative path ends with filename
if relative_path.endswith(file_name):
file_path = relative_path
else:
file_path = os.path.join(relative_path, file_name)
data = b""
while contents := file.file.read(1024 * 1024):
data += contents
# Check if relative path ends with filename
if relative_path.endswith(file_name):
file_path = relative_path
else:
file_path = os.path.join(relative_path, file_name)
await self.workspace.write_file(file_path, data)
await self.workspace.write_file(file_path, data)
artifact = await self.db.create_artifact(
task_id=task_id,
file_name=file_name,
relative_path=relative_path,
agent_created=False,
)
except Exception as e:
raise
artifact = await self.db.create_artifact(
task_id=task_id,
file_name=file_name,
relative_path=relative_path,
agent_created=False,
)
return artifact
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
async def get_artifact(self, task_id: str, artifact_id: str) -> StreamingResponse:
"""
Get an artifact by ID.
"""
try:
artifact = await self.db.get_artifact(artifact_id)
if artifact.file_name not in artifact.relative_path:
file_path = os.path.join(artifact.relative_path, artifact.file_name)
else:
file_path = artifact.relative_path
retrieved_artifact = self.workspace.read_file(file_path)
except NotFoundError as e:
raise
except FileNotFoundError as e:
raise
except Exception as e:
raise
artifact = await self.db.get_artifact(artifact_id)
if artifact.file_name not in artifact.relative_path:
file_path = os.path.join(artifact.relative_path, artifact.file_name)
else:
file_path = artifact.relative_path
retrieved_artifact = self.workspace.read_file(file_path, binary=True)
return StreamingResponse(
BytesIO(retrieved_artifact),

View File

@ -1,6 +1,7 @@
from pathlib import Path
import pytest
from fastapi import UploadFile
from forge.agent_protocol.database.db import AgentDB
from forge.agent_protocol.models.task import (
@ -16,16 +17,23 @@ from .agent import Agent
@pytest.fixture
def agent():
def agent(test_workspace: Path):
db = AgentDB("sqlite:///test.db")
config = FileStorageConfiguration(root=Path("./test_workspace"))
config = FileStorageConfiguration(root=test_workspace)
workspace = LocalFileStorage(config)
return Agent(db, workspace)
@pytest.mark.skip
@pytest.fixture
def file_upload():
this_file = Path(__file__)
file_handle = this_file.open("rb")
yield UploadFile(file_handle, filename=this_file.name)
file_handle.close()
@pytest.mark.asyncio
async def test_create_task(agent):
async def test_create_task(agent: Agent):
task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"}
)
@ -33,20 +41,18 @@ async def test_create_task(agent):
assert task.input == "test_input"
@pytest.mark.skip
@pytest.mark.asyncio
async def test_list_tasks(agent):
async def test_list_tasks(agent: Agent):
task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"}
)
task = await agent.create_task(task_request)
await agent.create_task(task_request)
tasks = await agent.list_tasks()
assert isinstance(tasks, TaskListResponse)
@pytest.mark.skip
@pytest.mark.asyncio
async def test_get_task(agent):
async def test_get_task(agent: Agent):
task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"}
)
@ -55,9 +61,9 @@ async def test_get_task(agent):
assert retrieved_task.task_id == task.task_id
@pytest.mark.skip
@pytest.mark.xfail(reason="execute_step is not implemented")
@pytest.mark.asyncio
async def test_create_and_execute_step(agent):
async def test_execute_step(agent: Agent):
task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"}
)
@ -65,14 +71,14 @@ async def test_create_and_execute_step(agent):
step_request = StepRequestBody(
input="step_input", additional_input={"input": "additional_test_input"}
)
step = await agent.create_and_execute_step(task.task_id, step_request)
step = await agent.execute_step(task.task_id, step_request)
assert step.input == "step_input"
assert step.additional_input == {"input": "additional_test_input"}
@pytest.mark.skip
@pytest.mark.xfail(reason="execute_step is not implemented")
@pytest.mark.asyncio
async def test_get_step(agent):
async def test_get_step(agent: Agent):
task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"}
)
@ -80,38 +86,52 @@ async def test_get_step(agent):
step_request = StepRequestBody(
input="step_input", additional_input={"input": "additional_test_input"}
)
step = await agent.create_and_execute_step(task.task_id, step_request)
step = await agent.execute_step(task.task_id, step_request)
retrieved_step = await agent.get_step(task.task_id, step.step_id)
assert retrieved_step.step_id == step.step_id
@pytest.mark.skip
@pytest.mark.asyncio
async def test_list_artifacts(agent):
artifacts = await agent.list_artifacts()
assert isinstance(artifacts, list)
async def test_list_artifacts(agent: Agent):
tasks = await agent.list_tasks()
assert tasks.tasks, "No tasks in test.db"
artifacts = await agent.list_artifacts(tasks.tasks[0].task_id)
assert isinstance(artifacts.artifacts, list)
@pytest.mark.skip
@pytest.mark.asyncio
async def test_create_artifact(agent):
async def test_create_artifact(agent: Agent, file_upload: UploadFile):
task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"}
)
task = await agent.create_task(task_request)
artifact_request = ArtifactRequestBody(file=None, uri="test_uri")
artifact = await agent.create_artifact(task.task_id, artifact_request)
assert artifact.uri == "test_uri"
artifact = await agent.create_artifact(
task_id=task.task_id,
file=file_upload,
relative_path=f"a_dir/{file_upload.filename}",
)
assert artifact.file_name == file_upload.filename
assert artifact.relative_path == f"a_dir/{file_upload.filename}"
@pytest.mark.skip
@pytest.mark.asyncio
async def test_get_artifact(agent):
async def test_create_and_get_artifact(agent: Agent, file_upload: UploadFile):
task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"}
)
task = await agent.create_task(task_request)
artifact_request = ArtifactRequestBody(file=None, uri="test_uri")
artifact = await agent.create_artifact(task.task_id, artifact_request)
artifact = await agent.create_artifact(
task_id=task.task_id,
file=file_upload,
relative_path=f"b_dir/{file_upload.filename}",
)
await file_upload.seek(0)
file_upload_content = await file_upload.read()
retrieved_artifact = await agent.get_artifact(task.task_id, artifact.artifact_id)
assert retrieved_artifact.artifact_id == artifact.artifact_id
retrieved_artifact_content = bytearray()
async for b in retrieved_artifact.body_iterator:
retrieved_artifact_content.extend(b) # type: ignore
assert retrieved_artifact_content == file_upload_content

View File

@ -5,22 +5,21 @@ import inspect
import logging
from abc import ABCMeta, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Generic,
Iterator,
Optional,
ParamSpec,
TypeVar,
cast,
overload,
)
from colorama import Fore
from pydantic import BaseModel, Field, validator
if TYPE_CHECKING:
from forge.models.action import ActionProposal, ActionResult
from forge.agent import protocols
from forge.agent.components import (
AgentComponent,
@ -29,15 +28,10 @@ from forge.agent.components import (
)
from forge.config.ai_directives import AIDirectives
from forge.config.ai_profile import AIProfile
from forge.config.config import ConfigBuilder
from forge.llm.providers import CHAT_MODELS, ModelName, OpenAIModelName
from forge.llm.providers.schema import ChatModelInfo
from forge.models.config import (
Configurable,
SystemConfiguration,
SystemSettings,
UserConfigurable,
)
from forge.models.action import ActionResult, AnyProposal
from forge.models.config import SystemConfiguration, SystemSettings, UserConfigurable
logger = logging.getLogger(__name__)
@ -133,17 +127,7 @@ class AgentMeta(ABCMeta):
return instance
class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
C = TypeVar("C", bound=AgentComponent)
default_settings = BaseAgentSettings(
name="BaseAgent",
description=__doc__ if __doc__ else "",
)
class BaseAgent(Generic[AnyProposal], metaclass=AgentMeta):
def __init__(
self,
settings: BaseAgentSettings,
@ -173,13 +157,13 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
return self.config.send_token_limit or self.llm.max_tokens * 3 // 4
@abstractmethod
async def propose_action(self) -> ActionProposal:
async def propose_action(self) -> AnyProposal:
...
@abstractmethod
async def execute(
self,
proposal: ActionProposal,
proposal: AnyProposal,
user_feedback: str = "",
) -> ActionResult:
...
@ -187,7 +171,7 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
@abstractmethod
async def do_not_execute(
self,
denied_proposal: ActionProposal,
denied_proposal: AnyProposal,
user_feedback: str,
) -> ActionResult:
...
@ -203,13 +187,16 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
@overload
async def run_pipeline(
self, protocol_method: Callable[P, None], *args, retry_limit: int = 3
self,
protocol_method: Callable[P, None | Awaitable[None]],
*args,
retry_limit: int = 3,
) -> list[None]:
...
async def run_pipeline(
self,
protocol_method: Callable[P, Iterator[T] | None],
protocol_method: Callable[P, Iterator[T] | None | Awaitable[None]],
*args,
retry_limit: int = 3,
) -> list[T] | list[None]:
@ -240,7 +227,10 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
)
continue
method = getattr(component, method_name, None)
method = cast(
Callable[..., Iterator[T] | None | Awaitable[None]] | None,
getattr(component, method_name, None),
)
if not callable(method):
continue
@ -248,10 +238,9 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
while component_attempts < retry_limit:
try:
component_args = self._selective_copy(args)
if inspect.iscoroutinefunction(method):
result = await method(*component_args)
else:
result = method(*component_args)
result = method(*component_args)
if inspect.isawaitable(result):
result = await result
if result is not None:
method_result.extend(result)
args = component_args
@ -269,9 +258,9 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
break
# Successful pipeline execution
break
except EndpointPipelineError:
except EndpointPipelineError as e:
self._trace.append(
f"{Fore.LIGHTRED_EX}{component.__class__.__name__}: "
f"{Fore.LIGHTRED_EX}{e.triggerer.__class__.__name__}: "
f"EndpointPipelineError{Fore.RESET}"
)
# Restart from the beginning on EndpointPipelineError

View File

@ -36,8 +36,9 @@ class AgentComponent(ABC):
class ComponentEndpointError(Exception):
"""Error of a single protocol method on a component."""
def __init__(self, message: str = ""):
def __init__(self, message: str, component: AgentComponent):
self.message = message
self.triggerer = component
super().__init__(message)

View File

@ -1,14 +1,13 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, Iterator
from typing import TYPE_CHECKING, Awaitable, Generic, Iterator
from forge.models.action import ActionResult, AnyProposal
from .components import AgentComponent
if TYPE_CHECKING:
from forge.command.command import Command
from forge.llm.providers import ChatMessage
from forge.models.action import ActionResult
from .base import ActionProposal
class DirectiveProvider(AgentComponent):
@ -34,19 +33,19 @@ class MessageProvider(AgentComponent):
...
class AfterParse(AgentComponent):
class AfterParse(AgentComponent, Generic[AnyProposal]):
@abstractmethod
def after_parse(self, result: "ActionProposal") -> None:
def after_parse(self, result: AnyProposal) -> None | Awaitable[None]:
...
class ExecutionFailure(AgentComponent):
@abstractmethod
def execution_failure(self, error: Exception) -> None:
def execution_failure(self, error: Exception) -> None | Awaitable[None]:
...
class AfterExecute(AgentComponent):
@abstractmethod
def after_execute(self, result: "ActionResult") -> None:
def after_execute(self, result: "ActionResult") -> None | Awaitable[None]:
...

View File

@ -1,39 +1,16 @@
"""
Routes for the Agent Service.
This module defines the API routes for the Agent service. While there are multiple endpoints provided by the service,
the ones that require special attention due to their complexity are:
This module defines the API routes for the Agent service.
1. `execute_agent_task_step`:
This route is significant because this is where the agent actually performs the work. The function handles
executing the next step for a task based on its current state, and it requires careful implementation to ensure
all scenarios (like the presence or absence of steps or a step marked as `last_step`) are handled correctly.
2. `upload_agent_task_artifacts`:
This route allows for the upload of artifacts, supporting various URI types (e.g., s3, gcs, ftp, http).
The support for different URI types makes it a bit more complex, and it's important to ensure that all
supported URI types are correctly managed. NOTE: The AutoGPT team will eventually handle the most common
uri types for you.
3. `create_agent_task`:
While this is a simpler route, it plays a crucial role in the workflow, as it's responsible for the creation
of a new task.
Developers and contributors should be especially careful when making modifications to these routes to ensure
consistency and correctness in the system's behavior.
Developers and contributors should be especially careful when making modifications
to these routes to ensure consistency and correctness in the system's behavior.
"""
import json
import logging
from typing import Optional
from typing import TYPE_CHECKING, Optional
from fastapi import APIRouter, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse
from forge.utils.exceptions import (
NotFoundError,
get_detailed_traceback,
get_exception_message,
)
from fastapi import APIRouter, HTTPException, Query, Request, Response, UploadFile
from fastapi.responses import StreamingResponse
from .models import (
Artifact,
@ -46,6 +23,9 @@ from .models import (
TaskStepsListResponse,
)
if TYPE_CHECKING:
from forge.agent.agent import Agent
base_router = APIRouter()
logger = logging.getLogger(__name__)
@ -73,10 +53,10 @@ async def create_agent_task(request: Request, task_request: TaskRequestBody) ->
Args:
request (Request): FastAPI request object.
task (TaskRequestBody): The task request containing input and additional input data.
task (TaskRequestBody): The task request containing input data.
Returns:
Task: A new task with task_id, input, additional_input, and empty lists for artifacts and steps.
Task: A new task with task_id, input, and additional_input set.
Example:
Request (TaskRequestBody defined in schema.py):
@ -93,46 +73,32 @@ async def create_agent_task(request: Request, task_request: TaskRequestBody) ->
"artifacts": [],
}
"""
agent = request["agent"]
agent: "Agent" = request["agent"]
try:
task_request = await agent.create_task(task_request)
return Response(
content=task_request.json(),
status_code=200,
media_type="application/json",
)
task = await agent.create_task(task_request)
return task
except Exception:
logger.exception(f"Error whilst trying to create a task: {task_request}")
return Response(
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
raise
@base_router.get("/agent/tasks", tags=["agent"], response_model=TaskListResponse)
async def list_agent_tasks(
request: Request,
page: Optional[int] = Query(1, ge=1),
page_size: Optional[int] = Query(10, ge=1),
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1),
) -> TaskListResponse:
"""
Retrieves a paginated list of all tasks.
Args:
request (Request): FastAPI request object.
page (int, optional): The page number for pagination. Defaults to 1.
page_size (int, optional): The number of tasks per page for pagination. Defaults to 10.
page (int, optional): Page number for pagination. Default: 1
page_size (int, optional): Number of tasks per page for pagination. Default: 10
Returns:
TaskListResponse: A response object containing a list of tasks and pagination details.
TaskListResponse: A list of tasks, and pagination details.
Example:
Request:
@ -158,34 +124,13 @@ async def list_agent_tasks(
}
}
"""
agent = request["agent"]
agent: "Agent" = request["agent"]
try:
tasks = await agent.list_tasks(page, page_size)
return Response(
content=tasks.json(),
status_code=200,
media_type="application/json",
)
except NotFoundError:
logger.exception("Error whilst trying to list tasks")
return Response(
content=json.dumps({"error": "Tasks not found"}),
status_code=404,
media_type="application/json",
)
return tasks
except Exception:
logger.exception("Error whilst trying to list tasks")
return Response(
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
raise
@base_router.get("/agent/tasks/{task_id}", tags=["agent"], response_model=Task)
@ -239,36 +184,14 @@ async def get_agent_task(request: Request, task_id: str) -> Task:
}
]
}
"""
agent = request["agent"]
""" # noqa: E501
agent: "Agent" = request["agent"]
try:
task = await agent.get_task(task_id)
return Response(
content=task.json(),
status_code=200,
media_type="application/json",
)
except NotFoundError:
logger.exception(f"Error whilst trying to get task: {task_id}")
return Response(
content=json.dumps({"error": "Task not found"}),
status_code=404,
media_type="application/json",
)
return task
except Exception:
logger.exception(f"Error whilst trying to get task: {task_id}")
return Response(
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
raise
@base_router.get(
@ -279,8 +202,8 @@ async def get_agent_task(request: Request, task_id: str) -> Task:
async def list_agent_task_steps(
request: Request,
task_id: str,
page: Optional[int] = Query(1, ge=1),
page_size: Optional[int] = Query(10, ge=1, alias="pageSize"),
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, alias="pageSize"),
) -> TaskStepsListResponse:
"""
Retrieves a paginated list of steps associated with a specific task.
@ -289,10 +212,10 @@ async def list_agent_task_steps(
request (Request): FastAPI request object.
task_id (str): The ID of the task.
page (int, optional): The page number for pagination. Defaults to 1.
page_size (int, optional): The number of steps per page for pagination. Defaults to 10.
page_size (int, optional): Number of steps per page for pagination. Default: 10.
Returns:
TaskStepsListResponse: A response object containing a list of steps and pagination details.
TaskStepsListResponse: A list of steps, and pagination details.
Example:
Request:
@ -315,54 +238,40 @@ async def list_agent_task_steps(
"pageSize": 10
}
}
"""
agent = request["agent"]
""" # noqa: E501
agent: "Agent" = request["agent"]
try:
steps = await agent.list_steps(task_id, page, page_size)
return Response(
content=steps.json(),
status_code=200,
media_type="application/json",
)
except NotFoundError:
logger.exception("Error whilst trying to list steps")
return Response(
content=json.dumps({"error": "Steps not found"}),
status_code=404,
media_type="application/json",
)
return steps
except Exception:
logger.exception("Error whilst trying to list steps")
return Response(
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
raise
@base_router.post("/agent/tasks/{task_id}/steps", tags=["agent"], response_model=Step)
async def execute_agent_task_step(
request: Request, task_id: str, step: Optional[StepRequestBody] = None
request: Request, task_id: str, step_request: Optional[StepRequestBody] = None
) -> Step:
"""
Executes the next step for a specified task based on the current task status and returns the
executed step with additional feedback fields.
Executes the next step for a specified task based on the current task status and
returns the executed step with additional feedback fields.
Depending on the current state of the task, the following scenarios are supported:
This route is significant because this is where the agent actually performs work.
The function handles executing the next step for a task based on its current state,
and it requires careful implementation to ensure all scenarios (like the presence
or absence of steps or a step marked as `last_step`) are handled correctly.
Depending on the current state of the task, the following scenarios are possible:
1. No steps exist for the task.
2. There is at least one step already for the task, and the task does not have a completed step marked as `last_step`.
2. There is at least one step already for the task, and the task does not have a
completed step marked as `last_step`.
3. There is a completed step marked as `last_step` already on the task.
In each of these scenarios, a step object will be returned with two additional fields: `output` and `additional_output`.
In each of these scenarios, a step object will be returned with two additional
fields: `output` and `additional_output`.
- `output`: Provides the primary response or feedback to the user.
- `additional_output`: Supplementary information or data. Its specific content is not strictly defined and can vary based on the step or agent's implementation.
- `additional_output`: Supplementary information or data. Its specific content is
not strictly defined and can vary based on the step or agent's implementation.
Args:
request (Request): FastAPI request object.
@ -389,39 +298,17 @@ async def execute_agent_task_step(
...
}
"""
agent = request["agent"]
agent: "Agent" = request["agent"]
try:
# An empty step request represents a yes to continue command
if not step:
step = StepRequestBody(input="y")
if not step_request:
step_request = StepRequestBody(input="y")
step = await agent.execute_step(task_id, step)
return Response(
content=step.json(),
status_code=200,
media_type="application/json",
)
except NotFoundError:
logger.exception(f"Error whilst trying to execute a task step: {task_id}")
return Response(
content=json.dumps({"error": f"Task not found {task_id}"}),
status_code=404,
media_type="application/json",
)
step = await agent.execute_step(task_id, step_request)
return step
except Exception:
logger.exception(f"Error whilst trying to execute a task step: {task_id}")
return Response(
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
raise
@base_router.get(
@ -450,31 +337,13 @@ async def get_agent_task_step(request: Request, task_id: str, step_id: str) -> S
...
}
"""
agent = request["agent"]
agent: "Agent" = request["agent"]
try:
step = await agent.get_step(task_id, step_id)
return Response(content=step.json(), status_code=200)
except NotFoundError:
logger.exception(f"Error whilst trying to get step: {step_id}")
return Response(
content=json.dumps({"error": "Step not found"}),
status_code=404,
media_type="application/json",
)
return step
except Exception:
logger.exception(f"Error whilst trying to get step: {step_id}")
return Response(
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
raise
@base_router.get(
@ -485,8 +354,8 @@ async def get_agent_task_step(request: Request, task_id: str, step_id: str) -> S
async def list_agent_task_artifacts(
request: Request,
task_id: str,
page: Optional[int] = Query(1, ge=1),
page_size: Optional[int] = Query(10, ge=1, alias="pageSize"),
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, alias="pageSize"),
) -> TaskArtifactsListResponse:
"""
Retrieves a paginated list of artifacts associated with a specific task.
@ -495,10 +364,10 @@ async def list_agent_task_artifacts(
request (Request): FastAPI request object.
task_id (str): The ID of the task.
page (int, optional): The page number for pagination. Defaults to 1.
page_size (int, optional): The number of items per page for pagination. Defaults to 10.
page_size (int, optional): Number of items per page for pagination. Default: 10.
Returns:
TaskArtifactsListResponse: A response object containing a list of artifacts and pagination details.
TaskArtifactsListResponse: A list of artifacts, and pagination details.
Example:
Request:
@ -518,52 +387,33 @@ async def list_agent_task_artifacts(
"pageSize": 10
}
}
"""
agent = request["agent"]
""" # noqa: E501
agent: "Agent" = request["agent"]
try:
artifacts: TaskArtifactsListResponse = await agent.list_artifacts(
task_id, page, page_size
)
artifacts = await agent.list_artifacts(task_id, page, page_size)
return artifacts
except NotFoundError:
logger.exception("Error whilst trying to list artifacts")
return Response(
content=json.dumps({"error": "Artifacts not found for task_id"}),
status_code=404,
media_type="application/json",
)
except Exception:
logger.exception("Error whilst trying to list artifacts")
return Response(
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
raise
@base_router.post(
"/agent/tasks/{task_id}/artifacts", tags=["agent"], response_model=Artifact
)
async def upload_agent_task_artifacts(
request: Request, task_id: str, file: UploadFile, relative_path: Optional[str] = ""
request: Request, task_id: str, file: UploadFile, relative_path: str = ""
) -> Artifact:
"""
This endpoint is used to upload an artifact associated with a specific task. The artifact is provided as a file.
This endpoint is used to upload an artifact (file) associated with a specific task.
Args:
request (Request): The FastAPI request object.
task_id (str): The unique identifier of the task for which the artifact is being uploaded.
task_id (str): The ID of the task for which the artifact is being uploaded.
file (UploadFile): The file being uploaded as an artifact.
relative_path (str): The relative path for the file. This is a query parameter.
Returns:
Artifact: An object containing metadata of the uploaded artifact, including its unique identifier.
Artifact: Metadata object for the uploaded artifact, including its ID and path.
Example:
Request:
@ -579,35 +429,17 @@ async def upload_agent_task_artifacts(
"relative_path": "/my_folder/my_other_folder/",
"file_name": "main.py"
}
"""
agent = request["agent"]
""" # noqa: E501
agent: "Agent" = request["agent"]
if file is None:
return Response(
content=json.dumps({"error": "File must be specified"}),
status_code=404,
media_type="application/json",
)
raise HTTPException(status_code=400, detail="File must be specified")
try:
artifact = await agent.create_artifact(task_id, file, relative_path)
return Response(
content=artifact.json(),
status_code=200,
media_type="application/json",
)
return artifact
except Exception:
logger.exception(f"Error whilst trying to upload artifact: {task_id}")
return Response(
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
raise
@base_router.get(
@ -617,7 +449,7 @@ async def upload_agent_task_artifacts(
)
async def download_agent_task_artifact(
request: Request, task_id: str, artifact_id: str
) -> FileResponse:
) -> StreamingResponse:
"""
Downloads an artifact associated with a specific task.
@ -636,32 +468,9 @@ async def download_agent_task_artifact(
Response:
<file_content_of_artifact>
"""
agent = request["agent"]
agent: "Agent" = request["agent"]
try:
return await agent.get_artifact(task_id, artifact_id)
except NotFoundError:
logger.exception(f"Error whilst trying to download artifact: {task_id}")
return Response(
content=json.dumps(
{
"error": f"Artifact not found "
"- task_id: {task_id}, artifact_id: {artifact_id}"
}
),
status_code=404,
media_type="application/json",
)
except Exception:
logger.exception(f"Error whilst trying to download artifact: {task_id}")
return Response(
content=json.dumps(
{
"error": f"Internal server error "
"- task_id: {task_id}, artifact_id: {artifact_id}",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
raise

View File

@ -1 +1,3 @@
from .db import AgentDB
__all__ = ["AgentDB"]

View File

@ -4,23 +4,22 @@ It uses SQLite as the database and file store backend.
IT IS NOT ADVISED TO USE THIS IN PRODUCTION!
"""
import datetime
import logging
import math
import uuid
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Tuple
from sqlalchemy import (
JSON,
Boolean,
Column,
DateTime,
ForeignKey,
String,
create_engine,
)
from sqlalchemy import JSON, Boolean, DateTime, ForeignKey, create_engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import DeclarativeBase, joinedload, relationship, sessionmaker
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
joinedload,
mapped_column,
relationship,
sessionmaker,
)
from forge.utils.exceptions import NotFoundError
@ -32,18 +31,20 @@ logger = logging.getLogger(__name__)
class Base(DeclarativeBase):
pass
type_annotation_map = {
dict[str, Any]: JSON,
}
class TaskModel(Base):
__tablename__ = "tasks"
task_id = Column(String, primary_key=True, index=True)
input = Column(String)
additional_input = Column(JSON)
created_at = Column(DateTime, default=datetime.datetime.utcnow)
modified_at = Column(
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
task_id: Mapped[str] = mapped_column(primary_key=True, index=True)
input: Mapped[str]
additional_input: Mapped[dict[str, Any]] = mapped_column(default=dict)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
modified_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)
artifacts = relationship("ArtifactModel", back_populates="task")
@ -52,35 +53,35 @@ class TaskModel(Base):
class StepModel(Base):
__tablename__ = "steps"
step_id = Column(String, primary_key=True, index=True)
task_id = Column(String, ForeignKey("tasks.task_id"))
name = Column(String)
input = Column(String)
status = Column(String)
output = Column(String)
is_last = Column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow)
modified_at = Column(
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
step_id: Mapped[str] = mapped_column(primary_key=True, index=True)
task_id: Mapped[str] = mapped_column(ForeignKey("tasks.task_id"))
name: Mapped[str]
input: Mapped[str]
status: Mapped[str]
output: Mapped[Optional[str]]
is_last: Mapped[bool] = mapped_column(Boolean, default=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
modified_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)
additional_input = Column(JSON)
additional_output = Column(JSON)
additional_input: Mapped[dict[str, Any]] = mapped_column(default=dict)
additional_output: Mapped[Optional[dict[str, Any]]]
artifacts = relationship("ArtifactModel", back_populates="step")
class ArtifactModel(Base):
__tablename__ = "artifacts"
artifact_id = Column(String, primary_key=True, index=True)
task_id = Column(String, ForeignKey("tasks.task_id"))
step_id = Column(String, ForeignKey("steps.step_id"))
agent_created = Column(Boolean, default=False)
file_name = Column(String)
relative_path = Column(String)
created_at = Column(DateTime, default=datetime.datetime.utcnow)
modified_at = Column(
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
artifact_id: Mapped[str] = mapped_column(primary_key=True, index=True)
task_id: Mapped[str] = mapped_column(ForeignKey("tasks.task_id"))
step_id: Mapped[Optional[str]] = mapped_column(ForeignKey("steps.step_id"))
agent_created: Mapped[bool] = mapped_column(default=False)
file_name: Mapped[str]
relative_path: Mapped[str]
created_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
modified_at: Mapped[datetime] = mapped_column(
default=datetime.utcnow, onupdate=datetime.utcnow
)
step = relationship("StepModel", back_populates="artifacts")
@ -150,6 +151,10 @@ class AgentDB:
Base.metadata.create_all(self.engine)
self.Session = sessionmaker(bind=self.engine)
def close(self) -> None:
self.Session.close_all()
self.engine.dispose()
async def create_task(
self, input: Optional[str], additional_input: Optional[dict] = {}
) -> Task:
@ -172,8 +177,6 @@ class AgentDB:
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while creating task: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
logger.error(f"Unexpected error while creating task: {e}")
raise
@ -207,8 +210,6 @@ class AgentDB:
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while creating step: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
logger.error(f"Unexpected error while creating step: {e}")
raise
@ -237,7 +238,7 @@ class AgentDB:
session.close()
if self.debug_enabled:
logger.debug(
f"Artifact already exists with relative_path: {relative_path}"
f"Artifact {file_name} already exists at {relative_path}/"
)
return convert_to_artifact(existing_artifact)
@ -254,14 +255,12 @@ class AgentDB:
session.refresh(new_artifact)
if self.debug_enabled:
logger.debug(
f"Created new artifact with artifact_id: {new_artifact.artifact_id}"
f"Created new artifact with ID: {new_artifact.artifact_id}"
)
return convert_to_artifact(new_artifact)
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while creating step: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
logger.error(f"Unexpected error while creating step: {e}")
raise
@ -285,8 +284,6 @@ class AgentDB:
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while getting task: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
logger.error(f"Unexpected error while getting task: {e}")
raise
@ -312,8 +309,6 @@ class AgentDB:
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while getting step: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
logger.error(f"Unexpected error while getting step: {e}")
raise
@ -337,8 +332,6 @@ class AgentDB:
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while getting artifact: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
logger.error(f"Unexpected error while getting artifact: {e}")
raise
@ -375,14 +368,13 @@ class AgentDB:
return await self.get_step(task_id, step_id)
else:
logger.error(
f"Step not found for update with task_id: {task_id} and step_id: {step_id}"
"Can't update non-existent Step with "
f"task_id: {task_id} and step_id: {step_id}"
)
raise NotFoundError("Step not found")
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while getting step: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
logger.error(f"Unexpected error while getting step: {e}")
raise
@ -441,8 +433,6 @@ class AgentDB:
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while listing tasks: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
logger.error(f"Unexpected error while listing tasks: {e}")
raise
@ -475,8 +465,6 @@ class AgentDB:
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while listing steps: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
logger.error(f"Unexpected error while listing steps: {e}")
raise
@ -509,8 +497,6 @@ class AgentDB:
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while listing artifacts: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
logger.error(f"Unexpected error while listing artifacts: {e}")
raise

View File

@ -22,14 +22,27 @@ from forge.agent_protocol.models import (
)
from forge.utils.exceptions import NotFoundError as DataNotFoundError
TEST_DB_FILENAME = "test_db.sqlite3"
TEST_DB_URL = f"sqlite:///{TEST_DB_FILENAME}"
@pytest.mark.asyncio
def test_table_creation():
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
conn = sqlite3.connect("test_db.sqlite3")
cursor = conn.cursor()
@pytest.fixture
def agent_db():
db = AgentDB(TEST_DB_URL)
yield db
db.close()
os.remove(TEST_DB_FILENAME)
@pytest.fixture
def raw_db_connection(agent_db: AgentDB):
connection = sqlite3.connect(TEST_DB_FILENAME)
yield connection
connection.close()
def test_table_creation(raw_db_connection: sqlite3.Connection):
cursor = raw_db_connection.cursor()
# Test for tasks table existence
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='tasks'")
@ -45,8 +58,6 @@ def test_table_creation():
)
assert cursor.fetchone() is not None
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio
async def test_task_schema():
@ -84,7 +95,10 @@ async def test_step_schema():
name="Write to file",
input="Write the words you receive to the file 'output.txt'.",
status=StepStatus.created,
output="I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')>",
output=(
"I am going to use the write_to_file command and write Washington "
"to a file called output.txt <write_to_file('output.txt', 'Washington')>"
),
artifacts=[
Artifact(
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
@ -101,13 +115,13 @@ async def test_step_schema():
assert step.step_id == "6bb1801a-fd80-45e8-899a-4dd723cc602e"
assert step.name == "Write to file"
assert step.status == StepStatus.created
assert (
step.output
== "I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')>"
assert step.output == (
"I am going to use the write_to_file command and write Washington "
"to a file called output.txt <write_to_file('output.txt', 'Washington')>"
)
assert len(step.artifacts) == 1
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
assert step.is_last == False
assert step.is_last is False
@pytest.mark.asyncio
@ -118,6 +132,7 @@ async def test_convert_to_task():
created_at=now,
modified_at=now,
input="Write the words you receive to the file 'output.txt'.",
additional_input={},
artifacts=[
ArtifactModel(
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
@ -147,6 +162,7 @@ async def test_convert_to_step():
name="Write to file",
status="created",
input="Write the words you receive to the file 'output.txt'.",
additional_input={},
artifacts=[
ArtifactModel(
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
@ -166,7 +182,7 @@ async def test_convert_to_step():
assert step.status == StepStatus.created
assert len(step.artifacts) == 1
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
assert step.is_last == False
assert step.is_last is False
@pytest.mark.asyncio
@ -183,91 +199,67 @@ async def test_convert_to_artifact():
artifact = convert_to_artifact(artifact_model)
assert artifact.artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
assert artifact.relative_path == "file:///path/to/main.py"
assert artifact.agent_created == True
assert artifact.agent_created is True
@pytest.mark.asyncio
async def test_create_task():
# Having issues with pytest fixture so added setup and teardown in each test as a rapid workaround
# TODO: Fix this!
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
async def test_create_task(agent_db: AgentDB):
task = await agent_db.create_task("task_input")
assert task.input == "task_input"
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio
async def test_create_and_get_task():
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
async def test_create_and_get_task(agent_db: AgentDB):
task = await agent_db.create_task("test_input")
fetched_task = await agent_db.get_task(task.task_id)
assert fetched_task.input == "test_input"
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio
async def test_get_task_not_found():
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
async def test_get_task_not_found(agent_db: AgentDB):
with pytest.raises(DataNotFoundError):
await agent_db.get_task(9999)
os.remove(db_name.split("///")[1])
await agent_db.get_task("9999")
@pytest.mark.asyncio
async def test_create_and_get_step():
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
async def test_create_and_get_step(agent_db: AgentDB):
task = await agent_db.create_task("task_input")
step_input = StepInput(type="python/code")
step_input = {"type": "python/code"}
request = StepRequestBody(input="test_input debug", additional_input=step_input)
step = await agent_db.create_step(task.task_id, request)
step = await agent_db.get_step(task.task_id, step.step_id)
assert step.input == "test_input debug"
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio
async def test_updating_step():
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
async def test_updating_step(agent_db: AgentDB):
created_task = await agent_db.create_task("task_input")
step_input = StepInput(type="python/code")
step_input = {"type": "python/code"}
request = StepRequestBody(input="test_input debug", additional_input=step_input)
created_step = await agent_db.create_step(created_task.task_id, request)
await agent_db.update_step(created_task.task_id, created_step.step_id, "completed")
step = await agent_db.get_step(created_task.task_id, created_step.step_id)
assert step.status.value == "completed"
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio
async def test_get_step_not_found():
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
async def test_get_step_not_found(agent_db: AgentDB):
with pytest.raises(DataNotFoundError):
await agent_db.get_step(9999, 9999)
os.remove(db_name.split("///")[1])
await agent_db.get_step("9999", "9999")
@pytest.mark.asyncio
async def test_get_artifact():
db_name = "sqlite:///test_db.sqlite3"
db = AgentDB(db_name)
async def test_get_artifact(agent_db: AgentDB):
# Given: A task and its corresponding artifact
task = await db.create_task("test_input debug")
step_input = StepInput(type="python/code")
task = await agent_db.create_task("test_input debug")
step_input = {"type": "python/code"}
requst = StepRequestBody(input="test_input debug", additional_input=step_input)
step = await db.create_step(task.task_id, requst)
step = await agent_db.create_step(task.task_id, requst)
# Create an artifact
artifact = await db.create_artifact(
artifact = await agent_db.create_artifact(
task_id=task.task_id,
file_name="test_get_artifact_sample_file.txt",
relative_path="file:///path/to/test_get_artifact_sample_file.txt",
@ -276,7 +268,7 @@ async def test_get_artifact():
)
# When: The artifact is fetched by its ID
fetched_artifact = await db.get_artifact(artifact.artifact_id)
fetched_artifact = await agent_db.get_artifact(artifact.artifact_id)
# Then: The fetched artifact matches the original
assert fetched_artifact.artifact_id == artifact.artifact_id
@ -285,47 +277,37 @@ async def test_get_artifact():
== "file:///path/to/test_get_artifact_sample_file.txt"
)
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio
async def test_list_tasks():
db_name = "sqlite:///test_db.sqlite3"
db = AgentDB(db_name)
async def test_list_tasks(agent_db: AgentDB):
# Given: Multiple tasks in the database
task1 = await db.create_task("test_input_1")
task2 = await db.create_task("test_input_2")
task1 = await agent_db.create_task("test_input_1")
task2 = await agent_db.create_task("test_input_2")
# When: All tasks are fetched
fetched_tasks, pagination = await db.list_tasks()
fetched_tasks, pagination = await agent_db.list_tasks()
# Then: The fetched tasks list includes the created tasks
task_ids = [task.task_id for task in fetched_tasks]
assert task1.task_id in task_ids
assert task2.task_id in task_ids
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio
async def test_list_steps():
db_name = "sqlite:///test_db.sqlite3"
db = AgentDB(db_name)
step_input = StepInput(type="python/code")
requst = StepRequestBody(input="test_input debug", additional_input=step_input)
async def test_list_steps(agent_db: AgentDB):
step_input = {"type": "python/code"}
request = StepRequestBody(input="test_input debug", additional_input=step_input)
# Given: A task and multiple steps for that task
task = await db.create_task("test_input")
step1 = await db.create_step(task.task_id, requst)
requst = StepRequestBody(input="step two", additional_input=step_input)
step2 = await db.create_step(task.task_id, requst)
task = await agent_db.create_task("test_input")
step1 = await agent_db.create_step(task.task_id, request)
request = StepRequestBody(input="step two")
step2 = await agent_db.create_step(task.task_id, request)
# When: All steps for the task are fetched
fetched_steps, pagination = await db.list_steps(task.task_id)
fetched_steps, pagination = await agent_db.list_steps(task.task_id)
# Then: The fetched steps list includes the created steps
step_ids = [step.step_id for step in fetched_steps]
assert step1.step_id in step_ids
assert step2.step_id in step_ids
os.remove(db_name.split("///")[1])

View File

@ -1,4 +1,4 @@
from .artifact import Artifact, ArtifactUpload
from .artifact import Artifact
from .pagination import Pagination
from .task import (
Step,
@ -10,3 +10,16 @@ from .task import (
TaskRequestBody,
TaskStepsListResponse,
)
__all__ = [
"Artifact",
"Pagination",
"Step",
"StepRequestBody",
"StepStatus",
"Task",
"TaskArtifactsListResponse",
"TaskListResponse",
"TaskRequestBody",
"TaskStepsListResponse",
]

View File

@ -3,15 +3,6 @@ from datetime import datetime
from pydantic import BaseModel, Field
class ArtifactUpload(BaseModel):
file: str = Field(..., description="File to upload.", format="binary")
relative_path: str = Field(
...,
description="Relative path of the artifact in the agent's workspace.",
example="python/code",
)
class Artifact(BaseModel):
created_at: datetime = Field(
...,

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import List, Optional
from typing import Any, List, Optional
from pydantic import BaseModel, Field
@ -17,7 +17,7 @@ class TaskRequestBody(BaseModel):
description="Input prompt for the task.",
example="Write the words you receive to the file 'output.txt'.",
)
additional_input: Optional[dict] = None
additional_input: dict[str, Any] = Field(default_factory=dict)
class Task(TaskRequestBody):
@ -38,8 +38,8 @@ class Task(TaskRequestBody):
description="The ID of the task.",
example="50da533e-3904-4401-8a07-c49adf88b5eb",
)
artifacts: Optional[List[Artifact]] = Field(
[],
artifacts: list[Artifact] = Field(
default_factory=list,
description="A list of artifacts that the task has produced.",
example=[
"7a49f31c-f9c6-4346-a22c-e32bc5af4d8e",
@ -50,14 +50,12 @@ class Task(TaskRequestBody):
class StepRequestBody(BaseModel):
name: Optional[str] = Field(
None, description="The name of the task step.", example="Write to file"
default=None, description="The name of the task step.", example="Write to file"
)
input: Optional[str] = Field(
None,
description="Input prompt for the step.",
example="Washington",
input: str = Field(
..., description="Input prompt for the step.", example="Washington"
)
additional_input: Optional[dict] = None
additional_input: dict[str, Any] = Field(default_factory=dict)
class StepStatus(Enum):
@ -90,19 +88,23 @@ class Step(StepRequestBody):
example="6bb1801a-fd80-45e8-899a-4dd723cc602e",
)
name: Optional[str] = Field(
None, description="The name of the task step.", example="Write to file"
default=None, description="The name of the task step.", example="Write to file"
)
status: StepStatus = Field(
..., description="The status of the task step.", example="created"
)
output: Optional[str] = Field(
None,
default=None,
description="Output of the task step.",
example="I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')",
example=(
"I am going to use the write_to_file command and write Washington "
"to a file called output.txt <write_to_file('output.txt', 'Washington')"
),
)
additional_output: Optional[dict] = None
artifacts: Optional[List[Artifact]] = Field(
[], description="A list of artifacts that the step has produced."
additional_output: Optional[dict[str, Any]] = None
artifacts: list[Artifact] = Field(
default_factory=list,
description="A list of artifacts that the step has produced.",
)
is_last: bool = Field(
..., description="Whether this is the last step in the task.", example=True

View File

@ -1,3 +1,5 @@
from .command import Command, CommandOutput, CommandParameter
from .command import Command
from .decorator import command
from .parameter import CommandParameter
__all__ = ["Command", "CommandParameter", "command"]

View File

@ -1,14 +1,16 @@
from __future__ import annotations
import inspect
from typing import Any, Callable, Generic, ParamSpec, TypeVar
from typing import Callable, Concatenate, Generic, ParamSpec, TypeVar, cast
from forge.agent.protocols import CommandProvider
from .parameter import CommandParameter
CommandOutput = Any
P = ParamSpec("P")
CO = TypeVar("CO", bound=CommandOutput)
CO = TypeVar("CO") # command output
_CP = TypeVar("_CP", bound=CommandProvider)
class Command(Generic[P, CO]):
@ -24,7 +26,7 @@ class Command(Generic[P, CO]):
self,
names: list[str],
description: str,
method: Callable[P, CO],
method: Callable[Concatenate[_CP, P], CO],
parameters: list[CommandParameter],
):
# Check if all parameters are provided
@ -34,7 +36,9 @@ class Command(Generic[P, CO]):
)
self.names = names
self.description = description
self.method = method
# Method technically has a `self` parameter, but we can ignore that
# since Python passes it internally.
self.method = cast(Callable[P, CO], method)
self.parameters = parameters
@property
@ -62,7 +66,8 @@ class Command(Generic[P, CO]):
def __str__(self) -> str:
params = [
f"{param.name}: "
+ ("%s" if param.spec.required else "Optional[%s]") % param.spec.type.value
+ ("%s" if param.spec.required else "Optional[%s]")
% (param.spec.type.value if param.spec.type else "Any")
for param in self.parameters
]
return (

View File

@ -1,2 +1,4 @@
from .action_history import ActionHistoryComponent
from .model import Episode, EpisodicActionHistory
__all__ = ["ActionHistoryComponent", "Episode", "EpisodicActionHistory"]

View File

@ -1,27 +1,27 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Callable, Generic, Iterator, Optional
from typing import TYPE_CHECKING, Callable, Iterator, Optional
from forge.agent.protocols import AfterExecute, AfterParse, MessageProvider
from forge.llm.prompting.utils import indent
from forge.llm.providers import ChatMessage, ChatModelProvider
from forge.llm.providers import ChatMessage, MultiProvider
if TYPE_CHECKING:
from forge.config.config import Config
from .model import AP, ActionResult, Episode, EpisodicActionHistory
from .model import ActionResult, AnyProposal, Episode, EpisodicActionHistory
class ActionHistoryComponent(MessageProvider, AfterParse, AfterExecute, Generic[AP]):
class ActionHistoryComponent(MessageProvider, AfterParse[AnyProposal], AfterExecute):
"""Keeps track of the event history and provides a summary of the steps."""
def __init__(
self,
event_history: EpisodicActionHistory[AP],
event_history: EpisodicActionHistory[AnyProposal],
max_tokens: int,
count_tokens: Callable[[str], int],
legacy_config: Config,
llm_provider: ChatModelProvider,
llm_provider: MultiProvider,
) -> None:
self.event_history = event_history
self.max_tokens = max_tokens
@ -37,7 +37,7 @@ class ActionHistoryComponent(MessageProvider, AfterParse, AfterExecute, Generic[
):
yield ChatMessage.system(f"## Progress on your Task so far\n\n{progress}")
def after_parse(self, result: AP) -> None:
def after_parse(self, result: AnyProposal) -> None:
self.event_history.register_action(result)
async def after_execute(self, result: ActionResult) -> None:
@ -48,7 +48,7 @@ class ActionHistoryComponent(MessageProvider, AfterParse, AfterExecute, Generic[
def _compile_progress(
self,
episode_history: list[Episode],
episode_history: list[Episode[AnyProposal]],
max_tokens: Optional[int] = None,
count_tokens: Optional[Callable[[str], int]] = None,
) -> str:

View File

@ -1,25 +1,23 @@
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING, Generic, Iterator, TypeVar
from typing import TYPE_CHECKING, Generic
from pydantic import Field
from pydantic.generics import GenericModel
from forge.content_processing.text import summarize_text
from forge.llm.prompting.utils import format_numbered_list, indent
from forge.models.action import ActionProposal, ActionResult
from forge.models.action import ActionResult, AnyProposal
from forge.models.utils import ModelWithSummary
if TYPE_CHECKING:
from forge.config.config import Config
from forge.llm.providers import ChatModelProvider
AP = TypeVar("AP", bound=ActionProposal)
from forge.llm.providers import MultiProvider
class Episode(GenericModel, Generic[AP]):
action: AP
class Episode(GenericModel, Generic[AnyProposal]):
action: AnyProposal
result: ActionResult | None
summary: str | None = None
@ -54,32 +52,29 @@ class Episode(GenericModel, Generic[AP]):
return executed_action + action_result
class EpisodicActionHistory(GenericModel, Generic[AP]):
class EpisodicActionHistory(GenericModel, Generic[AnyProposal]):
"""Utility container for an action history"""
episodes: list[Episode[AP]] = Field(default_factory=list)
episodes: list[Episode[AnyProposal]] = Field(default_factory=list)
cursor: int = 0
_lock = asyncio.Lock()
@property
def current_episode(self) -> Episode[AP] | None:
def current_episode(self) -> Episode[AnyProposal] | None:
if self.cursor == len(self):
return None
return self[self.cursor]
def __getitem__(self, key: int) -> Episode[AP]:
def __getitem__(self, key: int) -> Episode[AnyProposal]:
return self.episodes[key]
def __iter__(self) -> Iterator[Episode[AP]]:
return iter(self.episodes)
def __len__(self) -> int:
return len(self.episodes)
def __bool__(self) -> bool:
return len(self.episodes) > 0
def register_action(self, action: AP) -> None:
def register_action(self, action: AnyProposal) -> None:
if not self.current_episode:
self.episodes.append(Episode(action=action, result=None))
assert self.current_episode
@ -113,7 +108,7 @@ class EpisodicActionHistory(GenericModel, Generic[AP]):
self.cursor = len(self.episodes)
async def handle_compression(
self, llm_provider: ChatModelProvider, app_config: Config
self, llm_provider: MultiProvider, app_config: Config
) -> None:
"""Compresses each episode in the action history using an LLM.

View File

@ -3,6 +3,11 @@ from .code_executor import (
DENYLIST_CONTROL,
CodeExecutionError,
CodeExecutorComponent,
is_docker_available,
we_are_running_in_a_docker_container,
)
__all__ = [
"ALLOWLIST_CONTROL",
"DENYLIST_CONTROL",
"CodeExecutionError",
"CodeExecutorComponent",
]

Some files were not shown because too many files have changed in this diff Show More