From ce4464e81f9ebe8905644d0094109266890de95d Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Tue, 9 Jan 2024 18:33:26 +0000 Subject: [PATCH] formatting --- .../models/inference_api_wrapper/fmapi.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/llmfoundry/models/inference_api_wrapper/fmapi.py b/llmfoundry/models/inference_api_wrapper/fmapi.py index d1485de68c..97897589a8 100644 --- a/llmfoundry/models/inference_api_wrapper/fmapi.py +++ b/llmfoundry/models/inference_api_wrapper/fmapi.py @@ -1,15 +1,16 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import os import logging -from typing import Dict +import os import time +from typing import Dict + import requests from transformers import AutoTokenizer from llmfoundry.models.inference_api_wrapper.openai_causal_lm import ( - OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper) + OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper, OpenAIEvalInterface) __all__ = [ 'FMAPICasualLMEvalWrapper', @@ -21,9 +22,8 @@ def block_until_ready(base_url: str): """Block until the endpoint is ready.""" - sleep_s = 5 - remaining_s = 5 * 50 # At max, wait 5 minutes + remaining_s = 5 * 50 # At max, wait 5 minutes ping_url = f'{base_url}/ping' @@ -32,7 +32,9 @@ def block_until_ready(base_url: str): requests.get(ping_url) break except requests.exceptions.ConnectionError: - log.debug(f'Endpoint {ping_url} not ready yet. Sleeping {sleep_s} seconds') + log.debug( + f'Endpoint {ping_url} not ready yet. Sleeping {sleep_s} seconds' + ) time.sleep(sleep_s) remaining_s -= sleep_s else: @@ -40,20 +42,23 @@ def block_until_ready(base_url: str): break if remaining_s <= 0: - raise TimeoutError(f'Endpoint {ping_url} never became ready, exiting') + raise TimeoutError( + f'Endpoint {ping_url} never became ready, exiting') -class FMAPIEvalInterface: + +class FMAPIEvalInterface(OpenAIEvalInterface): def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer): is_local = model_cfg.get('local', False) if is_local: base_url = os.environ.get('MOSAICML_MODEL_ENDPOINT', - 'http://0.0.0.0:8080/v2') + 'http://0.0.0.0:8080/v2') elif 'base_url' in model_cfg: base_url = model_cfg['base_url'] else: - raise ValueError('Must specify base_url in model_cfg for FMAPIsEvalWrapper') - + raise ValueError( + 'Must specify base_url in model_cfg for FMAPIsEvalWrapper') + model_cfg['base_url'] = base_url if is_local: