Skip to content

Commit

Permalink
Merge branch 'main' into qkgn
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley authored Jan 23, 2024
2 parents ebd4c1e + f2614a4 commit 1f59606
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 20 deletions.
4 changes: 4 additions & 0 deletions llmfoundry/models/inference_api_wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.models.inference_api_wrapper.fmapi import (
FMAPICasualLMEvalWrapper, FMAPIChatAPIEvalWrapper)
from llmfoundry.models.inference_api_wrapper.interface import \
InferenceAPIEvalWrapper
from llmfoundry.models.inference_api_wrapper.openai_causal_lm import (
Expand All @@ -10,4 +12,6 @@
'OpenAICausalLMEvalWrapper',
'OpenAIChatAPIEvalWrapper',
'InferenceAPIEvalWrapper',
'FMAPICasualLMEvalWrapper',
'FMAPIChatAPIEvalWrapper',
]
72 changes: 72 additions & 0 deletions llmfoundry/models/inference_api_wrapper/fmapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import logging
import os
import time
from typing import Dict

import requests
from transformers import AutoTokenizer

from llmfoundry.models.inference_api_wrapper.openai_causal_lm import (
OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper, OpenAIEvalInterface)

__all__ = [
'FMAPICasualLMEvalWrapper',
'FMAPIChatAPIEvalWrapper',
]

log = logging.getLogger(__name__)


def block_until_ready(base_url: str):
"""Block until the endpoint is ready."""
sleep_s = 5
timout_s = 5 * 60 # At max, wait 5 minutes

ping_url = f'{base_url}/ping'

waited_s = 0
while True:
try:
requests.get(ping_url)
log.info(f'Endpoint {ping_url} is ready')
break
except requests.exceptions.ConnectionError:
log.debug(
f'Endpoint {ping_url} not ready yet. Sleeping {sleep_s} seconds'
)
time.sleep(sleep_s)
waited_s += sleep_s

if waited_s >= timout_s:
raise TimeoutError(
f'Endpoint {ping_url} did not become read after {waited_s:,} seconds, exiting'
)


class FMAPIEvalInterface(OpenAIEvalInterface):

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer):
is_local = model_cfg.pop('local', False)
if is_local:
base_url = os.environ.get('MOSAICML_MODEL_ENDPOINT',
'http://0.0.0.0:8080/v2')
model_cfg['base_url'] = base_url
block_until_ready(base_url)

if 'base_url' not in model_cfg:
raise ValueError(
'Must specify base_url or use local=True in model_cfg for FMAPIsEvalWrapper'
)

super().__init__(model_cfg, tokenizer)


class FMAPICasualLMEvalWrapper(FMAPIEvalInterface, OpenAICausalLMEvalWrapper):
"""Databricks Foundational Model API wrapper for causal LM models."""


class FMAPIChatAPIEvalWrapper(FMAPIEvalInterface, OpenAIChatAPIEvalWrapper):
"""Databricks Foundational Model API wrapper for chat models."""
27 changes: 22 additions & 5 deletions llmfoundry/models/inference_api_wrapper/openai_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,35 @@ class OpenAIEvalInterface(InferenceAPIEvalWrapper):

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
super().__init__(model_cfg, tokenizer)
assert os.getenv(
'OPENAI_API_KEY'
) is not None, 'No OpenAI API Key found. Ensure it is saved as an environmental variable called OPENAI_API_KEY.'
try:
import openai
except ImportError as e:
raise MissingConditionalImportError(
extra_deps_group='openai',
conda_package='openai',
conda_channel='conda-forge') from e
self.client = openai.OpenAI()
self.model_name = model_cfg['version']

api_key = os.environ.get('OPENAI_API_KEY')
base_url = model_cfg.get('base_url')
if base_url is None:
# Using OpenAI default, where the API key is required
if api_key is None:
raise ValueError(
'No OpenAI API Key found. Ensure it is saved as an environmental variable called OPENAI_API_KEY.'
)

else:
# Using a custom base URL, where the API key may not be required
log.info(
f'Making request to custom base URL: {base_url}{"" if api_key is not None else " (no API key set)"}'
)
api_key = 'placeholder' # This cannot be None

self.client = openai.OpenAI(base_url=base_url, api_key=api_key)
if 'version' in model_cfg:
self.model_name = model_cfg['version']
else:
self.model_name = model_cfg['name']

def generate_completion(self, prompt: str, num_tokens: int):
raise NotImplementedError()
Expand Down
8 changes: 6 additions & 2 deletions llmfoundry/models/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from llmfoundry.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM,
ComposerHFT5)
from llmfoundry.models.inference_api_wrapper import (OpenAICausalLMEvalWrapper,
from llmfoundry.models.inference_api_wrapper import (FMAPICasualLMEvalWrapper,
FMAPIChatAPIEvalWrapper,
OpenAICausalLMEvalWrapper,
OpenAIChatAPIEvalWrapper)
from llmfoundry.models.mpt import ComposerMPTCausalLM

Expand All @@ -13,5 +15,7 @@
'hf_prefix_lm': ComposerHFPrefixLM,
'hf_t5': ComposerHFT5,
'openai_causal_lm': OpenAICausalLMEvalWrapper,
'openai_chat': OpenAIChatAPIEvalWrapper
'fmapi_causal_lm': FMAPICasualLMEvalWrapper,
'openai_chat': OpenAIChatAPIEvalWrapper,
'fmapi_chat': FMAPIChatAPIEvalWrapper,
}
2 changes: 1 addition & 1 deletion llmfoundry/tokenizers/tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def default_chat_template(self):
'{% else %}'
"{{ '\n' + '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}"
'{% endif %}'
'{% if (add_generation_prompt == true) %}'
'{% if (add_generation_prompt == true and loop.last) %}'
"{{ '\n' + '<|im_start|>' + 'assistant' + '\n' }}"
"{% elif (message['role'] == 'assistant') %}"
'{{ eos_token }}'
Expand Down
26 changes: 20 additions & 6 deletions llmfoundry/utils/model_download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
]
PYTORCH_WEIGHTS_PATTERN = 'pytorch_model*.bin*'
SAFE_WEIGHTS_PATTERN = 'model*.safetensors*'
TOKENIZER_FILES = [
'special_tokens_map.json',
'tokenizer.json',
'tokenizer.model',
'tokenizer_config.json',
]

ORAS_PASSWD_PLACEHOLDER = '<placeholder_for_passwd>'
ORAS_CLI = 'oras'
Expand All @@ -45,6 +51,7 @@ def download_from_hf_hub(
model: str,
save_dir: str,
prefer_safetensors: bool = True,
tokenizer_only: bool = False,
token: Optional[str] = None,
):
"""Downloads model files from a Hugging Face Hub model repo.
Expand All @@ -57,6 +64,7 @@ def download_from_hf_hub(
save_dir (str, optional): The local path to the directory where the model files will be downloaded.
prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are
available. Defaults to True.
tokenizer_only (bool): If true, only download tokenizer files.
token (str, optional): The HuggingFace API token. If not provided, the token will be read from the
`HUGGING_FACE_HUB_TOKEN` environment variable.
Expand Down Expand Up @@ -95,10 +103,13 @@ def download_from_hf_hub(
' Please make sure the repo contains either safetensors or pytorch weights.'
)

allow_patterns = TOKENIZER_FILES if tokenizer_only else None

download_start = time.time()
hf_hub.snapshot_download(model,
local_dir=save_dir,
ignore_patterns=ignore_patterns,
allow_patterns=allow_patterns,
token=token)
download_duration = time.time() - download_start
log.info(
Expand Down Expand Up @@ -221,16 +232,18 @@ def download_from_oras(model: str,
config_file: str,
credentials_dir: str,
save_dir: str,
tokenizer_only: bool = False,
concurrency: int = 10):
"""Download from an OCI-compliant registry using oras.
Args:
model: The name of the model to download.
config_file: Path to a YAML config file that maps model names to registry paths.
credentials_dir: Path to a directory containing credentials for the registry. It is expected to contain three
model (str): The name of the model to download.
config_file (str): Path to a YAML config file that maps model and tokenizer names to registry paths.
credentials_dir (str): Path to a directory containing credentials for the registry. It is expected to contain three
files: `username`, `password`, and `registry`, each of which contains the corresponding credential.
save_dir: Path to the directory where files will be downloaded.
concurrency: The number of concurrent downloads to run.
save_dir (str): Path to the directory where files will be downloaded.
tokenizer_only (bool): If true, only download the tokenzier files.
concurrency (int): The number of concurrent downloads to run.
"""
if shutil.which(ORAS_CLI) is None:
raise Exception(
Expand All @@ -253,7 +266,8 @@ def _read_secrets_file(secret_file_path: str,):
with open(config_file, 'r', encoding='utf-8') as f:
configs = yaml.safe_load(f.read())

path = configs['models'][model]
config_type = 'tokenizers' if tokenizer_only else 'models'
path = configs[config_type][model]
registry = secrets['registry']

def get_oras_cmd(username: Optional[str] = None,
Expand Down
20 changes: 16 additions & 4 deletions scripts/misc/download_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
python download_model.py hf --model mosaicml/mpt-7b --save-dir <save_dir> --token <token>
Download from ORAS registry:
python download_model.py oras --registry <registry> --path mosaicml/mpt-7b --save-dir <save_dir>
python download_model.py oras --model mosaicml/mpt-7b --config-file <config_file> \
--credentials-dir <credentials_dir> --save-dir <save_dir>
Download from an HTTP file server:
python download_model.py http --host https://server.com --path mosaicml/mpt-7b --save-dir <save_dir>
python download_model.py http --url https://server.com/models/mosaicml/mpt-7b/ --save-dir <save_dir>
Download from an HTTP file server with fallback to Hugging Face Hub:
python download_model.py http --host https://server.com --path mosaicml/mpt-7b --save-dir <save_dir> \
Expand Down Expand Up @@ -56,6 +57,9 @@ def parse_args() -> argparse.Namespace:

base_parser = argparse.ArgumentParser(add_help=False)
base_parser.add_argument('--save-dir', type=str, required=True)
base_parser.add_argument('--tokenizer-only',
default=False,
action='store_true')

# Add subparser for downloading from Hugging Face Hub.
hf_parser = subparsers.add_parser('hf', parents=[base_parser])
Expand Down Expand Up @@ -85,6 +89,9 @@ def parse_args() -> argparse.Namespace:
download_from = args.download_from

if download_from == 'http':
if args.tokenizer_only:
raise ValueError(
'tokenizer-only is not currently supported for http.')
try:
download_from_http_fileserver(args.url, args.save_dir,
args.ignore_cert)
Expand All @@ -109,7 +116,12 @@ def parse_args() -> argparse.Namespace:
download_from_hf_hub(args.model,
save_dir=args.save_dir,
token=args.token,
tokenizer_only=args.tokenizer_only,
prefer_safetensors=args.prefer_safetensors)
elif download_from == 'oras':
download_from_oras(args.model, args.config_file, args.credentials_dir,
args.save_dir, args.concurrency)
download_from_oras(args.model,
args.config_file,
args.credentials_dir,
args.save_dir,
tokenizer_only=args.tokenizer_only,
concurrency=args.concurrency)
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@
install_requires = [
'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.17.2,<0.18',
'accelerate>=0.25,<0.26', # for HF inference `device_map`
'transformers>=4.36,<4.37',
'transformers>=4.37,<4.38',
'mosaicml-streaming>=0.7.2,<0.8',
'torch>=2.1,<2.1.1',
'datasets==2.15.0',
'datasets>=2.16,<2.17',
'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data
'sentencepiece==0.1.97',
'einops==0.7.0',
Expand Down
10 changes: 10 additions & 0 deletions tests/tokenizers/test_tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@
'Please summarize the goals in this text:\n\nGoing outside has benefits include reducing stress and triggering the relaxation response, which can help us not only feel better mentally, but even heal faster from physical ailments.',
'role':
'user'
}, {
'content': 'You should go outside and touch grass.',
'role': 'assistant'
}, {
'content': 'What else can I do?',
'role': 'user'
}]]

MULTI_TURN_GENERATE_STRING = [
Expand All @@ -118,6 +124,10 @@
Going outside has benefits include reducing stress and triggering the relaxation response, which can help us not only feel better mentally, but even heal faster from physical ailments.<|im_end|>
<|im_start|>assistant
You should go outside and touch grass.<|im_end|><|endoftext|>
<|im_start|>user
What else can I do?<|im_end|>
<|im_start|>assistant
"""
]

Expand Down
1 change: 1 addition & 0 deletions tests/utils/test_model_download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def test_download_from_hf_hub_weights_pref(mock_list_repo_files: MagicMock,
mock_snapshot_download.assert_called_once_with(
test_repo_id,
local_dir=save_dir,
allow_patterns=None,
ignore_patterns=expected_ignore_patterns,
token=None)

Expand Down

0 comments on commit 1f59606

Please sign in to comment.