mirror of
https://github.com/Significant-Gravitas/Auto-GPT.git
synced 2025-01-08 11:57:32 +08:00
Add File Downloading Capabilities (#1680)
* Added 'download_file' command * Added util and fixed spinner * Fixed comma and added autogpt/auto_gpt_workspace to .gitignore * Fix linter issues * Fix more linter issues * Fix Lint Issues * Added 'download_file' command * Added util and fixed spinner * Fixed comma and added autogpt/auto_gpt_workspace to .gitignore * Fix linter issues * Fix more linter issues * Conditionally add the 'download_file' prompt * Update args.py * Removed Duplicate Prompt * Switched to using path_in_workspace function
This commit is contained in:
parent
0409079983
commit
9589334a30
1
.gitignore
vendored
1
.gitignore
vendored
@ -3,6 +3,7 @@ autogpt/keys.py
|
||||
autogpt/*json
|
||||
autogpt/node_modules/
|
||||
autogpt/__pycache__/keys.cpython-310.pyc
|
||||
autogpt/auto_gpt_workspace
|
||||
package-lock.json
|
||||
*.pyc
|
||||
auto_gpt_workspace/*
|
||||
|
@ -17,6 +17,7 @@ from autogpt.commands.file_operations import (
|
||||
read_file,
|
||||
search_files,
|
||||
write_to_file,
|
||||
download_file
|
||||
)
|
||||
from autogpt.json_fixes.parsing import fix_and_parse_json
|
||||
from autogpt.memory import get_memory
|
||||
@ -164,6 +165,10 @@ def execute_command(command_name: str, arguments):
|
||||
return delete_file(arguments["file"])
|
||||
elif command_name == "search_files":
|
||||
return search_files(arguments["directory"])
|
||||
elif command_name == "download_file":
|
||||
if not CFG.allow_downloads:
|
||||
return "Error: You do not have user authorization to download files locally."
|
||||
return download_file(arguments["url"], arguments["file"])
|
||||
elif command_name == "browse_website":
|
||||
return browse_website(arguments["url"], arguments["question"])
|
||||
# TODO: Change these to take in a file rather than pasted code, if
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""This module contains the argument parsing logic for the script."""
|
||||
import argparse
|
||||
|
||||
from colorama import Fore
|
||||
from colorama import Fore, Back, Style
|
||||
from autogpt import utils
|
||||
from autogpt.config import Config
|
||||
from autogpt.logs import logger
|
||||
@ -63,6 +63,12 @@ def parse_arguments() -> None:
|
||||
help="Specifies which ai_settings.yaml file to use, will also automatically"
|
||||
" skip the re-prompt.",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--allow-downloads',
|
||||
action='store_true',
|
||||
dest='allow_downloads',
|
||||
help='Dangerous: Allows Auto-GPT to download files natively.'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.debug:
|
||||
@ -133,5 +139,13 @@ def parse_arguments() -> None:
|
||||
CFG.ai_settings_file = file
|
||||
CFG.skip_reprompt = True
|
||||
|
||||
if args.allow_downloads:
|
||||
logger.typewriter_log("Native Downloading:", Fore.GREEN, "ENABLED")
|
||||
logger.typewriter_log("WARNING: ", Fore.YELLOW,
|
||||
f"{Back.LIGHTYELLOW_EX}Auto-GPT will now be able to download and save files to your machine.{Back.RESET} " +
|
||||
"It is recommended that you monitor any files it downloads carefully.")
|
||||
logger.typewriter_log("WARNING: ", Fore.YELLOW, f"{Back.RED + Style.BRIGHT}ALWAYS REMEMBER TO NEVER OPEN FILES YOU AREN'T SURE OF!{Style.RESET_ALL}")
|
||||
CFG.allow_downloads = True
|
||||
|
||||
if args.browser_name:
|
||||
CFG.selenium_web_browser = args.browser_name
|
||||
|
@ -4,9 +4,16 @@ from __future__ import annotations
|
||||
import os
|
||||
import os.path
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
from typing import Generator, List
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.adapters import Retry
|
||||
from colorama import Fore, Back
|
||||
from autogpt.spinner import Spinner
|
||||
from autogpt.utils import readable_file_size
|
||||
from autogpt.workspace import path_in_workspace, WORKSPACE_PATH
|
||||
|
||||
|
||||
LOG_FILE = "file_logger.txt"
|
||||
LOG_FILE_PATH = WORKSPACE_PATH / LOG_FILE
|
||||
|
||||
@ -214,3 +221,43 @@ def search_files(directory: str) -> list[str]:
|
||||
found_files.append(relative_path)
|
||||
|
||||
return found_files
|
||||
|
||||
|
||||
def download_file(url, filename):
|
||||
"""Downloads a file
|
||||
Args:
|
||||
url (str): URL of the file to download
|
||||
filename (str): Filename to save the file as
|
||||
"""
|
||||
safe_filename = path_in_workspace(filename)
|
||||
try:
|
||||
message = f"{Fore.YELLOW}Downloading file from {Back.LIGHTBLUE_EX}{url}{Back.RESET}{Fore.RESET}"
|
||||
with Spinner(message) as spinner:
|
||||
session = requests.Session()
|
||||
retry = Retry(total=3, backoff_factor=1, status_forcelist=[502, 503, 504])
|
||||
adapter = HTTPAdapter(max_retries=retry)
|
||||
session.mount('http://', adapter)
|
||||
session.mount('https://', adapter)
|
||||
|
||||
total_size = 0
|
||||
downloaded_size = 0
|
||||
|
||||
with session.get(url, allow_redirects=True, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
total_size = int(r.headers.get('Content-Length', 0))
|
||||
downloaded_size = 0
|
||||
|
||||
with open(safe_filename, 'wb') as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
downloaded_size += len(chunk)
|
||||
|
||||
# Update the progress message
|
||||
progress = f"{readable_file_size(downloaded_size)} / {readable_file_size(total_size)}"
|
||||
spinner.update_message(f"{message} {progress}")
|
||||
|
||||
return f'Successfully downloaded and locally stored file: "{filename}"! (Size: {readable_file_size(total_size)})'
|
||||
except requests.HTTPError as e:
|
||||
return f"Got an HTTP Error whilst trying to download file: {e}"
|
||||
except Exception as e:
|
||||
return "Error: " + str(e)
|
||||
|
@ -24,6 +24,7 @@ class Config(metaclass=Singleton):
|
||||
self.continuous_limit = 0
|
||||
self.speak_mode = False
|
||||
self.skip_reprompt = False
|
||||
self.allow_downloads = False
|
||||
|
||||
self.selenium_web_browser = os.getenv("USE_WEB_BROWSER", "chrome")
|
||||
self.ai_settings_file = os.getenv("AI_SETTINGS_FILE", "ai_settings.yaml")
|
||||
|
@ -105,6 +105,16 @@ def get_prompt() -> str:
|
||||
),
|
||||
)
|
||||
|
||||
# Only add the download file command if the AI is allowed to execute it
|
||||
if cfg.allow_downloads:
|
||||
commands.append(
|
||||
(
|
||||
"Downloads a file from the internet, and stores it locally",
|
||||
"download_file",
|
||||
{"url": "<file_url>", "file": "<saved_filename>"}
|
||||
),
|
||||
)
|
||||
|
||||
# Add these command last.
|
||||
commands.append(
|
||||
("Do Nothing", "do_nothing", {}),
|
||||
|
@ -29,12 +29,14 @@ class Spinner:
|
||||
time.sleep(self.delay)
|
||||
sys.stdout.write(f"\r{' ' * (len(self.message) + 2)}\r")
|
||||
|
||||
def __enter__(self) -> None:
|
||||
def __enter__(self):
|
||||
"""Start the spinner"""
|
||||
self.running = True
|
||||
self.spinner_thread = threading.Thread(target=self.spin)
|
||||
self.spinner_thread.start()
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
|
||||
"""Stop the spinner
|
||||
|
||||
@ -48,3 +50,14 @@ class Spinner:
|
||||
self.spinner_thread.join()
|
||||
sys.stdout.write(f"\r{' ' * (len(self.message) + 2)}\r")
|
||||
sys.stdout.flush()
|
||||
|
||||
def update_message(self, new_message, delay=0.1):
|
||||
"""Update the spinner message
|
||||
Args:
|
||||
new_message (str): New message to display
|
||||
delay: Delay in seconds before updating the message
|
||||
"""
|
||||
time.sleep(delay)
|
||||
sys.stdout.write(f"\r{' ' * (len(self.message) + 2)}\r") # Clear the current message
|
||||
sys.stdout.flush()
|
||||
self.message = new_message
|
||||
|
@ -24,3 +24,16 @@ def validate_yaml_file(file: str):
|
||||
)
|
||||
|
||||
return (True, f"Successfully validated {Fore.CYAN}`{file}`{Fore.RESET}!")
|
||||
|
||||
|
||||
def readable_file_size(size, decimal_places=2):
|
||||
"""Converts the given size in bytes to a readable format.
|
||||
Args:
|
||||
size: Size in bytes
|
||||
decimal_places (int): Number of decimal places to display
|
||||
"""
|
||||
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
|
||||
if size < 1024.0:
|
||||
break
|
||||
size /= 1024.0
|
||||
return f"{size:.{decimal_places}f} {unit}"
|
||||
|
Loading…
Reference in New Issue
Block a user