Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl committed Nov 29, 2023
1 parent 89dc394 commit 2b482d2
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 86 deletions.
58 changes: 58 additions & 0 deletions llmfoundry/utils/prompt_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
from typing import List, Optional

PROMPTFILE_PREFIX = "file::"


def load_prompts(prompts: List[str],
prompt_delimiter: Optional[str] = None) -> List[str]:
"""
Loads a set of prompts, both free text and from file
Args:
prompts (List[str]): List of free text prompts and prompt files
prompt_delimiter (Optional str): Delimiter for text file
If not provided, assumes the prompt file is a single prompt (non-delimited)
Returns:
List of prompt string(s)
"""
prompt_strings = []
for prompt in prompts:
if prompt.startswith(PROMPTFILE_PREFIX):
prompts = load_prompts_from_file(prompt, prompt_delimiter)
prompt_strings.extend(prompts)
else:
prompt_strings.append(prompt)
return prompt_strings


def load_prompts_from_file(prompt_path: str,
prompt_delimiter: Optional[str] = None) -> List[str]:
"""
Load a set of prompts from a text fie
Args:
prompt_path (str): Path for text file
prompt_delimiter (Optional str): Delimiter for text file
If not provided, assumes the prompt file is a single prompt (non-delimited)
Returns:
List of prompt string(s)
"""

if not prompt_path.startswith(PROMPTFILE_PREFIX):
raise ValueError(f'prompt_path_str must start with {PROMPTFILE_PREFIX}')

_, prompt_file_path = prompt_path.split(PROMPTFILE_PREFIX, maxsplit=1)
prompt_file_path = os.path.expanduser(prompt_file_path)
if not os.path.isfile(prompt_file_path):
raise FileNotFoundError(
f'{prompt_file_path=} does not match any existing files.')

with open(prompt_file_path, 'r') as f:
prompt_string = f.read()

if prompt_delimiter is None:
return [prompt_string]
return [i for i in prompt_string.split(prompt_delimiter) if i]
107 changes: 42 additions & 65 deletions scripts/inference/endpoint_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
import logging
import math
import os
import tempfile
import time
from argparse import ArgumentParser, Namespace
from typing import List, cast

import pandas as pd
import requests
from composer.utils import (ObjectStore, maybe_create_object_store_from_uri,
parse_uri)
from composer.utils import maybe_create_object_store_from_uri, parse_uri

from llmfoundry.utils import prompt_files as utils

logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
log = logging.getLogger(__name__)
Expand All @@ -40,28 +41,26 @@ def parse_args() -> Namespace:
'-p',
'--prompts',
nargs='+',
help='Generation prompts. Use syntax "file::/path/to/prompt.txt" to load a ' +\
'prompt contained in a txt file.'
help='List of generation prompts or list of delimited files. Use syntax ' +\
'"file::/path/to/prompt.txt" to load a prompt(s) contained in a txt file.'
)

now = time.strftime('%Y%m%d-%H%M%S')
default_local_folder = f'/tmp/output/{now}'
parser.add_argument('-l',
'--local-folder',
type=str,
default=default_local_folder,
help='Local folder to save the output')
parser.add_argument(
'--prompt-delimiter',
default='\n',
help=
'Prompt delimiter for txt files. By default, a file is a single prompt')

parser.add_argument('-o',
'--output-folder',
required=True,
help='Remote folder to save the output')

#####
# Generation Parameters
parser.add_argument(
'--rate-limit',
type=int,
default=10,
default=75,
help='Max number of calls to make to the endpoint in a second')
parser.add_argument(
'--batch-size',
Expand All @@ -86,21 +85,6 @@ def parse_args() -> Namespace:
return parser.parse_args()


def load_prompts_from_file(prompt_path_str: str) -> List[str]:
# Note: slightly different than hf_generate.py (uses delimiter to split strings)

if not prompt_path_str.startswith('file::'):
raise ValueError('prompt_path_str must start with "file::".')
_, prompt_file_path = prompt_path_str.split('file::', maxsplit=1)
prompt_file_path = os.path.expanduser(prompt_file_path)
if not os.path.isfile(prompt_file_path):
raise FileNotFoundError(
f'{prompt_file_path=} does not match any existing files.')
with open(prompt_file_path, 'r') as f:
prompt_string = f.read()
return prompt_string.split(PROMPT_DELIMITER)


async def main(args: Namespace) -> None:
# This is mildly experimental, so for now imports are not added as part of llm-foundry
try:
Expand All @@ -115,7 +99,7 @@ async def main(args: Namespace) -> None:

if args.batch_size > args.rate_limit:
raise ValueError(
f'Batch size is {args.batch_size} but rate limit is set to { args.rate_limit} / s'
f'Batch size is {args.batch_size} but rate limit is set to {args.rate_limit} / s'
)

url = args.endpoint if args.endpoint else os.environ.get(ENDPOINT_URL_ENV)
Expand All @@ -129,12 +113,7 @@ async def main(args: Namespace) -> None:
if not api_key:
log.warning(f'API key not set in {ENDPOINT_API_KEY_ENV}')

# Load prompts
prompt_strings = []
for prompt in args.prompts:
if prompt.startswith('file::'):
prompt = load_prompts_from_file(prompt)
prompt_strings.append(prompt)
prompt_strings = utils.load_prompts(args.prompts, args.prompt_delimiter)

cols = ['batch', 'prompt', 'output']
param_data = {
Expand All @@ -144,27 +123,27 @@ async def main(args: Namespace) -> None:
'top_p': args.top_p,
}

total_batches = math.ceil(len(prompt_strings) / args.batch_size)
log.info(
f'Generating {len(prompt_strings)} prompts in {total_batches} batches')

@sleep_and_retry
@limits(calls=args.rate_limit // args.batch_size, period=1) # type: ignore
@limits(calls=total_batches, period=1) # type: ignore
async def generate(session: aiohttp.ClientSession, batch: int,
prompts: list):
data = copy.copy(param_data)
data['prompt'] = prompts
headers = {'Authorization': api_key, 'Content-Type': 'application/json'}

req_start = time.time()
async with session.post(url, headers=headers, json=data) as resp:
if resp.ok:
try:
response = await resp.json()
except requests.JSONDecodeError:
raise Exception(
f'Bad response: {resp.status_code} {resp.reason}' # type: ignore
)
f'Bad response: {resp.status} {resp.reason}')
else:
raise Exception(
f'Bad response: {resp.status_code} {resp.content.decode().strip()}' # type: ignore
)
raise Exception(f'Bad response: {resp.status} {resp.reason}')

req_end = time.time()
n_compl = response['usage']['completion_tokens']
Expand All @@ -183,10 +162,7 @@ async def generate(session: aiohttp.ClientSession, batch: int,
res = pd.DataFrame(columns=cols)
batch = 0

total_batches = math.ceil(len(prompt_strings) / args.batch_size)
log.info(
f'Generating {len(prompt_strings)} prompts in {total_batches} batches')

gen_start = time.time()
async with aiohttp.ClientSession() as session:
tasks = []

Expand All @@ -201,28 +177,29 @@ async def generate(session: aiohttp.ClientSession, batch: int,
res = pd.concat(results)

res.reset_index(drop=True, inplace=True)
log.info(f'Generated {len(res)} prompts, example data:')

gen_end = time.time()
gen_latency = (gen_end - gen_start)
log.info(f'Generated {len(res)} prompts in {gen_latency}s, example data:')
log.info(res.head())

# save res to local output folder
os.makedirs(args.local_folder, exist_ok=True)
local_path = os.path.join(args.local_folder, 'output.csv')
res.to_csv(os.path.join(args.local_folder, 'output.csv'), index=False)
log.info(f'Saved results in {local_path}')

if args.output_folder:
# Upload the local output to the remote location
output_object_store = cast(
ObjectStore, maybe_create_object_store_from_uri(args.output_folder))
_, _, output_folder_prefix = parse_uri(args.output_folder)
files_to_upload = os.listdir(args.local_folder)

for file in files_to_upload:
assert not os.path.isdir(file)
local_path = os.path.join(args.local_folder, file)
with tempfile.TemporaryDirectory() as tmp_dir:
file = 'output.csv'
local_path = os.path.join(tmp_dir, file)
res.to_csv(local_path, index=False)

output_object_store = maybe_create_object_store_from_uri(
args.output_folder)
if output_object_store is not None:
_, _, output_folder_prefix = parse_uri(args.output_folder)
remote_path = os.path.join(output_folder_prefix, file)
output_object_store.upload_object(remote_path, local_path)
log.info(f'Uploaded {local_path} to {args.output_folder}/{file}')
log.info(f'Uploaded results to {args.output_folder}/{file}')
else:
os.makedirs(args.output_folder, exist_ok=True)
permanent_local = os.path.join(args.output_folder, file)
os.rename(local_path, permanent_local)
log.info(f'Saved results to {permanent_local}')


if __name__ == '__main__':
Expand Down
31 changes: 10 additions & 21 deletions scripts/inference/hf_generate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import itertools
import os
import random
import time
import warnings
Expand All @@ -13,6 +12,8 @@
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from llmfoundry.utils import prompt_files as utils


def get_dtype(dtype: str):
if dtype == 'fp32':
Expand Down Expand Up @@ -62,9 +63,14 @@ def parse_args() -> Namespace:
'My name is',
'This is an explanation of deep learning to a five year old. Deep learning is',
],
help='Generation prompts. Use syntax "file::/path/to/prompt.txt" to load a ' +\
'prompt contained in a txt file.'
help='List of generation prompts or list of delimited files. Use syntax ' +\
'"file::/path/to/prompt.txt" to load a prompt(s) contained in a txt file.'
)
parser.add_argument(
'--prompt-delimiter',
default=None,
help=
'Prompt delimiter for txt files. By default, a file is a single prompt')
parser.add_argument('--max_seq_len', type=int, default=None)
parser.add_argument('--max_new_tokens', type=int, default=100)
parser.add_argument('--max_batch_size', type=int, default=None)
Expand Down Expand Up @@ -125,19 +131,6 @@ def parse_args() -> Namespace:
return parser.parse_args()


def load_prompt_string_from_file(prompt_path_str: str):
if not prompt_path_str.startswith('file::'):
raise ValueError('prompt_path_str must start with "file::".')
_, prompt_file_path = prompt_path_str.split('file::', maxsplit=1)
prompt_file_path = os.path.expanduser(prompt_file_path)
if not os.path.isfile(prompt_file_path):
raise FileNotFoundError(
f'{prompt_file_path=} does not match any existing files.')
with open(prompt_file_path, 'r') as f:
prompt_string = ''.join(f.readlines())
return prompt_string


def maybe_synchronize():
if torch.cuda.is_available():
torch.cuda.synchronize()
Expand All @@ -163,11 +156,7 @@ def main(args: Namespace) -> None:
print(f'Using {model_dtype=}')

# Load prompts
prompt_strings = []
for prompt in args.prompts:
if prompt.startswith('file::'):
prompt = load_prompt_string_from_file(prompt)
prompt_strings.append(prompt)
prompt_strings = utils.load_prompts(args.prompts, args.prompt_delimiter)

# Grab config first
print(f'Loading HF Config...')
Expand Down
14 changes: 14 additions & 0 deletions tests/test_prompt_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from pathlib import Path

from llmfoundry.utils import prompt_files as utils

def test_load_prompt_strings(tmp_path: Path):
assert utils.load_prompts(['hello', 'goodbye']) == ['hello', 'goodbye']

with open(tmp_path / 'prompts.txt', 'w') as f:
f.write('hello goodbye')

temp = utils.PROMPTFILE_PREFIX + str(tmp_path / 'prompts.txt')
assert utils.load_prompts(
[temp, temp, 'why'],
' ') == ['hello', 'goodbye', 'hello', 'goodbye', 'why']

0 comments on commit 2b482d2

Please sign in to comment.