From 2b482d23ee7d499e6b478881a85ece36c5ffa66b Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Wed, 29 Nov 2023 00:30:22 +0000 Subject: [PATCH] updates --- llmfoundry/utils/prompt_files.py | 58 ++++++++++++++ scripts/inference/endpoint_generate.py | 107 ++++++++++--------------- scripts/inference/hf_generate.py | 31 +++---- tests/test_prompt_files.py | 14 ++++ 4 files changed, 124 insertions(+), 86 deletions(-) create mode 100644 llmfoundry/utils/prompt_files.py create mode 100644 tests/test_prompt_files.py diff --git a/llmfoundry/utils/prompt_files.py b/llmfoundry/utils/prompt_files.py new file mode 100644 index 0000000000..7ce747b4c0 --- /dev/null +++ b/llmfoundry/utils/prompt_files.py @@ -0,0 +1,58 @@ +import os +from typing import List, Optional + +PROMPTFILE_PREFIX = "file::" + + +def load_prompts(prompts: List[str], + prompt_delimiter: Optional[str] = None) -> List[str]: + """ + Loads a set of prompts, both free text and from file + + Args: + prompts (List[str]): List of free text prompts and prompt files + prompt_delimiter (Optional str): Delimiter for text file + If not provided, assumes the prompt file is a single prompt (non-delimited) + + Returns: + List of prompt string(s) + """ + prompt_strings = [] + for prompt in prompts: + if prompt.startswith(PROMPTFILE_PREFIX): + prompts = load_prompts_from_file(prompt, prompt_delimiter) + prompt_strings.extend(prompts) + else: + prompt_strings.append(prompt) + return prompt_strings + + +def load_prompts_from_file(prompt_path: str, + prompt_delimiter: Optional[str] = None) -> List[str]: + """ + Load a set of prompts from a text fie + + Args: + prompt_path (str): Path for text file + prompt_delimiter (Optional str): Delimiter for text file + If not provided, assumes the prompt file is a single prompt (non-delimited) + + Returns: + List of prompt string(s) + """ + + if not prompt_path.startswith(PROMPTFILE_PREFIX): + raise ValueError(f'prompt_path_str must start with {PROMPTFILE_PREFIX}') + + _, prompt_file_path = prompt_path.split(PROMPTFILE_PREFIX, maxsplit=1) + prompt_file_path = os.path.expanduser(prompt_file_path) + if not os.path.isfile(prompt_file_path): + raise FileNotFoundError( + f'{prompt_file_path=} does not match any existing files.') + + with open(prompt_file_path, 'r') as f: + prompt_string = f.read() + + if prompt_delimiter is None: + return [prompt_string] + return [i for i in prompt_string.split(prompt_delimiter) if i] diff --git a/scripts/inference/endpoint_generate.py b/scripts/inference/endpoint_generate.py index 77c9e691c9..56d04f536c 100644 --- a/scripts/inference/endpoint_generate.py +++ b/scripts/inference/endpoint_generate.py @@ -11,14 +11,15 @@ import logging import math import os +import tempfile import time from argparse import ArgumentParser, Namespace -from typing import List, cast import pandas as pd import requests -from composer.utils import (ObjectStore, maybe_create_object_store_from_uri, - parse_uri) +from composer.utils import maybe_create_object_store_from_uri, parse_uri + +from llmfoundry.utils import prompt_files as utils logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') log = logging.getLogger(__name__) @@ -40,20 +41,18 @@ def parse_args() -> Namespace: '-p', '--prompts', nargs='+', - help='Generation prompts. Use syntax "file::/path/to/prompt.txt" to load a ' +\ - 'prompt contained in a txt file.' + help='List of generation prompts or list of delimited files. Use syntax ' +\ + '"file::/path/to/prompt.txt" to load a prompt(s) contained in a txt file.' ) - - now = time.strftime('%Y%m%d-%H%M%S') - default_local_folder = f'/tmp/output/{now}' - parser.add_argument('-l', - '--local-folder', - type=str, - default=default_local_folder, - help='Local folder to save the output') + parser.add_argument( + '--prompt-delimiter', + default='\n', + help= + 'Prompt delimiter for txt files. By default, a file is a single prompt') parser.add_argument('-o', '--output-folder', + required=True, help='Remote folder to save the output') ##### @@ -61,7 +60,7 @@ def parse_args() -> Namespace: parser.add_argument( '--rate-limit', type=int, - default=10, + default=75, help='Max number of calls to make to the endpoint in a second') parser.add_argument( '--batch-size', @@ -86,21 +85,6 @@ def parse_args() -> Namespace: return parser.parse_args() -def load_prompts_from_file(prompt_path_str: str) -> List[str]: - # Note: slightly different than hf_generate.py (uses delimiter to split strings) - - if not prompt_path_str.startswith('file::'): - raise ValueError('prompt_path_str must start with "file::".') - _, prompt_file_path = prompt_path_str.split('file::', maxsplit=1) - prompt_file_path = os.path.expanduser(prompt_file_path) - if not os.path.isfile(prompt_file_path): - raise FileNotFoundError( - f'{prompt_file_path=} does not match any existing files.') - with open(prompt_file_path, 'r') as f: - prompt_string = f.read() - return prompt_string.split(PROMPT_DELIMITER) - - async def main(args: Namespace) -> None: # This is mildly experimental, so for now imports are not added as part of llm-foundry try: @@ -115,7 +99,7 @@ async def main(args: Namespace) -> None: if args.batch_size > args.rate_limit: raise ValueError( - f'Batch size is {args.batch_size} but rate limit is set to { args.rate_limit} / s' + f'Batch size is {args.batch_size} but rate limit is set to {args.rate_limit} / s' ) url = args.endpoint if args.endpoint else os.environ.get(ENDPOINT_URL_ENV) @@ -129,12 +113,7 @@ async def main(args: Namespace) -> None: if not api_key: log.warning(f'API key not set in {ENDPOINT_API_KEY_ENV}') - # Load prompts - prompt_strings = [] - for prompt in args.prompts: - if prompt.startswith('file::'): - prompt = load_prompts_from_file(prompt) - prompt_strings.append(prompt) + prompt_strings = utils.load_prompts(args.prompts, args.prompt_delimiter) cols = ['batch', 'prompt', 'output'] param_data = { @@ -144,14 +123,17 @@ async def main(args: Namespace) -> None: 'top_p': args.top_p, } + total_batches = math.ceil(len(prompt_strings) / args.batch_size) + log.info( + f'Generating {len(prompt_strings)} prompts in {total_batches} batches') + @sleep_and_retry - @limits(calls=args.rate_limit // args.batch_size, period=1) # type: ignore + @limits(calls=total_batches, period=1) # type: ignore async def generate(session: aiohttp.ClientSession, batch: int, prompts: list): data = copy.copy(param_data) data['prompt'] = prompts headers = {'Authorization': api_key, 'Content-Type': 'application/json'} - req_start = time.time() async with session.post(url, headers=headers, json=data) as resp: if resp.ok: @@ -159,12 +141,9 @@ async def generate(session: aiohttp.ClientSession, batch: int, response = await resp.json() except requests.JSONDecodeError: raise Exception( - f'Bad response: {resp.status_code} {resp.reason}' # type: ignore - ) + f'Bad response: {resp.status} {resp.reason}') else: - raise Exception( - f'Bad response: {resp.status_code} {resp.content.decode().strip()}' # type: ignore - ) + raise Exception(f'Bad response: {resp.status} {resp.reason}') req_end = time.time() n_compl = response['usage']['completion_tokens'] @@ -183,10 +162,7 @@ async def generate(session: aiohttp.ClientSession, batch: int, res = pd.DataFrame(columns=cols) batch = 0 - total_batches = math.ceil(len(prompt_strings) / args.batch_size) - log.info( - f'Generating {len(prompt_strings)} prompts in {total_batches} batches') - + gen_start = time.time() async with aiohttp.ClientSession() as session: tasks = [] @@ -201,28 +177,29 @@ async def generate(session: aiohttp.ClientSession, batch: int, res = pd.concat(results) res.reset_index(drop=True, inplace=True) - log.info(f'Generated {len(res)} prompts, example data:') + + gen_end = time.time() + gen_latency = (gen_end - gen_start) + log.info(f'Generated {len(res)} prompts in {gen_latency}s, example data:') log.info(res.head()) - # save res to local output folder - os.makedirs(args.local_folder, exist_ok=True) - local_path = os.path.join(args.local_folder, 'output.csv') - res.to_csv(os.path.join(args.local_folder, 'output.csv'), index=False) - log.info(f'Saved results in {local_path}') - - if args.output_folder: - # Upload the local output to the remote location - output_object_store = cast( - ObjectStore, maybe_create_object_store_from_uri(args.output_folder)) - _, _, output_folder_prefix = parse_uri(args.output_folder) - files_to_upload = os.listdir(args.local_folder) - - for file in files_to_upload: - assert not os.path.isdir(file) - local_path = os.path.join(args.local_folder, file) + with tempfile.TemporaryDirectory() as tmp_dir: + file = 'output.csv' + local_path = os.path.join(tmp_dir, file) + res.to_csv(local_path, index=False) + + output_object_store = maybe_create_object_store_from_uri( + args.output_folder) + if output_object_store is not None: + _, _, output_folder_prefix = parse_uri(args.output_folder) remote_path = os.path.join(output_folder_prefix, file) output_object_store.upload_object(remote_path, local_path) - log.info(f'Uploaded {local_path} to {args.output_folder}/{file}') + log.info(f'Uploaded results to {args.output_folder}/{file}') + else: + os.makedirs(args.output_folder, exist_ok=True) + permanent_local = os.path.join(args.output_folder, file) + os.rename(local_path, permanent_local) + log.info(f'Saved results to {permanent_local}') if __name__ == '__main__': diff --git a/scripts/inference/hf_generate.py b/scripts/inference/hf_generate.py index 45ddc6b63e..6ac645e5b7 100644 --- a/scripts/inference/hf_generate.py +++ b/scripts/inference/hf_generate.py @@ -1,7 +1,6 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 import itertools -import os import random import time import warnings @@ -13,6 +12,8 @@ import torch from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from llmfoundry.utils import prompt_files as utils + def get_dtype(dtype: str): if dtype == 'fp32': @@ -62,9 +63,14 @@ def parse_args() -> Namespace: 'My name is', 'This is an explanation of deep learning to a five year old. Deep learning is', ], - help='Generation prompts. Use syntax "file::/path/to/prompt.txt" to load a ' +\ - 'prompt contained in a txt file.' + help='List of generation prompts or list of delimited files. Use syntax ' +\ + '"file::/path/to/prompt.txt" to load a prompt(s) contained in a txt file.' ) + parser.add_argument( + '--prompt-delimiter', + default=None, + help= + 'Prompt delimiter for txt files. By default, a file is a single prompt') parser.add_argument('--max_seq_len', type=int, default=None) parser.add_argument('--max_new_tokens', type=int, default=100) parser.add_argument('--max_batch_size', type=int, default=None) @@ -125,19 +131,6 @@ def parse_args() -> Namespace: return parser.parse_args() -def load_prompt_string_from_file(prompt_path_str: str): - if not prompt_path_str.startswith('file::'): - raise ValueError('prompt_path_str must start with "file::".') - _, prompt_file_path = prompt_path_str.split('file::', maxsplit=1) - prompt_file_path = os.path.expanduser(prompt_file_path) - if not os.path.isfile(prompt_file_path): - raise FileNotFoundError( - f'{prompt_file_path=} does not match any existing files.') - with open(prompt_file_path, 'r') as f: - prompt_string = ''.join(f.readlines()) - return prompt_string - - def maybe_synchronize(): if torch.cuda.is_available(): torch.cuda.synchronize() @@ -163,11 +156,7 @@ def main(args: Namespace) -> None: print(f'Using {model_dtype=}') # Load prompts - prompt_strings = [] - for prompt in args.prompts: - if prompt.startswith('file::'): - prompt = load_prompt_string_from_file(prompt) - prompt_strings.append(prompt) + prompt_strings = utils.load_prompts(args.prompts, args.prompt_delimiter) # Grab config first print(f'Loading HF Config...') diff --git a/tests/test_prompt_files.py b/tests/test_prompt_files.py new file mode 100644 index 0000000000..5f2fb582fc --- /dev/null +++ b/tests/test_prompt_files.py @@ -0,0 +1,14 @@ +from pathlib import Path + +from llmfoundry.utils import prompt_files as utils + +def test_load_prompt_strings(tmp_path: Path): + assert utils.load_prompts(['hello', 'goodbye']) == ['hello', 'goodbye'] + + with open(tmp_path / 'prompts.txt', 'w') as f: + f.write('hello goodbye') + + temp = utils.PROMPTFILE_PREFIX + str(tmp_path / 'prompts.txt') + assert utils.load_prompts( + [temp, temp, 'why'], + ' ') == ['hello', 'goodbye', 'hello', 'goodbye', 'why']