diff --git a/llmfoundry/models/inference_api_wrapper/__init__.py b/llmfoundry/models/inference_api_wrapper/__init__.py index 496abf2aa6..9bb2ece2b2 100644 --- a/llmfoundry/models/inference_api_wrapper/__init__.py +++ b/llmfoundry/models/inference_api_wrapper/__init__.py @@ -1,6 +1,8 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from llmfoundry.models.inference_api_wrapper.fmapi import ( + FMAPICasualLMEvalWrapper, FMAPIChatAPIEvalWrapper) from llmfoundry.models.inference_api_wrapper.interface import \ InferenceAPIEvalWrapper from llmfoundry.models.inference_api_wrapper.openai_causal_lm import ( @@ -10,4 +12,6 @@ 'OpenAICausalLMEvalWrapper', 'OpenAIChatAPIEvalWrapper', 'InferenceAPIEvalWrapper', + 'FMAPICasualLMEvalWrapper', + 'FMAPIChatAPIEvalWrapper', ] diff --git a/llmfoundry/models/inference_api_wrapper/fmapi.py b/llmfoundry/models/inference_api_wrapper/fmapi.py new file mode 100644 index 0000000000..867b3c272e --- /dev/null +++ b/llmfoundry/models/inference_api_wrapper/fmapi.py @@ -0,0 +1,72 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import time +from typing import Dict + +import requests +from transformers import AutoTokenizer + +from llmfoundry.models.inference_api_wrapper.openai_causal_lm import ( + OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper, OpenAIEvalInterface) + +__all__ = [ + 'FMAPICasualLMEvalWrapper', + 'FMAPIChatAPIEvalWrapper', +] + +log = logging.getLogger(__name__) + + +def block_until_ready(base_url: str): + """Block until the endpoint is ready.""" + sleep_s = 5 + timout_s = 5 * 60 # At max, wait 5 minutes + + ping_url = f'{base_url}/ping' + + waited_s = 0 + while True: + try: + requests.get(ping_url) + log.info(f'Endpoint {ping_url} is ready') + break + except requests.exceptions.ConnectionError: + log.debug( + f'Endpoint {ping_url} not ready yet. Sleeping {sleep_s} seconds' + ) + time.sleep(sleep_s) + waited_s += sleep_s + + if waited_s >= timout_s: + raise TimeoutError( + f'Endpoint {ping_url} did not become read after {waited_s:,} seconds, exiting' + ) + + +class FMAPIEvalInterface(OpenAIEvalInterface): + + def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer): + is_local = model_cfg.pop('local', False) + if is_local: + base_url = os.environ.get('MOSAICML_MODEL_ENDPOINT', + 'http://0.0.0.0:8080/v2') + model_cfg['base_url'] = base_url + block_until_ready(base_url) + + if 'base_url' not in model_cfg: + raise ValueError( + 'Must specify base_url or use local=True in model_cfg for FMAPIsEvalWrapper' + ) + + super().__init__(model_cfg, tokenizer) + + +class FMAPICasualLMEvalWrapper(FMAPIEvalInterface, OpenAICausalLMEvalWrapper): + """Databricks Foundational Model API wrapper for causal LM models.""" + + +class FMAPIChatAPIEvalWrapper(FMAPIEvalInterface, OpenAIChatAPIEvalWrapper): + """Databricks Foundational Model API wrapper for chat models.""" diff --git a/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py b/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py index 39de2ba59c..587dd179bd 100644 --- a/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py +++ b/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py @@ -36,9 +36,6 @@ class OpenAIEvalInterface(InferenceAPIEvalWrapper): def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: super().__init__(model_cfg, tokenizer) - assert os.getenv( - 'OPENAI_API_KEY' - ) is not None, 'No OpenAI API Key found. Ensure it is saved as an environmental variable called OPENAI_API_KEY.' try: import openai except ImportError as e: @@ -46,8 +43,28 @@ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: extra_deps_group='openai', conda_package='openai', conda_channel='conda-forge') from e - self.client = openai.OpenAI() - self.model_name = model_cfg['version'] + + api_key = os.environ.get('OPENAI_API_KEY') + base_url = model_cfg.get('base_url') + if base_url is None: + # Using OpenAI default, where the API key is required + if api_key is None: + raise ValueError( + 'No OpenAI API Key found. Ensure it is saved as an environmental variable called OPENAI_API_KEY.' + ) + + else: + # Using a custom base URL, where the API key may not be required + log.info( + f'Making request to custom base URL: {base_url}{"" if api_key is not None else " (no API key set)"}' + ) + api_key = 'placeholder' # This cannot be None + + self.client = openai.OpenAI(base_url=base_url, api_key=api_key) + if 'version' in model_cfg: + self.model_name = model_cfg['version'] + else: + self.model_name = model_cfg['name'] def generate_completion(self, prompt: str, num_tokens: int): raise NotImplementedError() diff --git a/llmfoundry/models/model_registry.py b/llmfoundry/models/model_registry.py index be09a69835..ff9942f5f6 100644 --- a/llmfoundry/models/model_registry.py +++ b/llmfoundry/models/model_registry.py @@ -3,7 +3,9 @@ from llmfoundry.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM, ComposerHFT5) -from llmfoundry.models.inference_api_wrapper import (OpenAICausalLMEvalWrapper, +from llmfoundry.models.inference_api_wrapper import (FMAPICasualLMEvalWrapper, + FMAPIChatAPIEvalWrapper, + OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper) from llmfoundry.models.mpt import ComposerMPTCausalLM @@ -13,5 +15,7 @@ 'hf_prefix_lm': ComposerHFPrefixLM, 'hf_t5': ComposerHFT5, 'openai_causal_lm': OpenAICausalLMEvalWrapper, - 'openai_chat': OpenAIChatAPIEvalWrapper + 'fmapi_causal_lm': FMAPICasualLMEvalWrapper, + 'openai_chat': OpenAIChatAPIEvalWrapper, + 'fmapi_chat': FMAPIChatAPIEvalWrapper, }