From 6e5e217f64e93ac041bd009c7a1c8960a0f4ebc8 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Wed, 29 Nov 2023 02:04:57 +0000 Subject: [PATCH] Support remote input --- scripts/inference/endpoint_generate.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/scripts/inference/endpoint_generate.py b/scripts/inference/endpoint_generate.py index df6cba8366..d94676a2e1 100644 --- a/scripts/inference/endpoint_generate.py +++ b/scripts/inference/endpoint_generate.py @@ -17,7 +17,8 @@ import pandas as pd import requests -from composer.utils import maybe_create_object_store_from_uri, parse_uri +from composer.utils import (get_file, maybe_create_object_store_from_uri, + parse_uri) from llmfoundry.utils import prompt_files as utils @@ -41,8 +42,8 @@ def parse_args() -> Namespace: '-p', '--prompts', nargs='+', - help='List of generation prompts or list of delimited files. Use syntax ' +\ - '"file::/path/to/prompt.txt" to load a prompt(s) contained in a txt file.' + help=f'List of strings, local datafiles (starting with {utils.PROMPTFILE_PREFIX}),' +\ + ' and/or remote object stores' ) parser.add_argument( '--prompt-delimiter', @@ -113,7 +114,22 @@ async def main(args: Namespace) -> None: if not api_key: log.warning(f'API key not set in {ENDPOINT_API_KEY_ENV}') - prompt_strings = utils.load_prompts(args.prompts, args.prompt_delimiter) + new_prompts = [] + for prompt in args.prompts: + if prompt.startswith(utils.PROMPTFILE_PREFIX): + new_prompts.append(prompt) + continue + + input_object_store = maybe_create_object_store_from_uri(prompt) + if input_object_store is not None: + local_output_path = tempfile.TemporaryDirectory().name + get_file(prompt, str(local_output_path)) + log.info(f'Downloaded {prompt} to {local_output_path}') + prompt = f'{utils.PROMPTFILE_PREFIX}{local_output_path}' + + new_prompts.append(prompt) + + prompt_strings = utils.load_prompts(new_prompts, args.prompt_delimiter) cols = ['batch', 'prompt', 'output'] param_data = { @@ -194,6 +210,7 @@ async def generate(session: aiohttp.ClientSession, batch: int, _, _, output_folder_prefix = parse_uri(args.output_folder) remote_path = os.path.join(output_folder_prefix, file) output_object_store.upload_object(remote_path, local_path) + output_object_store.download_object log.info(f'Uploaded results to {args.output_folder}/{file}') else: output_dir, _ = os.path.split(args.output_folder)