Skip to content

Commit

Permalink
Merge branch 'main' into text2mds-error
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Jan 23, 2024
2 parents 9953fe7 + f2614a4 commit 74e24ba
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 7 deletions.
4 changes: 4 additions & 0 deletions llmfoundry/models/inference_api_wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.models.inference_api_wrapper.fmapi import (
FMAPICasualLMEvalWrapper, FMAPIChatAPIEvalWrapper)
from llmfoundry.models.inference_api_wrapper.interface import \
InferenceAPIEvalWrapper
from llmfoundry.models.inference_api_wrapper.openai_causal_lm import (
Expand All @@ -10,4 +12,6 @@
'OpenAICausalLMEvalWrapper',
'OpenAIChatAPIEvalWrapper',
'InferenceAPIEvalWrapper',
'FMAPICasualLMEvalWrapper',
'FMAPIChatAPIEvalWrapper',
]
72 changes: 72 additions & 0 deletions llmfoundry/models/inference_api_wrapper/fmapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import logging
import os
import time
from typing import Dict

import requests
from transformers import AutoTokenizer

from llmfoundry.models.inference_api_wrapper.openai_causal_lm import (
OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper, OpenAIEvalInterface)

__all__ = [
'FMAPICasualLMEvalWrapper',
'FMAPIChatAPIEvalWrapper',
]

log = logging.getLogger(__name__)


def block_until_ready(base_url: str):
"""Block until the endpoint is ready."""
sleep_s = 5
timout_s = 5 * 60 # At max, wait 5 minutes

ping_url = f'{base_url}/ping'

waited_s = 0
while True:
try:
requests.get(ping_url)
log.info(f'Endpoint {ping_url} is ready')
break
except requests.exceptions.ConnectionError:
log.debug(
f'Endpoint {ping_url} not ready yet. Sleeping {sleep_s} seconds'
)
time.sleep(sleep_s)
waited_s += sleep_s

if waited_s >= timout_s:
raise TimeoutError(
f'Endpoint {ping_url} did not become read after {waited_s:,} seconds, exiting'
)


class FMAPIEvalInterface(OpenAIEvalInterface):

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer):
is_local = model_cfg.pop('local', False)
if is_local:
base_url = os.environ.get('MOSAICML_MODEL_ENDPOINT',
'http://0.0.0.0:8080/v2')
model_cfg['base_url'] = base_url
block_until_ready(base_url)

if 'base_url' not in model_cfg:
raise ValueError(
'Must specify base_url or use local=True in model_cfg for FMAPIsEvalWrapper'
)

super().__init__(model_cfg, tokenizer)


class FMAPICasualLMEvalWrapper(FMAPIEvalInterface, OpenAICausalLMEvalWrapper):
"""Databricks Foundational Model API wrapper for causal LM models."""


class FMAPIChatAPIEvalWrapper(FMAPIEvalInterface, OpenAIChatAPIEvalWrapper):
"""Databricks Foundational Model API wrapper for chat models."""
27 changes: 22 additions & 5 deletions llmfoundry/models/inference_api_wrapper/openai_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,35 @@ class OpenAIEvalInterface(InferenceAPIEvalWrapper):

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
super().__init__(model_cfg, tokenizer)
assert os.getenv(
'OPENAI_API_KEY'
) is not None, 'No OpenAI API Key found. Ensure it is saved as an environmental variable called OPENAI_API_KEY.'
try:
import openai
except ImportError as e:
raise MissingConditionalImportError(
extra_deps_group='openai',
conda_package='openai',
conda_channel='conda-forge') from e
self.client = openai.OpenAI()
self.model_name = model_cfg['version']

api_key = os.environ.get('OPENAI_API_KEY')
base_url = model_cfg.get('base_url')
if base_url is None:
# Using OpenAI default, where the API key is required
if api_key is None:
raise ValueError(
'No OpenAI API Key found. Ensure it is saved as an environmental variable called OPENAI_API_KEY.'
)

else:
# Using a custom base URL, where the API key may not be required
log.info(
f'Making request to custom base URL: {base_url}{"" if api_key is not None else " (no API key set)"}'
)
api_key = 'placeholder' # This cannot be None

self.client = openai.OpenAI(base_url=base_url, api_key=api_key)
if 'version' in model_cfg:
self.model_name = model_cfg['version']
else:
self.model_name = model_cfg['name']

def generate_completion(self, prompt: str, num_tokens: int):
raise NotImplementedError()
Expand Down
8 changes: 6 additions & 2 deletions llmfoundry/models/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from llmfoundry.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM,
ComposerHFT5)
from llmfoundry.models.inference_api_wrapper import (OpenAICausalLMEvalWrapper,
from llmfoundry.models.inference_api_wrapper import (FMAPICasualLMEvalWrapper,
FMAPIChatAPIEvalWrapper,
OpenAICausalLMEvalWrapper,
OpenAIChatAPIEvalWrapper)
from llmfoundry.models.mpt import ComposerMPTCausalLM

Expand All @@ -13,5 +15,7 @@
'hf_prefix_lm': ComposerHFPrefixLM,
'hf_t5': ComposerHFT5,
'openai_causal_lm': OpenAICausalLMEvalWrapper,
'openai_chat': OpenAIChatAPIEvalWrapper
'fmapi_causal_lm': FMAPICasualLMEvalWrapper,
'openai_chat': OpenAIChatAPIEvalWrapper,
'fmapi_chat': FMAPIChatAPIEvalWrapper,
}

0 comments on commit 74e24ba

Please sign in to comment.