Skip to content

Commit

Permalink
Support remote input
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl committed Nov 29, 2023
1 parent 17d1718 commit 6e5e217
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions scripts/inference/endpoint_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

import pandas as pd
import requests
from composer.utils import maybe_create_object_store_from_uri, parse_uri
from composer.utils import (get_file, maybe_create_object_store_from_uri,
parse_uri)

from llmfoundry.utils import prompt_files as utils

Expand All @@ -41,8 +42,8 @@ def parse_args() -> Namespace:
'-p',
'--prompts',
nargs='+',
help='List of generation prompts or list of delimited files. Use syntax ' +\
'"file::/path/to/prompt.txt" to load a prompt(s) contained in a txt file.'
help=f'List of strings, local datafiles (starting with {utils.PROMPTFILE_PREFIX}),' +\
' and/or remote object stores'
)
parser.add_argument(
'--prompt-delimiter',
Expand Down Expand Up @@ -113,7 +114,22 @@ async def main(args: Namespace) -> None:
if not api_key:
log.warning(f'API key not set in {ENDPOINT_API_KEY_ENV}')

prompt_strings = utils.load_prompts(args.prompts, args.prompt_delimiter)
new_prompts = []
for prompt in args.prompts:
if prompt.startswith(utils.PROMPTFILE_PREFIX):
new_prompts.append(prompt)
continue

input_object_store = maybe_create_object_store_from_uri(prompt)
if input_object_store is not None:
local_output_path = tempfile.TemporaryDirectory().name
get_file(prompt, str(local_output_path))
log.info(f'Downloaded {prompt} to {local_output_path}')
prompt = f'{utils.PROMPTFILE_PREFIX}{local_output_path}'

new_prompts.append(prompt)

prompt_strings = utils.load_prompts(new_prompts, args.prompt_delimiter)

cols = ['batch', 'prompt', 'output']
param_data = {
Expand Down Expand Up @@ -194,6 +210,7 @@ async def generate(session: aiohttp.ClientSession, batch: int,
_, _, output_folder_prefix = parse_uri(args.output_folder)
remote_path = os.path.join(output_folder_prefix, file)
output_object_store.upload_object(remote_path, local_path)
output_object_store.download_object
log.info(f'Uploaded results to {args.output_folder}/{file}')
else:
output_dir, _ = os.path.split(args.output_folder)
Expand Down

0 comments on commit 6e5e217

Please sign in to comment.