-
Notifications
You must be signed in to change notification settings - Fork 534
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add script for doing bulk generation against an endpoint (#765)
* Add script for doing bulk generation against an endpoint * more logging * warn * fix * format * asdfads * Add warning * updates * folder -> file * remove blank line * Support remote input * prompts -> inputs
- Loading branch information
Showing
4 changed files
with
309 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright 2022 MosaicML LLM Foundry authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import os | ||
from typing import List, Optional | ||
|
||
PROMPTFILE_PREFIX = 'file::' | ||
|
||
|
||
def load_prompts(prompts: List[str], | ||
prompt_delimiter: Optional[str] = None) -> List[str]: | ||
"""Loads a set of prompts, both free text and from file. | ||
Args: | ||
prompts (List[str]): List of free text prompts and prompt files | ||
prompt_delimiter (Optional str): Delimiter for text file | ||
If not provided, assumes the prompt file is a single prompt (non-delimited) | ||
Returns: | ||
List of prompt string(s) | ||
""" | ||
prompt_strings = [] | ||
for prompt in prompts: | ||
if prompt.startswith(PROMPTFILE_PREFIX): | ||
prompts = load_prompts_from_file(prompt, prompt_delimiter) | ||
prompt_strings.extend(prompts) | ||
else: | ||
prompt_strings.append(prompt) | ||
return prompt_strings | ||
|
||
|
||
def load_prompts_from_file(prompt_path: str, | ||
prompt_delimiter: Optional[str] = None) -> List[str]: | ||
"""Load a set of prompts from a text fie. | ||
Args: | ||
prompt_path (str): Path for text file | ||
prompt_delimiter (Optional str): Delimiter for text file | ||
If not provided, assumes the prompt file is a single prompt (non-delimited) | ||
Returns: | ||
List of prompt string(s) | ||
""" | ||
if not prompt_path.startswith(PROMPTFILE_PREFIX): | ||
raise ValueError(f'prompt_path_str must start with {PROMPTFILE_PREFIX}') | ||
|
||
_, prompt_file_path = prompt_path.split(PROMPTFILE_PREFIX, maxsplit=1) | ||
prompt_file_path = os.path.expanduser(prompt_file_path) | ||
if not os.path.isfile(prompt_file_path): | ||
raise FileNotFoundError( | ||
f'{prompt_file_path=} does not match any existing files.') | ||
|
||
with open(prompt_file_path, 'r') as f: | ||
prompt_string = f.read() | ||
|
||
if prompt_delimiter is None: | ||
return [prompt_string] | ||
return [i for i in prompt_string.split(prompt_delimiter) if i] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
# Copyright 2022 MosaicML LLM Foundry authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Batch generate text completion results from an endpoint. | ||
Warning: This script is experimental and could change or be removed at any time | ||
""" | ||
|
||
import asyncio | ||
import copy | ||
import logging | ||
import math | ||
import os | ||
import tempfile | ||
import time | ||
from argparse import ArgumentParser, Namespace | ||
|
||
import pandas as pd | ||
import requests | ||
from composer.utils import (get_file, maybe_create_object_store_from_uri, | ||
parse_uri) | ||
|
||
from llmfoundry.utils import prompt_files as utils | ||
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') | ||
log = logging.getLogger(__name__) | ||
|
||
ENDPOINT_API_KEY_ENV: str = 'ENDPOINT_API_KEY' | ||
ENDPOINT_URL_ENV: str = 'ENDPOINT_URL' | ||
|
||
PROMPT_DELIMITER = '\n' | ||
|
||
|
||
def parse_args() -> Namespace: | ||
"""Parse commandline arguments.""" | ||
parser = ArgumentParser( | ||
description='Call prompts against a text completions endpoint') | ||
|
||
##### | ||
# Path Parameters | ||
parser.add_argument( | ||
'-i', | ||
'--inputs', | ||
nargs='+', | ||
help=f'List of strings, local datafiles (starting with {utils.PROMPTFILE_PREFIX}),' +\ | ||
' and/or remote object stores' | ||
) | ||
parser.add_argument( | ||
'--prompt-delimiter', | ||
default='\n', | ||
help= | ||
'Prompt delimiter for txt files. By default, a file is a single prompt') | ||
|
||
parser.add_argument('-o', | ||
'--output-folder', | ||
required=True, | ||
help='Remote folder to save the output') | ||
|
||
##### | ||
# Generation Parameters | ||
parser.add_argument( | ||
'--rate-limit', | ||
type=int, | ||
default=75, | ||
help='Max number of calls to make to the endpoint in a second') | ||
parser.add_argument( | ||
'--batch-size', | ||
type=int, | ||
default=10, | ||
help='Max number of calls to make to the endpoint in a single request') | ||
|
||
##### | ||
# Endpoint Parameters | ||
parser.add_argument( | ||
'-e', | ||
'--endpoint', | ||
type=str, | ||
help= | ||
f'OpenAI-compatible text completions endpoint to query on. If not set, will read from {ENDPOINT_URL_ENV}' | ||
) | ||
|
||
parser.add_argument('--max-tokens', type=int, default=100) | ||
parser.add_argument('--temperature', type=float, default=1.0) | ||
parser.add_argument('--top-k', type=int, default=50) | ||
parser.add_argument('--top-p', type=float, default=1.0) | ||
return parser.parse_args() | ||
|
||
|
||
async def main(args: Namespace) -> None: | ||
# This is mildly experimental, so for now imports are not added as part of llm-foundry | ||
try: | ||
import aiohttp | ||
except ImportError as e: | ||
raise ImportError('Please install aiohttp') from e | ||
|
||
try: | ||
from ratelimit import limits, sleep_and_retry | ||
except ImportError as e: | ||
raise ImportError('Please install ratelimit') from e | ||
|
||
if args.batch_size > args.rate_limit: | ||
raise ValueError( | ||
f'Batch size is {args.batch_size} but rate limit is set to {args.rate_limit} / s' | ||
) | ||
|
||
url = args.endpoint if args.endpoint else os.environ.get(ENDPOINT_URL_ENV) | ||
if not url: | ||
raise ValueError( | ||
f'URL must be provided via --endpoint or {ENDPOINT_URL_ENV}') | ||
|
||
log.info(f'Using endpoint {url}') | ||
|
||
api_key = os.environ.get(ENDPOINT_API_KEY_ENV, '') | ||
if not api_key: | ||
log.warning(f'API key not set in {ENDPOINT_API_KEY_ENV}') | ||
|
||
new_inputs = [] | ||
for prompt in args.inputs: | ||
if prompt.startswith(utils.PROMPTFILE_PREFIX): | ||
new_inputs.append(prompt) | ||
continue | ||
|
||
input_object_store = maybe_create_object_store_from_uri(prompt) | ||
if input_object_store is not None: | ||
local_output_path = tempfile.TemporaryDirectory().name | ||
get_file(prompt, str(local_output_path)) | ||
log.info(f'Downloaded {prompt} to {local_output_path}') | ||
prompt = f'{utils.PROMPTFILE_PREFIX}{local_output_path}' | ||
|
||
new_inputs.append(prompt) | ||
|
||
prompt_strings = utils.load_prompts(new_inputs, args.prompt_delimiter) | ||
|
||
cols = ['batch', 'prompt', 'output'] | ||
param_data = { | ||
'max_tokens': args.max_tokens, | ||
'temperature': args.temperature, | ||
'top_k': args.top_k, | ||
'top_p': args.top_p, | ||
} | ||
|
||
total_batches = math.ceil(len(prompt_strings) / args.batch_size) | ||
log.info( | ||
f'Generating {len(prompt_strings)} prompts in {total_batches} batches') | ||
|
||
@sleep_and_retry | ||
@limits(calls=total_batches, period=1) # type: ignore | ||
async def generate(session: aiohttp.ClientSession, batch: int, | ||
prompts: list): | ||
data = copy.copy(param_data) | ||
data['prompt'] = prompts | ||
headers = {'Authorization': api_key, 'Content-Type': 'application/json'} | ||
req_start = time.time() | ||
async with session.post(url, headers=headers, json=data) as resp: | ||
if resp.ok: | ||
try: | ||
response = await resp.json() | ||
except requests.JSONDecodeError: | ||
raise Exception( | ||
f'Bad response: {resp.status} {resp.reason}') | ||
else: | ||
raise Exception(f'Bad response: {resp.status} {resp.reason}') | ||
|
||
req_end = time.time() | ||
n_compl = response['usage']['completion_tokens'] | ||
n_prompt = response['usage']['prompt_tokens'] | ||
req_latency = (req_end - req_start) | ||
log.info(f'Completed batch {batch}: {n_compl:,} completion' + | ||
f' tokens using {n_prompt:,} prompt tokens in {req_latency}s') | ||
|
||
res = pd.DataFrame(columns=cols) | ||
|
||
for r in response['choices']: | ||
index = r['index'] | ||
res.loc[len(res)] = [batch, prompts[index], r['text']] | ||
return res | ||
|
||
res = pd.DataFrame(columns=cols) | ||
batch = 0 | ||
|
||
gen_start = time.time() | ||
async with aiohttp.ClientSession() as session: | ||
tasks = [] | ||
|
||
for i in range(total_batches): | ||
prompts = prompt_strings[i * args.batch_size:min( | ||
(i + 1) * args.batch_size, len(prompt_strings))] | ||
|
||
tasks.append(generate(session, batch, prompts)) | ||
batch += 1 | ||
|
||
results = await asyncio.gather(*tasks) | ||
res = pd.concat(results) | ||
|
||
res.reset_index(drop=True, inplace=True) | ||
|
||
gen_end = time.time() | ||
gen_latency = (gen_end - gen_start) | ||
log.info(f'Generated {len(res)} prompts in {gen_latency}s, example data:') | ||
log.info(res.head()) | ||
|
||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
file = 'output.csv' | ||
local_path = os.path.join(tmp_dir, file) | ||
res.to_csv(local_path, index=False) | ||
|
||
output_object_store = maybe_create_object_store_from_uri( | ||
args.output_folder) | ||
if output_object_store is not None: | ||
_, _, output_folder_prefix = parse_uri(args.output_folder) | ||
remote_path = os.path.join(output_folder_prefix, file) | ||
output_object_store.upload_object(remote_path, local_path) | ||
output_object_store.download_object | ||
log.info(f'Uploaded results to {args.output_folder}/{file}') | ||
else: | ||
output_dir, _ = os.path.split(args.output_folder) | ||
os.makedirs(output_dir, exist_ok=True) | ||
os.rename(local_path, args.output_folder) | ||
log.info(f'Saved results to {args.output_folder}') | ||
|
||
|
||
if __name__ == '__main__': | ||
asyncio.run(main(parse_args())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright 2022 MosaicML LLM Foundry authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from pathlib import Path | ||
|
||
from llmfoundry.utils import prompt_files as utils | ||
|
||
|
||
def test_load_prompt_strings(tmp_path: Path): | ||
assert utils.load_prompts(['hello', 'goodbye']) == ['hello', 'goodbye'] | ||
|
||
with open(tmp_path / 'prompts.txt', 'w') as f: | ||
f.write('hello goodbye') | ||
|
||
temp = utils.PROMPTFILE_PREFIX + str(tmp_path / 'prompts.txt') | ||
assert utils.load_prompts( | ||
[temp, temp, 'why'], | ||
' ') == ['hello', 'goodbye', 'hello', 'goodbye', 'why'] |