From f0badd5ac3946dccf9db0f1edcd1ae079cc273bf Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Tue, 28 Nov 2023 11:56:44 -0800 Subject: [PATCH] more logging --- scripts/inference/endpoint_generate.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/scripts/inference/endpoint_generate.py b/scripts/inference/endpoint_generate.py index 0749efdffb..352045f47c 100644 --- a/scripts/inference/endpoint_generate.py +++ b/scripts/inference/endpoint_generate.py @@ -113,6 +113,8 @@ async def main(args: Namespace) -> None: raise ValueError( f'URL must be provided via --endpoint or {ENDPOINT_URL_ENV}') + log.info(f'Using endpoint {url}') + # Load prompts prompt_strings = [] for prompt in args.prompts: @@ -134,10 +136,10 @@ async def generate(session: aiohttp.ClientSession, batch: int, prompts: list): data = copy.copy(param_data) data['prompt'] = prompts - headers = { - "Authorization": os.environ.get(ENDPOINT_API_KEY_ENV), - "Content-Type": "application/json" - } + api_key = os.environ.get(ENDPOINT_API_KEY_ENV, '') + if not api_key: + log.warn('API key not set in {ENDPOINT_API_KEY_ENV}') + headers = {"Authorization": api_key, "Content-Type": "application/json"} req_start = time.time() async with session.post(url, headers=headers, json=data) as resp: