diff --git a/scripts/inference/endpoint_generate.py b/scripts/inference/endpoint_generate.py index 79a66d38b0..4acf146cf6 100644 --- a/scripts/inference/endpoint_generate.py +++ b/scripts/inference/endpoint_generate.py @@ -14,13 +14,13 @@ from composer.utils import (ObjectStore, maybe_create_object_store_from_uri, parse_uri) -logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") +logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') log = logging.getLogger(__name__) ENDPOINT_API_KEY_ENV: str = 'ENDPOINT_API_KEY' ENDPOINT_URL_ENV: str = 'ENDPOINT_URL' -PROMPT_DELIMITER = "\n" +PROMPT_DELIMITER = '\n' def parse_args() -> Namespace: @@ -38,17 +38,17 @@ def parse_args() -> Namespace: 'prompt contained in a txt file.' ) - now = time.strftime("%Y%m%d-%H%M%S") - default_local_folder = f"/tmp/output/{now}" + now = time.strftime('%Y%m%d-%H%M%S') + default_local_folder = f'/tmp/output/{now}' parser.add_argument('-l', '--local-folder', type=str, default=default_local_folder, - help="Local folder to save the output") + help='Local folder to save the output') parser.add_argument('-o', '--output-folder', - help="Remote folder to save the output") + help='Remote folder to save the output') ##### # Generation Parameters @@ -56,12 +56,12 @@ def parse_args() -> Namespace: '--rate-limit', type=int, default=5, - help="Max number of calls to make to the endpoint in a second") + help='Max number of calls to make to the endpoint in a second') parser.add_argument( '--batch-size', type=int, default=5, - help="Max number of calls to make to the endpoint in a single request") + help='Max number of calls to make to the endpoint in a single request') ##### # Endpoint Parameters @@ -124,12 +124,12 @@ async def main(args: Namespace) -> None: prompt = load_prompt_string_from_file(prompt) prompt_strings.append(prompt) - cols = ["batch", "prompt", "output"] + cols = ['batch', 'prompt', 'output'] param_data = { - "max_tokens": args.max_tokens, - "temperature": args.temperature, - "top_k": args.top_k, - "top_p": args.top_p, + 'max_tokens': args.max_tokens, + 'temperature': args.temperature, + 'top_k': args.top_k, + 'top_p': args.top_p, } @sleep_and_retry @@ -141,7 +141,7 @@ async def generate(session: aiohttp.ClientSession, batch: int, api_key = os.environ.get(ENDPOINT_API_KEY_ENV, '') if not api_key: log.warning(f'API key not set in {ENDPOINT_API_KEY_ENV}') - headers = {"Authorization": api_key, "Content-Type": "application/json"} + headers = {'Authorization': api_key, 'Content-Type': 'application/json'} req_start = time.time() async with session.post(url, headers=headers, json=data) as resp: @@ -158,17 +158,17 @@ async def generate(session: aiohttp.ClientSession, batch: int, ) req_end = time.time() - n_compl = response["usage"]["completion_tokens"] - n_prompt = response["usage"]["prompt_tokens"] + n_compl = response['usage']['completion_tokens'] + n_prompt = response['usage']['prompt_tokens'] req_latency = (req_end - req_start) log.info(f'Completed batch {batch}: {n_compl:,} completion' + f' tokens using {n_prompt:,} prompt tokens in {req_latency}s') res = pd.DataFrame(columns=cols) - for r in response["choices"]: - index = r["index"] - res.loc[len(res)] = [batch, prompts[index], r["text"]] + for r in response['choices']: + index = r['index'] + res.loc[len(res)] = [batch, prompts[index], r['text']] return res res = pd.DataFrame(columns=cols) @@ -176,7 +176,7 @@ async def generate(session: aiohttp.ClientSession, batch: int, total_batches = math.ceil(len(prompt_strings) / args.batch_size) log.info( - f"Generating {len(prompt_strings)} prompts in {total_batches} batches") + f'Generating {len(prompt_strings)} prompts in {total_batches} batches') async with aiohttp.ClientSession() as session: tasks = [] @@ -192,13 +192,13 @@ async def generate(session: aiohttp.ClientSession, batch: int, res = pd.concat(results) res.reset_index(drop=True, inplace=True) - log.info(f"Generated {len(res)} prompts, example data:") + log.info(f'Generated {len(res)} prompts, example data:') log.info(res.head()) # save res to local output folder os.makedirs(args.local_folder, exist_ok=True) - local_path = os.path.join(args.local_folder, "output.csv") - res.to_csv(os.path.join(args.local_folder, "output.csv"), index=False) + local_path = os.path.join(args.local_folder, 'output.csv') + res.to_csv(os.path.join(args.local_folder, 'output.csv'), index=False) log.info(f'Saved results in {local_path}') if args.output_folder: