Add File Downloading Capabilities (#1680)

* Added 'download_file' command

* Added util and fixed spinner

* Fixed comma and added autogpt/auto_gpt_workspace to .gitignore

* Fix linter issues

* Fix more linter issues

* Fix Lint Issues

* Added 'download_file' command

* Added util and fixed spinner

* Fixed comma and added autogpt/auto_gpt_workspace to .gitignore

* Fix linter issues

* Fix more linter issues

* Conditionally add the 'download_file' prompt

* Update args.py

* Removed Duplicate Prompt

* Switched to using path_in_workspace function
This commit is contained in:
EH 2023-04-17 03:34:02 +01:00 committed by GitHub
parent 0409079983
commit 9589334a30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 107 additions and 3 deletions

1
.gitignore vendored
View File

@ -3,6 +3,7 @@ autogpt/keys.py
autogpt/*json
autogpt/node_modules/
autogpt/__pycache__/keys.cpython-310.pyc
autogpt/auto_gpt_workspace
package-lock.json
*.pyc
auto_gpt_workspace/*

View File

@ -17,6 +17,7 @@ from autogpt.commands.file_operations import (
read_file,
search_files,
write_to_file,
download_file
)
from autogpt.json_fixes.parsing import fix_and_parse_json
from autogpt.memory import get_memory
@ -164,6 +165,10 @@ def execute_command(command_name: str, arguments):
return delete_file(arguments["file"])
elif command_name == "search_files":
return search_files(arguments["directory"])
elif command_name == "download_file":
if not CFG.allow_downloads:
return "Error: You do not have user authorization to download files locally."
return download_file(arguments["url"], arguments["file"])
elif command_name == "browse_website":
return browse_website(arguments["url"], arguments["question"])
# TODO: Change these to take in a file rather than pasted code, if

View File

@ -1,7 +1,7 @@
"""This module contains the argument parsing logic for the script."""
import argparse
from colorama import Fore
from colorama import Fore, Back, Style
from autogpt import utils
from autogpt.config import Config
from autogpt.logs import logger
@ -63,6 +63,12 @@ def parse_arguments() -> None:
help="Specifies which ai_settings.yaml file to use, will also automatically"
" skip the re-prompt.",
)
parser.add_argument(
'--allow-downloads',
action='store_true',
dest='allow_downloads',
help='Dangerous: Allows Auto-GPT to download files natively.'
)
args = parser.parse_args()
if args.debug:
@ -133,5 +139,13 @@ def parse_arguments() -> None:
CFG.ai_settings_file = file
CFG.skip_reprompt = True
if args.allow_downloads:
logger.typewriter_log("Native Downloading:", Fore.GREEN, "ENABLED")
logger.typewriter_log("WARNING: ", Fore.YELLOW,
f"{Back.LIGHTYELLOW_EX}Auto-GPT will now be able to download and save files to your machine.{Back.RESET} " +
"It is recommended that you monitor any files it downloads carefully.")
logger.typewriter_log("WARNING: ", Fore.YELLOW, f"{Back.RED + Style.BRIGHT}ALWAYS REMEMBER TO NEVER OPEN FILES YOU AREN'T SURE OF!{Style.RESET_ALL}")
CFG.allow_downloads = True
if args.browser_name:
CFG.selenium_web_browser = args.browser_name

View File

@ -4,9 +4,16 @@ from __future__ import annotations
import os
import os.path
from pathlib import Path
from typing import Generator
from typing import Generator, List
import requests
from requests.adapters import HTTPAdapter
from requests.adapters import Retry
from colorama import Fore, Back
from autogpt.spinner import Spinner
from autogpt.utils import readable_file_size
from autogpt.workspace import path_in_workspace, WORKSPACE_PATH
LOG_FILE = "file_logger.txt"
LOG_FILE_PATH = WORKSPACE_PATH / LOG_FILE
@ -214,3 +221,43 @@ def search_files(directory: str) -> list[str]:
found_files.append(relative_path)
return found_files
def download_file(url, filename):
"""Downloads a file
Args:
url (str): URL of the file to download
filename (str): Filename to save the file as
"""
safe_filename = path_in_workspace(filename)
try:
message = f"{Fore.YELLOW}Downloading file from {Back.LIGHTBLUE_EX}{url}{Back.RESET}{Fore.RESET}"
with Spinner(message) as spinner:
session = requests.Session()
retry = Retry(total=3, backoff_factor=1, status_forcelist=[502, 503, 504])
adapter = HTTPAdapter(max_retries=retry)
session.mount('http://', adapter)
session.mount('https://', adapter)
total_size = 0
downloaded_size = 0
with session.get(url, allow_redirects=True, stream=True) as r:
r.raise_for_status()
total_size = int(r.headers.get('Content-Length', 0))
downloaded_size = 0
with open(safe_filename, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
downloaded_size += len(chunk)
# Update the progress message
progress = f"{readable_file_size(downloaded_size)} / {readable_file_size(total_size)}"
spinner.update_message(f"{message} {progress}")
return f'Successfully downloaded and locally stored file: "{filename}"! (Size: {readable_file_size(total_size)})'
except requests.HTTPError as e:
return f"Got an HTTP Error whilst trying to download file: {e}"
except Exception as e:
return "Error: " + str(e)

View File

@ -24,6 +24,7 @@ class Config(metaclass=Singleton):
self.continuous_limit = 0
self.speak_mode = False
self.skip_reprompt = False
self.allow_downloads = False
self.selenium_web_browser = os.getenv("USE_WEB_BROWSER", "chrome")
self.ai_settings_file = os.getenv("AI_SETTINGS_FILE", "ai_settings.yaml")

View File

@ -105,6 +105,16 @@ def get_prompt() -> str:
),
)
# Only add the download file command if the AI is allowed to execute it
if cfg.allow_downloads:
commands.append(
(
"Downloads a file from the internet, and stores it locally",
"download_file",
{"url": "<file_url>", "file": "<saved_filename>"}
),
)
# Add these command last.
commands.append(
("Do Nothing", "do_nothing", {}),

View File

@ -29,12 +29,14 @@ class Spinner:
time.sleep(self.delay)
sys.stdout.write(f"\r{' ' * (len(self.message) + 2)}\r")
def __enter__(self) -> None:
def __enter__(self):
"""Start the spinner"""
self.running = True
self.spinner_thread = threading.Thread(target=self.spin)
self.spinner_thread.start()
return self
def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
"""Stop the spinner
@ -48,3 +50,14 @@ class Spinner:
self.spinner_thread.join()
sys.stdout.write(f"\r{' ' * (len(self.message) + 2)}\r")
sys.stdout.flush()
def update_message(self, new_message, delay=0.1):
"""Update the spinner message
Args:
new_message (str): New message to display
delay: Delay in seconds before updating the message
"""
time.sleep(delay)
sys.stdout.write(f"\r{' ' * (len(self.message) + 2)}\r") # Clear the current message
sys.stdout.flush()
self.message = new_message

View File

@ -24,3 +24,16 @@ def validate_yaml_file(file: str):
)
return (True, f"Successfully validated {Fore.CYAN}`{file}`{Fore.RESET}!")
def readable_file_size(size, decimal_places=2):
"""Converts the given size in bytes to a readable format.
Args:
size: Size in bytes
decimal_places (int): Number of decimal places to display
"""
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
if size < 1024.0:
break
size /= 1024.0
return f"{size:.{decimal_places}f} {unit}"