diff --git a/llmfoundry/models/inference_api_wrapper/__init__.py b/llmfoundry/models/inference_api_wrapper/__init__.py new file mode 100644 index 0000000000..b9cd71ad47 --- /dev/null +++ b/llmfoundry/models/inference_api_wrapper/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from llmfoundry.models.inference_api_wrapper.interface import \ + InferenceAPIEvalWrapper +from llmfoundry.models.inference_api_wrapper.openai_causal_lm import ( + OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper, OpenAITokenizerWrapper) + +__all__ = [ + 'OpenAICausalLMEvalWrapper', + 'OpenAIChatAPIEvalWrapper', + 'OpenAITokenizerWrapper', + 'InferenceAPIEvalWrapper', +] diff --git a/llmfoundry/models/inference_api_wrapper/interface.py b/llmfoundry/models/inference_api_wrapper/interface.py new file mode 100644 index 0000000000..2d84599772 --- /dev/null +++ b/llmfoundry/models/inference_api_wrapper/interface.py @@ -0,0 +1,110 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, Optional + +import torch +from composer.core.types import Batch +from composer.metrics import InContextLearningMetric +from composer.metrics.nlp import (InContextLearningLMAccuracy, + InContextLearningLMExpectedCalibrationError, + InContextLearningMCExpectedCalibrationError, + InContextLearningMultipleChoiceAccuracy, + InContextLearningQAAccuracy, + LanguageCrossEntropy, LanguagePerplexity) +from composer.models import ComposerModel +from torchmetrics import Metric +from transformers import AutoTokenizer + + +class InferenceAPIEvalWrapper(ComposerModel): + + def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer): + self.tokenizer = tokenizer + self.labels = None + # set up training and eval metrics + eval_metrics = [ + LanguageCrossEntropy(), + LanguagePerplexity(), + InContextLearningLMAccuracy(), + InContextLearningMultipleChoiceAccuracy(), + InContextLearningQAAccuracy(), + InContextLearningLMExpectedCalibrationError(), + InContextLearningMCExpectedCalibrationError() + ] + self.eval_metrics = { + metric.__class__.__name__: metric for metric in eval_metrics + } + super().__init__() + + def get_metrics(self, is_train: bool = False): + if is_train: + raise NotImplementedError( + 'You cannot use inference wrappers for training') + else: + metrics = self.eval_metrics + + return metrics if metrics else {} + + def get_next_token_logit_tensor(self, + prompt: str) -> Optional[torch.Tensor]: + raise NotImplementedError + + def rebatch(self, batch: Batch): + # default is a no-op, but Chat API modifies these + return batch + + def eval_forward(self, batch: Batch, outputs: Optional[Any] = None): + # If the batch mode is generate, we will generate a requested number of tokens using the underlying + # model's generate function. Extra generation kwargs can be passed in via the batch. Strings will + # be returned from eval_forward + output_logits_batch = [] + for tokens, cont_idxs in zip(batch['input_ids'], + batch['continuation_indices']): + + seqlen = tokens.shape[0] + tokens = tokens.tolist() + cont_idxs = cont_idxs.tolist() + expected_cont_tokens = tokens[cont_idxs[0]:cont_idxs[-1] + 1] + output_logits = torch.nn.functional.one_hot( + torch.tensor(tokens[1:cont_idxs[0]]), + num_classes=self.tokenizer.vocab_size) + for i in range(len(expected_cont_tokens)): + # decode one token at a time + prompt = self.tokenizer.decode(tokens[:cont_idxs[0]] + + expected_cont_tokens[0:i]) + next_logit_tensor = self.get_next_token_logit_tensor(prompt) + if next_logit_tensor is None: + continue + output_logits = torch.cat( + [output_logits, + next_logit_tensor.reshape(1, -1)]) + padding = torch.nn.functional.one_hot( + torch.full((seqlen - output_logits.shape[0],), + self.tokenizer.pad_token_id), + num_classes=self.tokenizer.vocab_size) + output_logits = torch.cat([output_logits, padding]) + output_logits_batch.append(output_logits) + + return torch.stack(output_logits_batch).to(batch['input_ids'].device) + + def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None: + batch = self.rebatch(batch) + self.labels = batch.pop('labels') + self.labels[:, :-1] = self.labels[:, 1:].clone() + self.labels[:, -1] = -100 + if isinstance(metric, InContextLearningMetric) and batch.get( + 'mode', None) == 'icl_task': + assert self.labels is not None + metric.update(batch, outputs, self.labels) + else: + raise NotImplementedError( + 'Inference API wrapper only supports InContextLearningMetrics and mode=icl_task' + ) + + def forward(self): + raise NotImplementedError( + "Inference API wrapper doesn't support forward") + + def loss(self): + raise NotImplementedError("Inference API wrapper doesn't support loss") diff --git a/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py b/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py new file mode 100644 index 0000000000..14228134d2 --- /dev/null +++ b/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py @@ -0,0 +1,324 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Implements a OpenAI chat and causal LM inference API wrappers.""" + +import logging +import os +from time import sleep +from typing import Any, Dict, List, Optional, Union + +import torch +from composer.core.types import Batch +from composer.utils.import_helpers import MissingConditionalImportError +from transformers import (AutoTokenizer, PreTrainedTokenizer, + PreTrainedTokenizerFast) + +log = logging.getLogger(__name__) + +from llmfoundry.models.inference_api_wrapper.interface import \ + InferenceAPIEvalWrapper + +__all__ = [ + 'OpenAICausalLMEvalWrapper', 'OpenAIChatAPIEvalWrapper', + 'OpenAITokenizerWrapper' +] + +Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + +MAX_RETRIES = 10 + + +class OpenAITokenizerWrapper(AutoTokenizer): + + def __init__(self, name: str) -> None: + try: + import tiktoken + except ImportError as e: + raise MissingConditionalImportError( + extra_deps_group='openai', + conda_package='tiktoken', + conda_channel='conda-forge') from e + self.tokenizer = tiktoken.encoding_for_model(name) + + def __call__(self, x: str, add_special_tokens: bool = False): + if add_special_tokens: + raise ValueError( + 'OpenAITokenizerWrapper only supports add_special_tokens=False') + return self.encode(x) + + def encode(self, + x: Union[str, List[str]], + add_special_tokens: bool = False): + if add_special_tokens: + raise ValueError( + 'OpenAITokenizerWrapper only supports add_special_tokens=False') + if isinstance(x, str): + return { + 'input_ids': + self.tokenizer.encode(x, allowed_special={'<|endoftext|>'}) + } + elif isinstance(x, + list): # pyright: ignore [reportUnnecessaryIsInstance] + return { + 'input_ids': + self.tokenizer.encode_batch( + x, allowed_special={'<|endoftext|>'}) + } + else: + raise ValueError( + f'`encode` argument must be str or List[str], got: {type(x)}') + + def decode( + self, + x: Union[List[int], List[List[int]]], + ): + if len(x) > 0 and isinstance(x[0], list): + return self.tokenizer.decode_batch( + x) # pyright: ignore [reportGeneralTypeIssues] + else: + assert isinstance(x, list) + return self.tokenizer.decode( + x) # pyright: ignore [reportGeneralTypeIssues] + + @property + def pad_token_id(self): + return self.tokenizer.eot_token + + @property + def eos_token_id(self): + return self.tokenizer.eot_token + + @property + def vocab_size(self): + return self.tokenizer.n_vocab + + def construct_logit_tensor(self, logprobs: Dict[str, float]): + """Construct tensor of shape (vocab_size,) mapping words to logprobs. + + Args: + logprobs (Dict[str, float]): Dictionary mapping tokens to log probabilities assigned to them by the model. + """ + tensor = torch.tensor([min(logprobs.values()) - 1] * (self.vocab_size)) + for k in logprobs: + encoding = self.encode(k)['input_ids'] + idx = encoding[0] + tensor[idx] = logprobs[k] + return tensor + + +class OpenAIEvalInterface(InferenceAPIEvalWrapper): + + def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: + super().__init__(model_cfg, tokenizer) + try: + import openai + except ImportError as e: + raise MissingConditionalImportError( + extra_deps_group='openai', + conda_package='openai', + conda_channel='conda-forge') from e + openai.api_key = os.getenv('OPENAI_API_KEY') + self.model_name = model_cfg['version'] + + def generate_completion(self, prompt: str, num_tokens: int): + raise NotImplementedError() + + def process_result(self, completion: Optional[dict]): + raise NotImplementedError() + + def get_next_token_logit_tensor(self, prompt: str, num_tokens: int = 1): + completion = self.try_generate_completion(prompt, num_tokens) + return self.process_result(completion) + + def try_generate_completion(self, prompt: str, num_tokens: int): + try: + from openai.error import RateLimitError + except ImportError as e: + raise MissingConditionalImportError( + extra_deps_group='openai', + conda_package='openai', + conda_channel='conda-forge') from e + tries = 0 + completion = None + while tries < MAX_RETRIES: + tries += 1 + try: + + completion = self.generate_completion(prompt, num_tokens) + break + except RateLimitError as e: + if 'You exceeded your current quota' in str(e._message): + raise e + sleep(60) + continue + except Exception: + continue + return completion + + +class OpenAIChatAPIEvalWrapper(OpenAIEvalInterface): + + def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: + super().__init__(model_cfg, tokenizer) + try: + import openai + except ImportError as e: + raise MissingConditionalImportError( + extra_deps_group='openai', + conda_package='openai', + conda_channel='conda-forge') from e + + self.generate_completion = lambda prompt, num_tokens: openai.ChatCompletion.create( + self.model_name, + messages=[{ + 'role': 'user', + 'content': prompt + }], + max_tokens=num_tokens, + temperature=0.0) + + def retokenize(self, tokens: List[int], cont_idxs: List[int]): + """Chat API will never respond with a word-initial space. + + If the continuation tokens begin with a word initial space, we need to + re-tokenize with the space removed. + """ + original_len = len(tokens) + retokenized_continuation = self.tokenizer.encode( + self.tokenizer.decode(tokens[cont_idxs[0]:cont_idxs[-1] + + 1]).strip())['input_ids'] + + # replace the original continuation with the retokenized continuation + padding + padding = [tokens[-1]] * ( + len(tokens) - len(tokens[:cont_idxs[0]] + retokenized_continuation)) + tokens = tokens[:cont_idxs[0]] + retokenized_continuation + padding + + if len(tokens) > original_len: + # this only happens if we were already at max seq len and the continuation got LARGER + tokens = tokens[-original_len:] + cont_idxs = list( + range(original_len - len(retokenized_continuation), + original_len)) + else: + cont_idxs = list( + range(cont_idxs[0], + cont_idxs[0] + len(retokenized_continuation))) + return torch.tensor(tokens), torch.tensor(cont_idxs) + + def rebatch(self, batch: Batch): + """Chat API tokenization has different behavior than GPT3. + + Model responses will never begin with spaces even if the continuation is + expected to, so we need to retokenize the input to account for that. + """ + new_batch: Dict[str, Union[List[torch.Tensor], torch.Tensor]] = { + 'input_ids': [], + 'continuation_indices': [], + 'labels': [] + } + for tokens, cont_idxs in zip(batch['input_ids'], + batch['continuation_indices']): + tokens, cont_idxs = self.retokenize(tokens.tolist(), + cont_idxs.tolist()) + + assert isinstance(new_batch['input_ids'], list) + new_batch['input_ids'].append(tokens) + assert isinstance(new_batch['labels'], list) + new_batch['labels'].append(tokens) + assert isinstance(new_batch['continuation_indices'], list) + new_batch['continuation_indices'].append(cont_idxs) + + new_batch.update({ + k: torch.stack(new_batch[k]) # pyright: ignore + for k in ['input_ids', 'labels'] + }) + + new_batch.update({k: v for k, v in batch.items() if k not in new_batch}) + + return new_batch + + def eval_forward(self, batch: Batch, outputs: Optional[Any] = None): + # Override the base class because Chat's API always strips spacing from model outputs resulting in different tokens + # than what the continuation would expect. + # Get around this issue by retokenizing the batch to remove spacing from the continuation as well as + # decoding the whole continuation at once. + output_logits_batch = [] + batch = self.rebatch(batch) + for tokens, cont_idxs in zip(batch['input_ids'], + batch['continuation_indices']): + + seqlen = tokens.shape[0] + tokens = tokens.tolist() + cont_idxs = cont_idxs.tolist() + expected_cont_tokens = tokens[cont_idxs[0]:cont_idxs[-1] + 1] + output_logits = torch.nn.functional.one_hot( + torch.tensor(tokens[1:cont_idxs[0]]), + num_classes=self.tokenizer.vocab_size) + + prompt = self.tokenizer.decode(tokens[:cont_idxs[0]]) + next_logit_tensor = self.get_next_token_logit_tensor( + prompt, num_tokens=len(expected_cont_tokens)) + + if next_logit_tensor is not None: + output_logits = torch.cat([output_logits, next_logit_tensor]) + padding = torch.nn.functional.one_hot( + torch.full((seqlen - output_logits.shape[0],), + self.tokenizer.pad_token_id), + num_classes=self.tokenizer.vocab_size) + output_logits = torch.cat([output_logits, padding]) + output_logits_batch.append(output_logits) + + return torch.stack(output_logits_batch).to(batch['input_ids'].device) + + def process_result(self, completion: Optional[dict]): + assert isinstance(completion, dict) + if len(completion['choices']) > 0: + tensors = [] + for t in self.tokenizer.encode(completion['choices'][0]['message'] + ['content'])['input_ids']: + tensors.append( + self.tokenizer.construct_logit_tensor( + {self.tokenizer.decode([t]): 0.0})) + + if len(tensors) == 0: + return None + return torch.stack(tensors) + else: + # the model sometimes stops early even though we are still requesting tokens! + # not sure if there's a fix + return None + + +class OpenAICausalLMEvalWrapper(OpenAIEvalInterface): + + def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: + super().__init__(model_cfg, tokenizer) + try: + import openai + except ImportError as e: + raise MissingConditionalImportError( + extra_deps_group='openai', + conda_package='openai', + conda_channel='conda-forge') from e + + self.generate_completion = lambda prompt, num_tokens: openai.Completion.create( + engine=self.model_name, + prompt=prompt, + max_tokens=1, + logprobs=5, + temperature=0.0) + + def process_result(self, completion: Optional[dict]): + if completion is None: + raise ValueError("Couldn't generate model output") + + assert isinstance(completion, dict) + if len(completion['choices'][0]['logprobs']['top_logprobs']) > 0: + tensor = self.tokenizer.construct_logit_tensor( + dict(completion['choices'][0]['logprobs']['top_logprobs'][0])) + return tensor + else: + # the model sometimes stops early even though we are still requesting tokens! + # not sure if there's a fix + return None diff --git a/llmfoundry/models/model_registry.py b/llmfoundry/models/model_registry.py index 02a709740e..be09a69835 100644 --- a/llmfoundry/models/model_registry.py +++ b/llmfoundry/models/model_registry.py @@ -3,6 +3,8 @@ from llmfoundry.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM, ComposerHFT5) +from llmfoundry.models.inference_api_wrapper import (OpenAICausalLMEvalWrapper, + OpenAIChatAPIEvalWrapper) from llmfoundry.models.mpt import ComposerMPTCausalLM COMPOSER_MODEL_REGISTRY = { @@ -10,4 +12,6 @@ 'hf_causal_lm': ComposerHFCausalLM, 'hf_prefix_lm': ComposerHFPrefixLM, 'hf_t5': ComposerHFT5, + 'openai_causal_lm': OpenAICausalLMEvalWrapper, + 'openai_chat': OpenAIChatAPIEvalWrapper } diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 32f7aceea3..b89ff899ee 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -30,6 +30,8 @@ GlobalLRScaling, HuggingFaceCheckpointer, LayerFreezing, MonolithicCheckpointSaver, ScheduledGarbageCollector) +from llmfoundry.models.inference_api_wrapper.openai_causal_lm import \ + OpenAITokenizerWrapper from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion, DecoupledLionW, DecoupledLionW_8bit) @@ -110,6 +112,8 @@ def build_logger(name: str, kwargs: Dict[str, Any]) -> LoggerDestination: return WandBLogger(**kwargs) elif name == 'tensorboard': return TensorboardLogger(**kwargs) + elif name == 'in_memory_logger': + return InMemoryLogger(**kwargs) elif name == 'mlflow': return MLFlowLogger(**kwargs) elif name == 'inmemory': @@ -164,21 +168,24 @@ def build_scheduler(name: str, def build_tokenizer( tokenizer_name: str, tokenizer_kwargs: Dict[str, Any]) -> PreTrainedTokenizerBase: - os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1' - os.environ['TOKENIZERS_PARALLELISM'] = 'false' + if tokenizer_name == 'openai': + return OpenAITokenizerWrapper(**tokenizer_kwargs) + else: + os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1' + os.environ['TOKENIZERS_PARALLELISM'] = 'false' - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, - **tokenizer_kwargs) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, + **tokenizer_kwargs) - # HuggingFace does not respect the model_max_length kwarg, and overrides it with - # min(kwargs['model_max_length'], original_config['model_max_length']), so we - # explicitly set it here - tokenizer.model_max_length = tokenizer_kwargs.get( - 'model_max_length', - int(1e30), - ) + # HuggingFace does not respect the model_max_length kwarg, and overrides it with + # min(kwargs['model_max_length'], original_config['model_max_length']), so we + # explicitly set it here + tokenizer.model_max_length = tokenizer_kwargs.get( + 'model_max_length', + int(1e30), + ) - return tokenizer + return tokenizer def build_icl_evaluators( diff --git a/mcli/mcli-openai-eval.yaml b/mcli/mcli-openai-eval.yaml new file mode 100644 index 0000000000..6275d9d578 --- /dev/null +++ b/mcli/mcli-openai-eval.yaml @@ -0,0 +1,65 @@ +integrations: +- integration_type: git_repo + git_repo: mosaicml/llm-foundry + git_branch: # use your branch + # git_commit: 29d65cc26853c09f6de7542978056ddb0b07e98c # OR use your commit hash + pip_install: -e ".[gpu,openai]" + ssh_clone: false # Should be true if using a private repo + +command: | + cd llm-foundry/scripts + composer eval/eval.py /mnt/config/parameters.yaml + +# Mosaic Cloud will use run_name (with a unique suffix) to populate the env var $RUN_NAME +run_name: openai-eval +# gpu_num: # +# gpu_type: # +cluster: # replace with your cluster here! + +image: mosaicml/llm-foundry:2.0.1_cu118-latest + +# The below is injected as a YAML file: /mnt/config/parameters.yaml +parameters: + seed: 1 + max_seq_len: 1024 + device_eval_batch_size: 4 + models: + - + model_name: openai/davinci + model: + name: openai_causal_lm + version: davinci + tokenizer: + name: openai + kwargs: + name: davinci + - + model_name: openai/ada + model: + name: openai_causal_lm + version: ada + tokenizer: + name: openai + kwargs: + name: ada + - + model_name: openai/gpt-4 + model: + name: openai_chat + version: gpt-4 + tokenizer: + name: openai + kwargs: + name: gpt-4 + - + model_name: openai/gpt-3.5-turbo + model: + name: openai_chat + version: gpt-3.5-turbo + tokenizer: + name: openai + kwargs: + name: gpt-3.5-turbo + + icl_tasks: 'eval/yamls/lm_tasks.yaml' + eval_gauntlet: 'eval/yamls/eval_gauntlet.yaml' diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 1ba723a172..24e05528a6 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -214,7 +214,10 @@ def main(cfg: DictConfig): device_eval_batch_size: int = pop_config(cfg, 'device_eval_batch_size', must_exist=True) - precision: str = pop_config(cfg, 'precision', must_exist=True) + precision: str = pop_config(cfg, + 'precision', + must_exist=False, + default_value=None) python_log_level: Optional[str] = pop_config(cfg, 'python_log_level', must_exist=False, diff --git a/scripts/eval/yamls/eval_gauntlet.yaml b/scripts/eval/yamls/eval_gauntlet.yaml index 0234cd1591..e86c505e9a 100644 --- a/scripts/eval/yamls/eval_gauntlet.yaml +++ b/scripts/eval/yamls/eval_gauntlet.yaml @@ -112,6 +112,50 @@ eval_gauntlet: - name: boolq num_fewshot: 10 random_baseline: 0.5 + - name: world_knowledge_lm_task_subscore + benchmarks: + - name: jeopardy + num_fewshot: 10 + random_baseline: 0 + - name: bigbench_qa_wikidata + num_fewshot: 10 + random_baseline: 0 + - name: language_understanding_lm_task_subscore + benchmarks: + - name: lambada_openai + num_fewshot: 0 + random_baseline: 0.0 + - name: bigbench_conlang_translation + num_fewshot: 0 + random_baseline: 0.0 + - name: symbolic_problem_solving_lm_task_subscore + benchmarks: + - name: bigbench_dyck_languages + num_fewshot: 10 + random_baseline: 0 + - name: bigbench_cs_algorithms + num_fewshot: 10 + random_baseline: 0 + - name: bigbench_operators + num_fewshot: 10 + random_baseline: 0.0 + - name: bigbench_repeat_copy_logic + num_fewshot: 10 + random_baseline: 0.0 + - name: simple_arithmetic_withspaces + num_fewshot: 10 + random_baseline: 0.0 + - name: simple_arithmetic_nospaces + num_fewshot: 10 + random_baseline: 0.0 + - name: reading_comprehension_lm_task_subscore + benchmarks: + - name: pubmed_qa_labeled + num_fewshot: 10 + random_baseline: 0.0 + - name: squad + num_fewshot: 10 + random_baseline: 0 - name: world_knowledge_lite benchmarks: - name: jeopardy diff --git a/scripts/eval/yamls/openai_eval.yaml b/scripts/eval/yamls/openai_eval.yaml new file mode 100644 index 0000000000..e1afe78015 --- /dev/null +++ b/scripts/eval/yamls/openai_eval.yaml @@ -0,0 +1,34 @@ +seed: 1 +max_seq_len: 1024 +device_eval_batch_size: 4 +models: +- + model_name: openai/davinci + model: + name: openai_causal_lm + version: davinci + tokenizer: + name: openai + kwargs: + name: davinci +- + model_name: openai/gpt-4 + model: + name: openai_chat + version: gpt-4 + tokenizer: + name: openai + kwargs: + name: gpt-4 +- + model_name: openai/gpt-3.5-turbo + model: + name: openai_chat + version: gpt-3.5-turbo + tokenizer: + name: openai + kwargs: + name: gpt-3.5-turbo + +icl_tasks: 'eval/yamls/lm_tasks.yaml' +eval_gauntlet: 'eval/yamls/eval_gauntlet.yaml' diff --git a/setup.py b/setup.py index 1a93bd05f7..5cd1309922 100644 --- a/setup.py +++ b/setup.py @@ -99,6 +99,10 @@ 'peft==0.4.0', ] +extra_deps['openai'] = [ + 'openai==0.27.8', + 'tiktoken==0.4.0', +] extra_deps['all'] = set(dep for deps in extra_deps.values() for dep in deps) setup( diff --git a/tests/test_inference_api_eval_wrapper.py b/tests/test_inference_api_eval_wrapper.py new file mode 100644 index 0000000000..ba065b6020 --- /dev/null +++ b/tests/test_inference_api_eval_wrapper.py @@ -0,0 +1,141 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict +from unittest.mock import patch + +import pytest +from omegaconf import DictConfig, ListConfig + +from llmfoundry.models.inference_api_wrapper import (OpenAICausalLMEvalWrapper, + OpenAIChatAPIEvalWrapper, + OpenAITokenizerWrapper) +from llmfoundry.utils.builders import build_icl_evaluators + + +def load_icl_config(): + return DictConfig({ + 'icl_tasks': + ListConfig([ + DictConfig({ + 'label': + 'jeopardy', + 'dataset_uri': + 'scripts/eval/local_data/world_knowledge/jeopardy_all.jsonl', + 'num_fewshot': [0, 1], + 'icl_task_type': + 'language_modeling', + 'continuation_delimiter': + '\nAnswer: ', + 'has_categories': + True + }) + ]) + }) + + +def mock_create(**kwargs: Dict[str, str]): + prompt = kwargs['prompt'] + if prompt == 'AMERICAN HISTORY: On May 29, 1765 Patrick Henrys Stamp Act protest was interrupted with this one word\nAnswer:': # pyright: ignore[reportUnnecessaryComparison] + return { + 'choices': [{ + 'logprobs': { + 'top_logprobs': [{ + ' Tre': 0, + }], + }, + }], + } + elif prompt == 'AMERICAN HISTORY: On May 29, 1765 Patrick Henrys Stamp Act protest was interrupted with this one word\nAnswer: Tre': # pyright: ignore[reportUnnecessaryComparison] + return { + 'choices': [{ + 'logprobs': { + 'top_logprobs': [{ + 'ason': 0, + }], + }, + }], + } + elif prompt == 'AMERICAN HISTORY: On May 29, 1765 Patrick Henrys Stamp Act protest was interrupted with this one word\nAnswer: Treason': # pyright: ignore[reportUnnecessaryComparison] + return { + 'choices': [{ + 'logprobs': { + 'top_logprobs': [{ + '!': 0, + }], + }, + }], + } + else: + # dummy token to make sure the model is incorrect on any other prompt + return { + 'choices': [{ + 'logprobs': { + 'top_logprobs': [{ + ' ': 0, + }], + }, + }], + } + + +def test_openai_api_eval_wrapper(tmp_path: str): + _ = pytest.importorskip('openai') + with patch('openai.Completion') as mock: + mock.create = mock_create + model_name = 'davinci' + tokenizer = OpenAITokenizerWrapper(model_name) + model = OpenAICausalLMEvalWrapper(model_cfg={'version': model_name}, + tokenizer=tokenizer) + task_cfg = load_icl_config() + evaluators, _ = build_icl_evaluators(task_cfg.icl_tasks, + tokenizer, + 1024, + 2, + destination_dir=str(tmp_path)) + + batch = next(evaluators[0].dataloader.dataloader.__iter__()) + result = model.eval_forward(batch) + model.update_metric(batch, + result, + metric=model.get_metrics() + ['InContextLearningLMAccuracy']) # pyright: ignore + acc = model.get_metrics( + )['InContextLearningLMAccuracy'].compute( # pyright: ignore + ) # pyright: ignore + assert acc == 0.5 + + +def test_chat_api_eval_wrapper(tmp_path: str): + _ = pytest.importorskip('openai') + with patch('openai.ChatCompletion') as mock: + mock.create.return_value = { + 'choices': [{ + 'message': { + 'role': 'assistant', + 'content': 'Treason!' + }, + }], + } + model_name = 'gpt-3.5-turbo' + tokenizer = OpenAITokenizerWrapper(model_name) + chatmodel = OpenAIChatAPIEvalWrapper(model_cfg={'version': model_name}, + tokenizer=tokenizer) + task_cfg = load_icl_config() + evaluators, _ = build_icl_evaluators(task_cfg.icl_tasks, + tokenizer, + 1024, + 2, + destination_dir=str(tmp_path)) + + batch = next(evaluators[0].dataloader.dataloader.__iter__()) + result = chatmodel.eval_forward(batch) + chatmodel.update_metric( + batch, + result, + metric=chatmodel.get_metrics() + ['InContextLearningLMAccuracy']) # pyright: ignore + acc = chatmodel.get_metrics( + )['InContextLearningLMAccuracy'].compute( # pyright: ignore + ) + assert acc == 0.5