diff --git a/README.md b/README.md
index 04bad9c519..1d3f6d5df4 100644
--- a/README.md
+++ b/README.md
@@ -228,7 +228,7 @@ python inference/convert_composer_to_hf.py \
# --hf_repo_for_upload user-org/repo-name
# Evaluate the model on a subset of tasks
-python eval/eval.py \
+composer eval/eval.py \
eval/yamls/hf_eval.yaml \
icl_tasks=eval/yamls/copa.yaml \
model_name_or_path=mpt-125m-hf
diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py
index 3bb9eed043..51fa67993a 100644
--- a/llmfoundry/__init__.py
+++ b/llmfoundry/__init__.py
@@ -4,6 +4,11 @@
import torch
try:
+ # Before importing any transformers models, we need to disable transformers flash attention if
+ # we are in an environment with flash attention version <2. Transformers hard errors on a not properly
+ # gated import otherwise.
+ import transformers
+
from llmfoundry import optim, utils
from llmfoundry.data import (ConcatTokensDataset,
MixtureOfDenoisersCollator, NoConcatDataset,
@@ -14,8 +19,8 @@
ComposerHFT5)
from llmfoundry.models.layers.attention import (
MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
- flash_attn_fn, scaled_multihead_dot_product_attention,
- triton_flash_attn_fn)
+ flash_attn_fn, is_flash_v1_installed,
+ scaled_multihead_dot_product_attention, triton_flash_attn_fn)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import (FFN_CLASS_REGISTRY, MPTMLP,
build_ffn)
@@ -24,6 +29,8 @@
MPTForCausalLM, MPTModel,
MPTPreTrainedModel)
from llmfoundry.tokenizers import TiktokenTokenizerWrapper
+ if is_flash_v1_installed():
+ transformers.utils.is_flash_attn_available = lambda: False
except ImportError as e:
try:
diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py
index aa3beda513..3050529a5a 100644
--- a/llmfoundry/callbacks/hf_checkpointer.py
+++ b/llmfoundry/callbacks/hf_checkpointer.py
@@ -4,13 +4,14 @@
import contextlib
import copy
import logging
+import math
import os
import tempfile
from pathlib import Path
from typing import Optional, Union
import torch
-from composer.core import Callback, Event, State, Time
+from composer.core import Callback, Event, State, Time, TimeUnit
from composer.core.state import fsdp_state_dict_type_context
from composer.loggers import Logger, MLFlowLogger
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
@@ -83,6 +84,13 @@ def __init__(
self.huggingface_folder_name_fstr = os.path.join(
'huggingface', huggingface_folder_name)
+
+ if isinstance(save_interval, str):
+ save_interval = Time.from_timestring(save_interval)
+ if isinstance(save_interval, int):
+ save_interval = Time(save_interval, TimeUnit.EPOCH)
+
+ self.save_interval = save_interval
self.check_interval = create_interval_scheduler(
save_interval, include_end_of_training=True)
self.upload_to_object_store = (self.backend != '')
@@ -128,6 +136,21 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set(
'5GB')
+ def _is_last_batch(self, state: State):
+ elapsed_duration = state.get_elapsed_duration()
+ if elapsed_duration is not None and elapsed_duration >= 1.0:
+ return True
+
+ assert state.max_duration is not None # for pyright
+ # If the save interval is specified as 1dur, and the max duration is in epoch units
+ # we need a special case to identify we are on the last batch and should write the mlflow checkpoint
+ if self.save_interval.unit == TimeUnit.DURATION and self.save_interval.value == 1 and state.max_duration.unit == TimeUnit.EPOCH:
+ assert state.dataloader_len is not None # for pyright
+ return int(state.timestamp.batch) % math.ceil(
+ state.max_duration.value * state.dataloader_len) == 0
+
+ return False
+
def _save_checkpoint(self, state: State, logger: Logger):
del logger # unused
@@ -224,8 +247,8 @@ def _save_checkpoint(self, state: State, logger: Logger):
overwrite=self.overwrite,
)
- elapsed_duration = state.get_elapsed_duration()
- if self.mlflow_registered_model_name is not None and elapsed_duration is not None and elapsed_duration >= 1.0:
+ if self.mlflow_registered_model_name and self._is_last_batch(
+ state):
components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer
diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py
index 13857e9bb9..eb90b07045 100644
--- a/llmfoundry/models/hf/hf_causal_lm.py
+++ b/llmfoundry/models/hf/hf_causal_lm.py
@@ -24,8 +24,7 @@
from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
-from llmfoundry.models.layers.llama_attention_monkeypatch import \
- get_llama_attention_patch_fn
+from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.utils import init_empty_weights
try:
@@ -95,12 +94,28 @@ def __init__(self, om_model_config: Union[DictConfig,
# load the model config
trust_remote_code = om_model_config.get('trust_remote_code', True)
use_auth_token = om_model_config.get('use_auth_token', False)
+ use_flash_attention_2 = om_model_config.get('use_flash_attention_2',
+ False)
+ if use_flash_attention_2 and not is_flash_v2_installed():
+ raise ValueError(
+ 'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. '
+ + 'Please install flash_attn==2.3.2`.')
+
config = AutoConfig.from_pretrained(
om_model_config.pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
)
+ # This is not how you are supposed to set this, but transformers currently only
+ # supports enabling flash attention 2 when using the from_pretrained API.
+ # We need to support it for both from_pretrained and from_config, so we have to
+ # set the private attribute here. This will just skip all of transformers'
+ # validation logic that it is ok to use flash attention 2, so we check
+ # whether it is installed above, and whether the chosen config supports it here.
+ # https://github.com/huggingface/transformers/issues/26878
+ config._flash_attn_2_enabled = use_flash_attention_2
+
# set config overrides
for k, v in om_model_config.get('config_overrides', {}).items():
if not hasattr(config, k):
@@ -200,6 +215,9 @@ def __init__(self, om_model_config: Union[DictConfig,
)
from transformers.models.llama.modeling_llama import \
LlamaAttention
+
+ from llmfoundry.models.layers.llama_attention_monkeypatch import \
+ get_llama_attention_patch_fn
LlamaAttention.forward = get_llama_attention_patch_fn(
attention_patch_type)
model.config.use_cache = False
diff --git a/llmfoundry/tokenizers/tiktoken.py b/llmfoundry/tokenizers/tiktoken.py
index 001be6a030..45192e09dd 100644
--- a/llmfoundry/tokenizers/tiktoken.py
+++ b/llmfoundry/tokenizers/tiktoken.py
@@ -21,6 +21,7 @@ def __init__(self,
model_name: Optional[str] = None,
encoding_name: Optional[str] = None,
add_bos_token: bool = False,
+ add_eos_token: bool = False,
unk_token: Optional[str] = '<|endoftext|>',
eos_token: Optional[str] = '<|endoftext|>',
bos_token: Optional[str] = '<|endoftext|>',
@@ -36,6 +37,7 @@ def __init__(self,
encoding_name (Optional[str], optional): The name of the encoding to load from tiktoken. Defaults to None.
Either model_name or encoding_name must be set, but not both.
add_bos_token (bool, optional): Whether to add bos tokens. Defaults to False.
+ add_eos_token (bool, optional): Whether to add eos tokens. Defaults to False.
unk_token (Optional[str], optional): The unk token. Defaults to '<|endoftext|>'.
eos_token (Optional[str], optional): The eos token. Defaults to '<|endoftext|>'.
bos_token (Optional[str], optional): The bos token. Defaults to '<|endoftext|>'.
@@ -66,10 +68,12 @@ def __init__(self,
'You need to specify either model_name or encoding_name.')
self.add_bos_token = add_bos_token
+ self.add_eos_token = add_eos_token
super().__init__(model_name=model_name,
encoding_name=encoding_name,
add_bos_token=add_bos_token,
+ add_eos_token=add_eos_token,
unk_token=unk_token,
eos_token=eos_token,
bos_token=bos_token,
@@ -151,7 +155,7 @@ def convert_ids_to_tokens(
"""
if isinstance(ids, int):
if ids in self.added_tokens_decoder:
- return self.added_tokens_decoder[ids]
+ return str(self.added_tokens_decoder[ids])
return self._convert_id_to_token(ids)
@@ -167,7 +171,7 @@ def convert_ids_to_tokens(
if index in self.added_tokens_decoder:
tokens.append(self.encoding.decode(current_stream))
current_stream = []
- tokens.append(self.added_tokens_decoder[index])
+ tokens.append(str(self.added_tokens_decoder[index]))
else:
current_stream.append(index)
@@ -179,17 +183,15 @@ def build_inputs_with_special_tokens(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None) -> List[int]:
- if self.add_bos_token:
- bos_token_ids = [self.bos_token_id]
- else:
- bos_token_ids = []
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
- output = bos_token_ids + token_ids_0
+ output = bos_token_id + token_ids_0 + eos_token_id
- if token_ids_1 is None:
- return output
+ if token_ids_1 is not None:
+ output = output + bos_token_id + token_ids_1 + eos_token_id
- return output + bos_token_ids + token_ids_1
+ return output
def get_special_tokens_mask(
self,
@@ -221,15 +223,13 @@ def get_special_tokens_mask(
token_ids_1=token_ids_1,
already_has_special_tokens=True)
- if not self.add_bos_token:
- return super().get_special_tokens_mask(
- token_ids_0=token_ids_0,
- token_ids_1=token_ids_1,
- already_has_special_tokens=False)
+ bos_token_id = [1] if self.add_bos_token else []
+ eos_token_id = [1] if self.add_eos_token else []
if token_ids_1 is None:
- return [1] + ([0] * len(token_ids_0))
- return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
+ return (bos_token_id + ([0] * len(token_ids_0)) + eos_token_id +
+ bos_token_id + ([0] * len(token_ids_1)) + eos_token_id)
def create_token_type_ids_from_sequences(
self,
diff --git a/scripts/train/README.md b/scripts/train/README.md
index f10fdf59f0..4c706dc040 100644
--- a/scripts/train/README.md
+++ b/scripts/train/README.md
@@ -5,14 +5,15 @@ This README walks through pretraining and finetuning a large language model usin
#### Table of Contents
1. [Part 1: LLM Pretraining](#llmpretraining)
1. [Installation](#installation)
- 2. [Dataset Preparation](#datasetpreparation)
- 3. [How to start single and multi-node pretraining](#howtostartpretraining)
-2. [Part 2: LLM Finetuning](#llmfinetuning)
+ 1. [Dataset Preparation](#datasetpreparation)
+ 1. [How to start single and multi-node pretraining](#howtostartpretraining)
+1. [Part 2: LLM Finetuning](#llmfinetuning)
1. [Using a dataset on the HuggingFace Hub](#hfdataset)
- 2. [Using a local dataset](#localdataset)
- 3. [Using a StreamingDataset (MDS) formatted dataset locally or in an object store](#mdsdataset)
-3. [FAQ: How many GPUs do I need to train a LLM?](#howmandygpus)
-4. [FAQ: Optimizing Performance](#optimizingperformance)
+ 1. [Using a local dataset](#localdataset)
+ 1. [Using a StreamingDataset (MDS) formatted dataset locally or in an object store](#mdsdataset)
+1. [Using Flash Attention](#flashattention)
+1. [FAQ: How many GPUs do I need to train a LLM?](#howmandygpus)
+1. [FAQ: Optimizing Performance](#optimizingperformance)
# Part 1: LLM Pretraining
@@ -332,6 +333,53 @@ train_loader:
...
```
+# Using Flash Attention
+
+Flash Attention is an optimized implementation of the attention mechanism, first introduced by [Dao et al.](https://github.com/Dao-AILab/flash-attention). There are three versions of Flash Attention that can be used with LLM Foundry: Flash Attention V1, Flash Attention V2, and a Triton implementation of Flash Attention. To start, we recommend using one of our [provided Docker images](../../README.md#mosaicml-docker-images) corresponding to the Flash Attention version you would like to use. The Triton implementation can be used with either Flash Attention V1 or V2. Next, how you specify to use Flash Attention depends on which model you are using.
+
+For MPT, you can specify Flash Attention in your YAML like so:
+```yaml
+model:
+ name: mpt_causal_lm
+ ...
+ attn_config:
+ # Will use either V1 or V2 depending on what is installed
+ # "triton" will use the Triton implementation
+ attn_impl: flash
+ ...
+```
+
+If loading MPT from the HuggingFace Hub, you can specify Flash Attention in your YAML like so:
+```yaml
+model:
+ name: hf_causal_lm
+ pretrained_model_name_or_path: mosaicml/mpt-7b
+ ...
+ config_overrides:
+ # Will use either V1 or V2 depending on what is installed
+ # "triton" will use the Triton implementation
+ attn_config:
+ attn_impl: flash
+ ...
+```
+
+For any HuggingFace model that supports Flash Attention (e.g. Llama and Mistral), you can specify Flash Attention in your YAML like so:
+```yaml
+model:
+ name: hf_causal_lm
+ use_flash_attention_2: True # Will be automatically set to True if Flash Attention V2 is installed and the model supports it
+ ...
+```
+HuggingFace models currently only support Flash Attention V2.
+
+For Llama specifically, we have another option if you would like to use the Triton implementation of Flash Attention. You can specify this in your YAML like so:
+```yaml
+model:
+ name: hf_causal_lm
+ pretrained_model_name_or_path: meta-llama/Llama-2-7b-hf
+ attention_patch_type: triton
+ ...
+```
# FAQ: How many GPUs do I need to train a LLM?
This is a complicated question in general, but if we assume that you are using FSDP with `FULL_SHARD`,
diff --git a/scripts/train/benchmarking/README.md b/scripts/train/benchmarking/README.md
index 1bbf399e88..c3c8bc1c74 100644
--- a/scripts/train/benchmarking/README.md
+++ b/scripts/train/benchmarking/README.md
@@ -139,7 +139,7 @@ Our microbatching engine enables microbatch sizes that do not divde Global Batch
## A100 80GB with 1600 Gbps node-node interconnect (RoCE)
| Model | SeqLen (T) | # GPUs | GPU | MFU | HFU | Model TFLOP | MicroBatchSize | GradAccum | GlobalBatchSize | Throughput (S/s) | Throughput (T/s) | Throughput (T/s/GPU) | GlobalBatchSize (T) | Precision | MP Mode | Sharding Strategy | Activation Checkpointing | Activation CPUOffload | NumParams |
-| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
+| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| 70b | 2048 | 64 | a100_80gb | 53.33 | 71.1 | 166 | 8 | 4 | 2048 | 12 | 26274 | 410 | 4194304 | bf16 | PURE | FULL_SHARD | True | False | 64862437376 |
| 70b | 2048 | 32 | a100_80gb | 48.56 | 64.75 | 151 | 2 | 16 | 1024 | 5 | 11962 | 373 | 2097152 | bf16 | PURE | FULL_SHARD | True | False | 64862437376 |
| 30b | 8192 | 8 | a100_80gb | 39.38 | 52.5 | 122 | 1 | 21 | 168 | 0 | 4594 | 574 | 1376256 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 30019254272 |
@@ -205,7 +205,7 @@ Our microbatching engine enables microbatch sizes that do not divde Global Batch
## A100 40GB with 1600 Gbps node-node interconnect (RoCE)
| Model | SeqLen (T) | # GPUs | GPU | MFU | HFU | Model TFLOP| MicroBatchSize | GradAccum | GlobalBatchSize | Throughput (S/s) | Throughput (T/s) | Throughput (T/s/GPU) | GlobalBatchSize (T) | Precision | MP Mode | Sharding Strategy | Activation Checkpointing | Activation CPUOffload | NumParams |
-| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
+| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| 70b | 2048 | 128 | a100_40gb | 48.91 | 65.21 | 152 | 4 | 1 | 512 | 23 | 48194 | 376 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 64862437376 |
| 70b | 2048 | 64 | a100_40gb | 35.87 | 47.82 | 111 | 2 | 1 | 128 | 8 | 17672 | 276 | 262144 | bf16 | PURE | FULL_SHARD | True | False | 64862437376 |
| 30b | 2048 | 128 | a100_40gb | 52.25 | 69.66 | 163 | 6 | 1 | 768 | 54 | 110803 | 865 | 1572864 | bf16 | PURE | FULL_SHARD | True | False | 29975214080 |
diff --git a/scripts/train/train.py b/scripts/train/train.py
index 5e93e33056..e29f2c9a47 100644
--- a/scripts/train/train.py
+++ b/scripts/train/train.py
@@ -1,9 +1,11 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import copy
+import gc
import logging
import os
import sys
+import time
import warnings
from typing import Any, Dict, List, Optional, Union
@@ -11,6 +13,9 @@
from composer import Trainer
from composer.core import Evaluator
from composer.core.callback import Callback
+from composer.loggers import MosaicMLLogger
+from composer.loggers.mosaicml_logger import (MOSAICML_ACCESS_TOKEN_ENV_VAR,
+ MOSAICML_PLATFORM_ENV_VAR)
from composer.profiler import (JSONTraceHandler, Profiler, TraceHandler,
cyclic_schedule)
from composer.utils import dist, get_device, reproducibility
@@ -212,6 +217,12 @@ def main(cfg: DictConfig) -> Trainer:
os.environ[
'PYTORCH_CUDA_ALLOC_CONF'] = f'max_split_size_mb:{max_split_size_mb}'
+ # Set CUDA lazy loading
+ # This can save a bit of memory if not all modules are needed
+ cuda_load_lazy: bool = cfg.pop('cuda_load_lazy', False)
+ if cuda_load_lazy:
+ os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
+
# Set seed first
seed: int = pop_config(cfg, 'seed', must_exist=True)
reproducibility.seed_all(seed)
@@ -462,7 +473,17 @@ def main(cfg: DictConfig) -> Trainer:
loggers = [
build_logger(str(name), logger_cfg)
for name, logger_cfg in logger_configs.items()
- ] if logger_configs else None
+ ] if logger_configs else []
+
+ mosaicml_logger = next(
+ (logger for logger in loggers if isinstance(logger, MosaicMLLogger)),
+ None)
+ if mosaicml_logger is None:
+ if os.environ.get(MOSAICML_PLATFORM_ENV_VAR, 'false').lower(
+ ) == 'true' and os.environ.get(MOSAICML_ACCESS_TOKEN_ENV_VAR):
+ # Adds mosaicml logger to composer if the run was sent from Mosaic platform, access token is set, and mosaic logger wasn't previously added
+ mosaicml_logger = MosaicMLLogger()
+ loggers.append(mosaicml_logger)
# Profiling
profiler: Optional[Profiler] = None
@@ -510,6 +531,10 @@ def main(cfg: DictConfig) -> Trainer:
tokenizer,
device_train_batch_size,
)
+
+ if mosaicml_logger is not None:
+ mosaicml_logger.log_metrics({'data_validated': time.time()})
+
## Evaluation
print('Building eval loader...')
evaluators = []
@@ -616,6 +641,7 @@ def main(cfg: DictConfig) -> Trainer:
print('Logging config')
log_config(logged_cfg)
torch.cuda.empty_cache()
+ gc.collect()
# Eval first if requested
if eval_first and trainer.state.timestamp.batch.value == 0:
diff --git a/setup.py b/setup.py
index d0ecc66160..63aac9d752 100644
--- a/setup.py
+++ b/setup.py
@@ -49,7 +49,7 @@
install_requires = [
'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.16.4,<0.17',
'accelerate>=0.20,<0.21', # for HF inference `device_map`
- 'transformers>=4.33,<4.34',
+ 'transformers>=4.34.1,<4.35',
'mosaicml-streaming>=0.6,<0.7',
'torch>=1.13.1,<2.1.1',
'datasets>=2.14.5,<2.15',
@@ -114,9 +114,10 @@
extra_deps['all-cpu'] = set(
dep for key, deps in extra_deps.items() for dep in deps if 'gpu' not in key)
extra_deps['all'] = set(dep for key, deps in extra_deps.items() for dep in deps
- if key != 'gpu-flash2')
-extra_deps['all-flash2'] = set(
- dep for key, deps in extra_deps.items() for dep in deps if key != 'gpu')
+ if key not in {'gpu-flash2', 'all-cpu'})
+extra_deps['all-flash2'] = set(dep for key, deps in extra_deps.items()
+ for dep in deps
+ if key not in {'gpu', 'all', 'all-cpu'})
setup(
name=_PACKAGE_NAME,
diff --git a/tests/conftest.py b/tests/conftest.py
index b39ebd66a9..545dc7e38f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,12 +1,10 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
-import gc
import os
from typing import List, Optional
import pytest
-import torch
from composer.utils import reproducibility
# Allowed options for pytest.mark.world_size()
@@ -18,6 +16,13 @@
# Enforce deterministic mode before any tests start.
reproducibility.configure_deterministic_mode()
+# Add the path of any pytest fixture files you want to make global
+pytest_plugins = [
+ 'tests.fixtures.autouse',
+ 'tests.fixtures.models',
+ 'tests.fixtures.data',
+]
+
def _add_option(parser: pytest.Parser,
name: str,
@@ -78,12 +83,3 @@ def pytest_collection_modifyitems(config: pytest.Config,
def pytest_sessionfinish(session: pytest.Session, exitstatus: int):
if exitstatus == 5:
session.exitstatus = 0 # Ignore no-test-ran errors
-
-
-@pytest.fixture(autouse=True)
-def clear_cuda_cache(request: pytest.FixtureRequest):
- """Clear memory between GPU tests."""
- marker = request.node.get_closest_marker('gpu')
- if marker is not None and torch.cuda.is_available():
- torch.cuda.empty_cache()
- gc.collect() # Only gc on GPU tests as it 2x slows down CPU tests
diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py
new file mode 100644
index 0000000000..f6c1f9f3ab
--- /dev/null
+++ b/tests/fixtures/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2022 MosaicML LLM Foundry authors
+# SPDX-License-Identifier: Apache-2.0
diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py
new file mode 100644
index 0000000000..c51ccfacb0
--- /dev/null
+++ b/tests/fixtures/autouse.py
@@ -0,0 +1,39 @@
+# Copyright 2022 MosaicML LLM Foundry authors
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import pytest
+import torch
+from composer.utils import dist, get_device, reproducibility
+
+
+@pytest.fixture(autouse=True)
+def initialize_dist(request: pytest.FixtureRequest):
+ """Initialize the default PyTorch distributed process group for tests."""
+ # should we just always initialize dist like in train.py?
+ _default = pytest.mark.world_size(1).mark
+ world_size = request.node.get_closest_marker('world_size', _default).args[0]
+ gpu = request.node.get_closest_marker('gpu')
+ if world_size > 1:
+ dist.initialize_dist(get_device('gpu' if gpu is not None else 'cpu'))
+
+
+@pytest.fixture(autouse=True)
+def clear_cuda_cache(request: pytest.FixtureRequest):
+ """Clear memory between GPU tests."""
+ marker = request.node.get_closest_marker('gpu')
+ if marker is not None and torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ gc.collect() # Only gc on GPU tests as it 2x slows down CPU tests
+
+
+@pytest.fixture
+def random_seed() -> int:
+ return 17
+
+
+@pytest.fixture(autouse=True)
+def seed_all(random_seed: int):
+ """Sets the seed for reproducibility."""
+ reproducibility.seed_all(random_seed)
diff --git a/tests/fixtures/data.py b/tests/fixtures/data.py
new file mode 100644
index 0000000000..39032146b6
--- /dev/null
+++ b/tests/fixtures/data.py
@@ -0,0 +1,58 @@
+# Copyright 2022 MosaicML LLM Foundry authors
+# SPDX-License-Identifier: Apache-2.0
+
+from pathlib import Path
+
+from composer.utils import dist
+from omegaconf import DictConfig
+from pytest import fixture
+from torch.utils.data import DataLoader
+from transformers import PreTrainedTokenizerBase
+
+from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader
+from tests.data_utils import make_tiny_ft_dataset
+
+
+@fixture
+def tiny_ft_dataset_path(tmp_path: Path, dataset_size: int = 4) -> Path:
+ """Creates a tiny dataset and returns the path."""
+ tiny_dataset_path = tmp_path / 'test-ift-data-small'
+ tiny_dataset_path.mkdir(exist_ok=True)
+ tiny_dataset_file = tiny_dataset_path / 'train.jsonl'
+ if dist.get_world_size() == 1 or dist.get_global_rank() == 0:
+ make_tiny_ft_dataset(path=str(tiny_dataset_file), size=dataset_size)
+ return tiny_dataset_path
+
+
+@fixture
+def tiny_ft_dataloader(tiny_ft_dataset_path: Path,
+ mpt_tokenizer: PreTrainedTokenizerBase,
+ max_seq_len: int = 128,
+ device_batch_size: int = 1) -> DataLoader:
+ dataloader_cfg = DictConfig({
+ 'name': 'finetuning',
+ 'dataset': {
+ 'hf_name': str(tiny_ft_dataset_path),
+ 'split': 'train',
+ 'max_seq_len': max_seq_len,
+ 'decoder_only_format': True,
+ 'allow_pad_trimming': False,
+ 'packing_ratio': None,
+ 'shuffle': True,
+ },
+ 'drop_last': False,
+ 'num_workers': 4,
+ 'pin_memory': False,
+ 'prefetch_factor': 2,
+ 'persistent_workers': False,
+ 'timeout': 0
+ })
+
+ dataloader = build_finetuning_dataloader(
+ dataloader_cfg,
+ mpt_tokenizer,
+ device_batch_size,
+ ).dataloader
+
+ assert isinstance(dataloader, DataLoader)
+ return dataloader
diff --git a/tests/fixtures/models.py b/tests/fixtures/models.py
new file mode 100644
index 0000000000..1b1ef86302
--- /dev/null
+++ b/tests/fixtures/models.py
@@ -0,0 +1,70 @@
+# Copyright 2022 MosaicML LLM Foundry authors
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable
+
+from omegaconf import DictConfig
+from pytest import fixture
+from transformers import PreTrainedTokenizerBase
+
+from llmfoundry.models.hf.hf_causal_lm import ComposerHFCausalLM
+from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY
+from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM
+from llmfoundry.utils.builders import build_tokenizer
+
+
+def _build_model(config: DictConfig, tokenizer: PreTrainedTokenizerBase):
+ model = COMPOSER_MODEL_REGISTRY[config.name](config, tokenizer)
+ return model
+
+
+@fixture
+def mpt_tokenizer():
+ return build_tokenizer('EleutherAI/gpt-neox-20b', {})
+
+
+@fixture
+def build_tiny_mpt(
+ mpt_tokenizer: PreTrainedTokenizerBase
+) -> Callable[..., ComposerMPTCausalLM]:
+
+ def build(**kwargs: Any) -> ComposerMPTCausalLM:
+ config = DictConfig({
+ 'name': 'mpt_causal_lm',
+ 'd_model': 128,
+ 'n_heads': 4,
+ 'n_layers': 2,
+ 'expansion_ratio': 2,
+ })
+ config.update(kwargs)
+ model = _build_model(config, mpt_tokenizer)
+ assert isinstance(model, ComposerMPTCausalLM)
+ return model
+
+ return build
+
+
+@fixture
+def build_tiny_hf_mpt(
+ mpt_tokenizer: PreTrainedTokenizerBase
+) -> Callable[..., ComposerHFCausalLM]:
+
+ def build(**kwargs: Any) -> ComposerHFCausalLM:
+ config_overrides = {
+ 'd_model': 128,
+ 'n_heads': 4,
+ 'n_layers': 2,
+ 'expansion_ratio': 2,
+ }
+ config_overrides.update(kwargs)
+ config = DictConfig({
+ 'name': 'hf_causal_lm',
+ 'pretrained_model_name_or_path': 'mosaicml/mpt-7b',
+ 'pretrained': False,
+ 'config_overrides': config_overrides,
+ })
+ model = _build_model(config, mpt_tokenizer)
+ assert isinstance(model, ComposerHFCausalLM)
+ return model
+
+ return build
diff --git a/tests/test_data_prep_scripts.py b/tests/test_data_prep_scripts.py
index 4c555ea9a2..4fe5ed7e64 100644
--- a/tests/test_data_prep_scripts.py
+++ b/tests/test_data_prep_scripts.py
@@ -2,9 +2,9 @@
# SPDX-License-Identifier: Apache-2.0
import os
-import shutil
import sys
from argparse import Namespace
+from pathlib import Path
# Add repo root to path so we can import scripts and test it
repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
@@ -13,17 +13,16 @@
from scripts.data_prep.convert_dataset_json import main as main_json
-def test_download_script_from_api():
+def test_download_script_from_api(tmp_path: Path):
# test calling it directly
- path = os.path.join(os.getcwd(), 'my-copy-c4-1')
- shutil.rmtree(path, ignore_errors=True)
+ path = os.path.join(tmp_path, 'my-copy-c4-1')
main_hf(
Namespace(
**{
'dataset': 'c4',
'data_subset': 'en',
'splits': ['val_xsmall'],
- 'out_root': './my-copy-c4-1',
+ 'out_root': path,
'compression': None,
'concat_tokens': None,
'bos_text': None,
@@ -32,18 +31,16 @@ def test_download_script_from_api():
'num_workers': None
}))
assert os.path.exists(path)
- shutil.rmtree(path, ignore_errors=False)
-def test_json_script_from_api():
+def test_json_script_from_api(tmp_path: Path):
# test calling it directly
- path = os.path.join(os.getcwd(), 'my-copy-arxiv-1')
- shutil.rmtree(path, ignore_errors=True)
+ path = os.path.join(tmp_path, 'my-copy-arxiv-1')
main_json(
Namespace(
**{
'path': 'scripts/data_prep/example_data/arxiv.jsonl',
- 'out_root': './my-copy-arxiv-1',
+ 'out_root': path,
'compression': None,
'split': 'train',
'concat_tokens': None,
@@ -53,4 +50,3 @@ def test_json_script_from_api():
'num_workers': None
}))
assert os.path.exists(path)
- shutil.rmtree(path, ignore_errors=False)
diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py
index e87d70223f..2441d0824a 100644
--- a/tests/test_flash_triton_torch.py
+++ b/tests/test_flash_triton_torch.py
@@ -3,7 +3,7 @@
import pytest
import torch
-from composer.utils import reproducibility
+
from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding
from omegaconf import OmegaConf as om
from transformers.models.llama.modeling_llama import \
@@ -82,8 +82,6 @@ def test_attn_impl(attn_impl_0: str,
if alibi and (attn_impl_0 == 'flash' or attn_impl_1 == 'flash'):
pytest.xfail('flash attn does not support alibi')
- reproducibility.seed_all(7)
-
cfg = om.create({
'attn_impl': 'flash',
'd_model': 128,
@@ -253,8 +251,6 @@ def test_vs_mha(attn_impl: str, device: str = 'cuda'):
"""Compare diff attn_impl to torch.nn.MultiheadAttention."""
from llmfoundry.models.layers import attention
- reproducibility.seed_all(17)
-
cfg = om.create({
'attn_impl': attn_impl,
'd_model': 256,
@@ -352,8 +348,6 @@ def test_grouped_attention_heads(attn_impl: str,
"""Ensure grouped_query_attention runs w/ diff n_heads & kv_n_heads."""
from llmfoundry.models.layers import attention
- reproducibility.seed_all(17)
-
cfg = om.create({
'attn_impl': attn_impl,
'd_model': 256,
@@ -391,8 +385,6 @@ def test_grouped_query_invalid_heads(attn_impl: str, device: str = 'cuda'):
"""Check indivisble combinations of grouped_query_attention."""
from llmfoundry.models.layers import attention
- reproducibility.seed_all(17)
-
cfg = om.create({
'attn_impl': attn_impl,
'd_model': 256,
diff --git a/tests/test_hf_config.py b/tests/test_hf_config.py
index 5b3bb3d150..b47f267c55 100644
--- a/tests/test_hf_config.py
+++ b/tests/test_hf_config.py
@@ -9,7 +9,6 @@
import pytest
import torch
-from composer.utils import reproducibility
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from transformers import AutoModelForCausalLM
@@ -93,8 +92,6 @@ def test_hf_config_override(
with open(conf_path) as f:
test_cfg = om.load(f)
- reproducibility.seed_all(test_cfg.seed)
-
# Build Model
# For fast initialization, use `meta` device
print('Initializing model...')
diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py
index 5bc3ed6d5d..d2f203d3a0 100644
--- a/tests/test_hf_conversion_script.py
+++ b/tests/test_hf_conversion_script.py
@@ -138,6 +138,49 @@ def check_hf_tokenizer_equivalence(tokenizer1: PreTrainedTokenizerBase,
tokenizer1.__dict__['init_kwargs'].pop('auto_map', None)
tokenizer2.__dict__['init_kwargs'].pop('auto_map', None)
+ # Additional special tokens do not match between original tokenizer and loaded tokenizer due to transformers
+ # constructor differences
+ additional_special_tokens_1 = {
+ t if isinstance(t, str) else t.content
+ for t in tokenizer1.__dict__.pop('_additional_special_tokens', [])
+ }
+ additional_special_tokens_2 = {
+ t if isinstance(t, str) else t.content
+ for t in tokenizer2.__dict__.pop('_additional_special_tokens', [])
+ }
+ # Also pop it out of init_kwargs
+ tokenizer1.__dict__['init_kwargs'].pop('additional_special_tokens', None)
+ tokenizer2.__dict__['init_kwargs'].pop('additional_special_tokens', None)
+ tokenizer1.__dict__['init_kwargs'].pop('added_tokens_decoder', None)
+ tokenizer2.__dict__['init_kwargs'].pop('added_tokens_decoder', None)
+ # If the additional special tokens are the same (or a subset of each other), or if one of them is empty, then we are good
+ assert additional_special_tokens_1.issubset(
+ additional_special_tokens_2) or additional_special_tokens_2.issubset(
+ additional_special_tokens_1)
+
+ # The special token attributes may be strings or they may be AddedToken objects, so we just check string values
+ # First check that they have the same attrs
+ assert tokenizer1.SPECIAL_TOKENS_ATTRIBUTES == tokenizer2.SPECIAL_TOKENS_ATTRIBUTES
+ # Then check that the values are the same
+ for special_token_attr in tokenizer1.SPECIAL_TOKENS_ATTRIBUTES:
+ # Skip additional_special_tokens because we already checked it above
+ if special_token_attr == 'additional_special_tokens':
+ continue
+
+ # The init_kwargs can change between the original tokenizer and the loaded tokenizer,
+ # so we just pop them
+ tokenizer1.__dict__['init_kwargs'].pop(special_token_attr, None)
+ tokenizer2.__dict__['init_kwargs'].pop(special_token_attr, None)
+
+ attr1 = tokenizer1.__dict__.pop('_' + special_token_attr, None)
+ attr2 = tokenizer2.__dict__.pop('_' + special_token_attr, None)
+ if attr1 is None and attr2 is None:
+ continue
+
+ attr_value1 = attr1 if isinstance(attr1, str) else attr1.content
+ attr_value2 = attr2 if isinstance(attr2, str) else attr2.content
+ assert attr_value1 == attr_value2
+
assert tokenizer1.__dict__ == tokenizer2.__dict__
@@ -174,6 +217,10 @@ def check_hf_model_equivalence(model1: PreTrainedModel,
def delete_transformers_cache():
+ # Only delete the files on local rank 0, otherwise race conditions are created
+ if not dist.get_local_rank() == 0:
+ return
+
hf_cache_home = os.path.expanduser(
os.getenv(
'HF_HOME',
@@ -204,25 +251,30 @@ def test_callback_inits_with_defaults():
@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2'])
@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None])
@pytest.mark.parametrize('log_to_mlflow', [True, False])
+@pytest.mark.parametrize(
+ 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints',
+ [('3ba', '2ba', '7ba', 3, 4), ('1dur', '2ba', '1ep', 1, 4)])
def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
fsdp_state_dict_type: Optional[str],
- log_to_mlflow: bool):
+ log_to_mlflow: bool,
+ hf_save_interval: str,
+ save_interval: str, max_duration: str,
+ expected_hf_checkpoints: int,
+ expected_normal_checkpoints: int):
delete_transformers_cache()
dist.initialize_dist(get_device('gpu'))
max_seq_len = 16
- save_interval_batches = 2
- huggingface_save_interval_batches = 3
device_batch_size = 1
dataset_size = 14
- max_duration_batches = 7
precision_str = 'bfloat16'
precision = torch.bfloat16
+ batches_per_epoch = math.ceil(dataset_size / (device_batch_size * 2))
checkpointer_callback = HuggingFaceCheckpointer(
save_folder=os.path.join(tmp_path, 'checkpoints'),
- save_interval=f'{huggingface_save_interval_batches}ba',
+ save_interval=hf_save_interval,
precision=precision_str,
mlflow_registered_model_name='dummy-registered-name'
if log_to_mlflow else None,
@@ -358,8 +410,8 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
fsdp_config=fsdp_config if fsdp_state_dict_type is not None else None,
train_dataloader=train_dataloader,
save_folder=os.path.join(tmp_path, 'checkpoints'),
- save_interval=f'{save_interval_batches}ba',
- max_duration=f'{max_duration_batches}ba',
+ save_interval=save_interval,
+ max_duration=max_duration,
callbacks=[checkpointer_callback],
loggers=[mlflow_logger_mock] if log_to_mlflow else [],
optimizers=optimizer,
@@ -395,15 +447,13 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
name for name in os.listdir(
os.path.join(tmp_path, 'checkpoints', 'huggingface'))
]
- assert len(normal_checkpoints) == math.ceil(max_duration_batches /
- save_interval_batches)
- assert len(huggingface_checkpoints) == math.ceil(
- max_duration_batches / huggingface_save_interval_batches)
+ assert len(normal_checkpoints) == expected_normal_checkpoints
+ assert len(huggingface_checkpoints) == expected_hf_checkpoints
# Load the last huggingface checkpoint
loaded_model = transformers.AutoModelForCausalLM.from_pretrained(
os.path.join(tmp_path, 'checkpoints', 'huggingface',
- f'ba{max_duration_batches}'),
+ f'ba{batches_per_epoch}'),
trust_remote_code=True,
)
@@ -424,7 +474,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
loaded_tokenizer = transformers.AutoTokenizer.from_pretrained(
os.path.join(tmp_path, 'checkpoints', 'huggingface',
- f'ba{max_duration_batches}'),
+ f'ba{batches_per_epoch}'),
trust_remote_code=True,
)
@@ -434,6 +484,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
loaded_model)
check_hf_tokenizer_equivalence(tokenizer, loaded_tokenizer)
+ dist.barrier()
delete_transformers_cache()
diff --git a/tests/test_hf_mpt_gen.py b/tests/test_hf_mpt_gen.py
index cc357141ba..ea133c64fa 100644
--- a/tests/test_hf_mpt_gen.py
+++ b/tests/test_hf_mpt_gen.py
@@ -1,167 +1,51 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
-from pathlib import Path
-from typing import Any, Dict
-from unittest.mock import Mock
+from typing import Callable
import pytest
-from composer.callbacks import Generate as ComposerGenerate
from composer.core.precision import get_precision_context
-from composer.trainer import Trainer
-from composer.utils import get_device, reproducibility
-from omegaconf import DictConfig
-from omegaconf import OmegaConf as om
+from composer.utils import get_device
+from transformers import PreTrainedTokenizerBase
-from llmfoundry import COMPOSER_MODEL_REGISTRY
-from llmfoundry.data.finetuning import build_finetuning_dataloader
-from llmfoundry.utils import build_tokenizer
-from tests.data_utils import make_tiny_ft_dataset
+from llmfoundry.models.hf.hf_causal_lm import ComposerHFCausalLM
@pytest.mark.gpu
@pytest.mark.parametrize('device', ['cpu', 'gpu'])
@pytest.mark.parametrize('attn_impl', ['triton', 'torch'])
-def test_init_hfhub_mpt(device: str, attn_impl: str):
+def test_init_hfhub_mpt(
+ device: str,
+ attn_impl: str,
+ build_tiny_hf_mpt: Callable[..., ComposerHFCausalLM],
+ mpt_tokenizer: PreTrainedTokenizerBase,
+):
if device == 'cpu' and attn_impl == 'triton':
pytest.skip(f'{attn_impl=} not implemented for {device=}.')
composer_device = get_device(device)
- with open('scripts/train/yamls/pretrain/testing.yaml') as f:
- test_cfg = om.load(f)
-
- assert isinstance(test_cfg, DictConfig)
- reproducibility.seed_all(test_cfg.get('seed', 42))
-
- attn_uses_sequence_id = True if test_cfg.get('eos_token_id',
- None) is not None else False
- test_cfg.model = DictConfig({
- 'name': 'hf_causal_lm',
- 'pretrained_model_name_or_path': 'mosaicml/mpt-7b',
- 'pretrained': False,
- 'config_overrides': {
- 'd_model': 128,
- 'n_heads': 4,
- 'n_layers': 2,
- 'expansion_ratio': 2,
- 'attn_config': {
- 'attn_impl': attn_impl,
- 'attn_uses_sequence_id': attn_uses_sequence_id,
- },
- },
+ model = build_tiny_hf_mpt(attn_config={
+ 'attn_impl': attn_impl,
+ 'attn_uses_sequence_id': False,
})
-
- # build tokenizer
- tokenizer_cfg: Dict[str,
- Any] = om.to_container(test_cfg.tokenizer,
- resolve=True) # type: ignore
- tokenizer_name = tokenizer_cfg['name']
- tokenizer_kwargs = tokenizer_cfg.get('kwargs', {})
- tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)
-
- # build model
- model = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model,
- tokenizer)
- test_cfg.n_params = sum(p.numel() for p in model.parameters())
+ model = composer_device.module_to_device(model)
model.eval()
- model = composer_device.module_to_device(model)
with get_precision_context('amp_bf16' if composer_device.name ==
'gpu' else 'fp32'):
_ = model.generate(
composer_device.tensor_to_device(
- tokenizer('hello', return_tensors='pt')['input_ids']),
+ mpt_tokenizer('hello', return_tensors='pt')['input_ids']),
max_new_tokens=10,
)
-def test_init_hfhub_mpt_cpu():
- test_init_hfhub_mpt(device='cpu', attn_impl='torch')
-
-
-@pytest.mark.gpu
-def test_mpt_generate_callback(tmpdir: Path):
- composer_device = get_device('gpu')
- reproducibility.seed_all(42)
- max_seq_len = 128
-
- # testing dataset and dataloader
- dataset_size = 5
-
- tiny_dataset_path = tmpdir / 'test-ift-data-small'
- tiny_dataset_path.mkdir()
- tiny_dataset_file = tiny_dataset_path / 'train.jsonl'
- make_tiny_ft_dataset(path=str(tiny_dataset_file), size=dataset_size)
-
- dataloader_cfg = DictConfig({
- 'name': 'finetuning',
- 'dataset': {
- 'hf_name': str(tiny_dataset_path),
- 'split': 'train',
- 'max_seq_len': max_seq_len,
- 'decoder_only_format': True,
- 'allow_pad_trimming': False,
- 'packing_ratio': None,
- 'shuffle': True,
- },
- 'drop_last': False,
- 'num_workers': 4,
- 'pin_memory': False,
- 'prefetch_factor': 2,
- 'persistent_workers': False,
- 'timeout': 0
- })
-
- # build tokenizer
- tokenizer = build_tokenizer('EleutherAI/gpt-neox-20b', {})
-
- # build mpt model
- model_config = DictConfig({
- 'name': 'mpt_causal_lm',
- 'config_overrides': {
- 'd_model': 128,
- 'n_heads': 4,
- 'n_layers': 2,
- 'expansion_ratio': 2,
- },
- })
- model = COMPOSER_MODEL_REGISTRY[model_config.name](model_config, tokenizer)
- model = composer_device.module_to_device(model)
-
- # generate callback
- prompts = [
- 'The best banana bread recipe is',
- '2+2=',
- 'how much wood could a woodchuck chuck',
- ]
- gen_interval = 1
- generate = ComposerGenerate(
- prompts,
- interval=f'{gen_interval}ba',
- max_new_tokens=5,
- batch_size=len(prompts),
- use_cache=True,
- )
- generate.generate = Mock(wraps=generate.generate, autospec=True)
-
- # build trainer
- device_batch_size = 1
- train_dataloader = build_finetuning_dataloader(
- dataloader_cfg,
- tokenizer,
- device_batch_size,
- )
-
- trainer = Trainer(
- model=model,
- train_dataloader=train_dataloader,
- device=composer_device,
- max_duration=f'{gen_interval}ba',
- callbacks=[generate],
- )
- trainer.logger.log_table = Mock()
- trainer.fit()
-
- generate.generate.assert_called_once()
- trainer.logger.log_table.assert_called_once()
+def test_init_hfhub_mpt_cpu(
+ build_tiny_hf_mpt: Callable[..., ComposerHFCausalLM],
+ mpt_tokenizer: PreTrainedTokenizerBase,
+):
+ test_init_hfhub_mpt(device='cpu',
+ attn_impl='torch',
+ build_tiny_hf_mpt=build_tiny_hf_mpt,
+ mpt_tokenizer=mpt_tokenizer)
diff --git a/tests/test_hf_v_mpt.py b/tests/test_hf_v_mpt.py
index 82e2d05550..46172faf35 100644
--- a/tests/test_hf_v_mpt.py
+++ b/tests/test_hf_v_mpt.py
@@ -5,7 +5,6 @@
import pytest
import torch
-from composer.utils import reproducibility
from omegaconf import OmegaConf as om
from llmfoundry import COMPOSER_MODEL_REGISTRY
@@ -52,10 +51,6 @@ def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool,
batch_size = 2 # set batch size
device = 'cuda' # set decive
- # ensure reproducibility
- seed = 17
- reproducibility.seed_all(seed) # set seed
-
# get hf gpt2 cfg
hf_cfg = om.create({
'model': {
@@ -154,11 +149,9 @@ def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool,
# UTIL: can be used to verify that models are not the same at init
with torch.autocast(device_type='cuda', dtype=torch.float16):
- torch.manual_seed(seed)
hf_model_fwd = hf_model(batch)['logits']
if kpm is not None:
hf_model_fwd *= kpm
- torch.manual_seed(seed)
model_fwd = model(batch).logits
if kpm is not None:
model_fwd *= kpm
@@ -208,11 +201,9 @@ def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool,
model.load_state_dict(_hf_model_statedict)
with torch.autocast(device_type=device, dtype=torch.float16):
- torch.manual_seed(seed)
hf_model_fwd = hf_model(batch)['logits']
if kpm is not None:
hf_model_fwd *= kpm
- torch.manual_seed(seed)
model_fwd = model(batch).logits
if kpm is not None:
model_fwd *= kpm
diff --git a/tests/test_huggingface_flash.py b/tests/test_huggingface_flash.py
new file mode 100644
index 0000000000..a71217ea1f
--- /dev/null
+++ b/tests/test_huggingface_flash.py
@@ -0,0 +1,195 @@
+# Copyright 2022 MosaicML LLM Foundry authors
+# SPDX-License-Identifier: Apache-2.0
+
+import contextlib
+import os
+from unittest.mock import patch
+
+import pytest
+import torch
+import transformers
+from composer.core.precision import get_precision_context
+from composer.utils import reproducibility
+from omegaconf import OmegaConf as om
+
+from llmfoundry import COMPOSER_MODEL_REGISTRY
+from llmfoundry.models.hf.hf_fsdp import rgetattr
+from llmfoundry.models.layers.attention import (is_flash_v1_installed,
+ is_flash_v2_installed)
+from llmfoundry.utils.builders import build_tokenizer
+
+# Before importing any transformers models, we need to disable transformers flash attention if
+# we are in an environment with flash attention version <2. Transformers hard errors on a not properly
+# gated import otherwise.
+if is_flash_v1_installed():
+ transformers.utils.is_flash_attn_available = lambda: False
+
+from transformers.models.llama.modeling_llama import LlamaAttention
+
+from llmfoundry.models.layers.llama_attention_monkeypatch import (
+ llama_attention_patch_torch, llama_attention_patch_triton)
+
+
+@pytest.mark.parametrize('patch_fn_name', ['torch', 'triton'])
+@pytest.mark.parametrize('explicit_mask', [True, False])
+@pytest.mark.parametrize(
+ 'model_name', ['meta-llama/Llama-2-7b-hf', 'meta-llama/Llama-2-70b-hf'])
+@pytest.mark.gpu
+def test_patch_equivalence(patch_fn_name: str, explicit_mask: bool,
+ model_name: str):
+ if 'HUGGING_FACE_HUB_TOKEN' not in os.environ:
+ pytest.skip(
+ 'The CI cluster does not have access to the Llama models, so skip this test.'
+ )
+
+ device = 'cuda:0'
+ sequence_length = 4096
+ model_dim = 4096 if '7b' in model_name else 8192
+ batch_size = 2
+ if patch_fn_name == 'torch':
+ patch_fn = llama_attention_patch_torch
+ dtype = torch.float32
+ atol = 0.0
+ rtol = 0.0
+ elif patch_fn_name == 'triton':
+ # the huggingface implementation of llama performs the softmax in fp32
+ # this can result in fairly large differences for the triton implementation
+ # but the torch implementation produces the exact same output so we can confirm
+ # the implementation is correct
+ patch_fn = llama_attention_patch_triton
+ dtype = torch.bfloat16
+ atol = 1e-2
+ rtol = 1e-2
+ else:
+ raise ValueError(f'Unknown patch_fn_name: {patch_fn_name}')
+
+ llama_config = transformers.AutoConfig.from_pretrained(model_name,
+ use_auth_token=True)
+
+ reproducibility.seed_all(42)
+ attention = LlamaAttention(config=llama_config,)
+ attention.to(dtype=dtype, device=device)
+
+ rng = torch.Generator(device=device).manual_seed(42)
+ hidden_states = torch.randn(batch_size,
+ sequence_length,
+ model_dim,
+ generator=rng,
+ dtype=dtype,
+ device=device)
+ causal_mask = torch.full((sequence_length, sequence_length),
+ torch.finfo(torch.float32).min,
+ device=device)
+ causal_mask = causal_mask.triu(diagonal=1)
+ causal_mask = causal_mask[None,
+ None, :, :].expand(batch_size, 1, sequence_length,
+ sequence_length)
+ attn_output, _, _ = attention(
+ hidden_states=hidden_states,
+ attention_mask=causal_mask if explicit_mask else None,
+ position_ids=None,
+ past_key_value=None,
+ use_cache=False,
+ )
+
+ reproducibility.seed_all(42)
+ with patch.object(LlamaAttention, 'forward', new=patch_fn):
+ attention = LlamaAttention(config=llama_config,)
+ attention.to(dtype=dtype, device=device)
+ new_output, _, _ = attention(
+ hidden_states=hidden_states,
+ attention_mask=causal_mask if explicit_mask else None,
+ position_ids=None,
+ past_key_value=None,
+ use_cache=False,
+ )
+
+ assert torch.allclose(attn_output, new_output, atol=atol, rtol=rtol)
+
+
+@pytest.mark.gpu
+@pytest.mark.parametrize('model_name', ['llama2', 'mistral'])
+@pytest.mark.parametrize('use_flash_attention_2', [True, False])
+def test_flash2(model_name: str, use_flash_attention_2: bool):
+ if model_name == 'llama2':
+ if 'HUGGING_FACE_HUB_TOKEN' not in os.environ:
+ pytest.skip(
+ 'The CI cluster does not have access to the Llama models, so skip this test.'
+ )
+ model_cfg = {
+ 'name': 'hf_causal_lm',
+ 'pretrained_model_name_or_path': 'meta-llama/Llama-2-7b-hf',
+ 'config_overrides': {
+ 'num_hidden_layers': 2,
+ 'intermediate_size': 64,
+ },
+ 'use_auth_token': True,
+ 'pretrained': False,
+ 'init_device': 'cpu',
+ }
+
+ tokenizer_name = 'meta-llama/Llama-2-7b-hf'
+ from transformers.models.llama.modeling_llama import (
+ LlamaAttention, LlamaFlashAttention2)
+ flash_attn_class = LlamaFlashAttention2 if use_flash_attention_2 else LlamaAttention
+ attention_layers_attr = 'model.model.layers'
+ attention_attr = 'self_attn'
+ elif model_name == 'mistral':
+ model_cfg = {
+ 'name': 'hf_causal_lm',
+ 'pretrained_model_name_or_path': 'mistralai/Mistral-7B-v0.1',
+ 'config_overrides': {
+ 'num_hidden_layers': 2,
+ 'intermediate_size': 64,
+ },
+ 'pretrained': False,
+ 'init_device': 'cpu',
+ }
+
+ tokenizer_name = 'mistralai/Mistral-7B-v0.1'
+ from transformers.models.mistral.modeling_mistral import (
+ MistralAttention, MistralFlashAttention2)
+ flash_attn_class = MistralFlashAttention2 if use_flash_attention_2 else MistralAttention
+ attention_layers_attr = 'model.model.layers'
+ attention_attr = 'self_attn'
+ else:
+ raise ValueError(f'Unknown model: {model_name}')
+
+ if use_flash_attention_2:
+ model_cfg['use_flash_attention_2'] = True
+
+ model_cfg = om.create(model_cfg)
+
+ tokenizer = build_tokenizer(
+ tokenizer_name=tokenizer_name,
+ tokenizer_kwargs={'model_max_length': 10},
+ )
+ tokenizer.pad_token = tokenizer.eos_token
+
+ error_context = pytest.raises(
+ ValueError, match='use_flash_attention_2 is set to True'
+ ) if not is_flash_v2_installed(
+ ) and use_flash_attention_2 else contextlib.nullcontext()
+
+ with error_context:
+ model = COMPOSER_MODEL_REGISTRY[model_cfg['name']](model_cfg, tokenizer)
+
+ # check that it actually used flash attention 2
+ assert model.model.config._flash_attn_2_enabled if use_flash_attention_2 else not model.model.config._flash_attn_2_enabled
+ attention_layer = rgetattr(
+ rgetattr(model, attention_layers_attr)[0], attention_attr)
+ assert isinstance(attention_layer, flash_attn_class)
+
+ tokenized_input = tokenizer(['Hello world blah blah', 'Goodbye world'],
+ return_tensors='pt',
+ padding=True)
+ tokenized_input['labels'] = tokenized_input['input_ids'].clone()
+
+ tokenized_input = {k: v.cuda() for k, v in tokenized_input.items()}
+ model.to('cuda')
+
+ with get_precision_context('amp_bf16'):
+ # We're just testing that flash attention 2 runs okay
+ outputs = model(tokenized_input)
+ loss = outputs.loss
+ loss.backward()
diff --git a/tests/test_init_fn.py b/tests/test_init_fn.py
index b054bac186..6be2c5ca42 100644
--- a/tests/test_init_fn.py
+++ b/tests/test_init_fn.py
@@ -8,7 +8,6 @@
import pytest
import torch
-from composer.utils import reproducibility
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
from torch import nn
@@ -35,8 +34,6 @@ def forward(self, x: torch.Tensor):
@pytest.mark.parametrize('is_residual', [True, False])
def test_div_is_residual(is_residual: bool):
- reproducibility.seed_all(7)
-
in_features, out_features = 8, 32
cfg = om.create({
'in_features': in_features,
@@ -64,8 +61,6 @@ def test_div_is_residual(is_residual: bool):
@pytest.mark.parametrize('fused', [True, False])
def test_fused_init_helper(fused: bool):
- reproducibility.seed_all(7)
-
in_features, out_features = 8, 32
cfg = om.create({
'in_features': in_features,
@@ -133,8 +128,6 @@ def max_fill_init_(weight: torch.Tensor):
('emb_init_uniform_lim', [1, 1])
])
def test_emb_init(emb_init_cfg: Optional[Tuple[str, Union[int, List[int]]]]):
- reproducibility.seed_all(7)
-
cfg: Dict[str, Union[int, List[int]]] = {
'vocab_size': 64,
'in_features': 16,
diff --git a/tests/test_llama_patch.py b/tests/test_llama_patch.py
deleted file mode 100644
index b1cd3711e0..0000000000
--- a/tests/test_llama_patch.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# Copyright 2022 MosaicML LLM Foundry authors
-# SPDX-License-Identifier: Apache-2.0
-
-import os
-
-import pytest
-import torch
-import transformers
-from composer.utils import reproducibility
-from transformers.models.llama.modeling_llama import LlamaAttention
-
-from llmfoundry.models.layers.llama_attention_monkeypatch import (
- llama_attention_patch_torch, llama_attention_patch_triton)
-
-
-@pytest.mark.parametrize('patch_fn_name', ['torch', 'triton'])
-@pytest.mark.parametrize('explicit_mask', [True, False])
-@pytest.mark.parametrize(
- 'model_name', ['meta-llama/Llama-2-7b-hf', 'meta-llama/Llama-2-70b-hf'])
-@pytest.mark.gpu
-def test_patch_equivalence(patch_fn_name: str, explicit_mask: bool,
- model_name: str):
- if 'HUGGING_FACE_HUB_TOKEN' not in os.environ:
- pytest.skip(
- 'The CI cluster does not have access to the Llama models, so skip this test.'
- )
-
- original_forward = LlamaAttention.forward
-
- device = 'cuda:0'
- sequence_length = 4096
- model_dim = 4096 if '7b' in model_name else 8192
- batch_size = 2
- if patch_fn_name == 'torch':
- patch_fn = llama_attention_patch_torch
- dtype = torch.float32
- atol = 0.0
- rtol = 0.0
- elif patch_fn_name == 'triton':
- # the huggingface implementation of llama performs the softmax in fp32
- # this can result in fairly large differences for the triton implementation
- # but the torch implementation produces the exact same output so we can confirm
- # the implementation is correct
- patch_fn = llama_attention_patch_triton
- dtype = torch.bfloat16
- atol = 1e-2
- rtol = 1e-2
- else:
- raise ValueError(f'Unknown patch_fn_name: {patch_fn_name}')
-
- llama_config = transformers.AutoConfig.from_pretrained(model_name,
- use_auth_token=True)
-
- reproducibility.seed_all(42)
- attention = LlamaAttention(config=llama_config,)
- attention.to(dtype=dtype, device=device)
-
- rng = torch.Generator(device=device).manual_seed(42)
- hidden_states = torch.randn(batch_size,
- sequence_length,
- model_dim,
- generator=rng,
- dtype=dtype,
- device=device)
- causal_mask = torch.full((sequence_length, sequence_length),
- torch.finfo(torch.float32).min,
- device=device)
- causal_mask = causal_mask.triu(diagonal=1)
- causal_mask = causal_mask[None,
- None, :, :].expand(batch_size, 1, sequence_length,
- sequence_length)
- attn_output, _, _ = attention(
- hidden_states=hidden_states,
- attention_mask=causal_mask if explicit_mask else None,
- position_ids=None,
- past_key_value=None,
- use_cache=False,
- )
-
- reproducibility.seed_all(42)
- LlamaAttention.forward = patch_fn
- attention = LlamaAttention(config=llama_config,)
- attention.to(dtype=dtype, device=device)
- new_output, _, _ = attention(
- hidden_states=hidden_states,
- attention_mask=causal_mask if explicit_mask else None,
- position_ids=None,
- past_key_value=None,
- use_cache=False,
- )
-
- # Reset the forward function so patches don't persist
- LlamaAttention.forward = original_forward
-
- assert torch.allclose(attn_output, new_output, atol=atol, rtol=rtol)
diff --git a/tests/test_model.py b/tests/test_model.py
index 69aa05a362..ef76396da2 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -16,7 +16,7 @@
from composer.core.precision import Precision, get_precision_context
from composer.optim import DecoupledAdamW
from composer.trainer.dist_strategy import prepare_fsdp_module
-from composer.utils import dist, get_device, reproducibility
+from composer.utils import dist, get_device
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedModel,
@@ -56,8 +56,6 @@ def get_objs(conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml'):
message='Torchmetrics v0.9 introduced a new argument class property')
test_cfg = get_config(conf_path=conf_path)
- reproducibility.seed_all(test_cfg.seed)
-
# Read FSDP Config as a dict
fsdp_config = test_cfg.get('fsdp_config', None)
fsdp_config = om.to_container(fsdp_config,
@@ -316,7 +314,6 @@ def test_determinism(attn_impl: str, precision: torch.dtype):
pytest.skip(
'This test requires CUDA to be available in order to run with bfloat16 precision.'
)
- reproducibility.seed_all(1111)
conf_path = 'scripts/train/yamls/pretrain/testing.yaml'
with open(conf_path) as f:
@@ -394,8 +391,6 @@ def test_loss_fn():
'init_std': 0.02,
}
- reproducibility.seed_all(test_cfg.get('global_seed', 42))
-
tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(test_cfg.tokenizer)
tokenizer = build_tokenizer(test_cfg.tokenizer.name,
tokenizer_cfg.get('kwargs', {}))
@@ -578,7 +573,6 @@ def test_forward_with_padding(attention_impl: str, device: str,
pytest.skip(
f'dail implementation of rope is only implemented for gpus.')
- reproducibility.seed_all(1234)
composer_device = get_device(device)
hf_config = MPTConfig(
@@ -815,7 +809,6 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict):
pytest.skip(
f'dail implementation of rope is only implemented for gpus.')
- reproducibility.seed_all(1234)
composer_device = get_device(device)
hf_config = MPTConfig(
@@ -875,14 +868,12 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict):
use_cache=False)
assert batched_generation.shape == (2, 6 + 5)
- reproducibility.seed_all(1234)
generation_with_left_padding = mpt.generate(
input_ids=left_padding_input_ids,
attention_mask=left_padding_attention_mask,
max_new_tokens=5,
use_cache=False)
assert generation_with_left_padding.shape == (2, 6 + 5)
- reproducibility.seed_all(1234)
generation_with_no_padding = mpt.generate(
input_ids=no_padding_input_ids,
attention_mask=no_padding_attention_mask,
@@ -1224,14 +1215,12 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict):
'init_std': 0.02,
},
)
- reproducibility.seed_all(1234)
mpt = MPTForCausalLM(hf_config)
mpt = composer_device.module_to_device(mpt)
mpt.eval()
with get_precision_context('amp_bf16' if composer_device.name ==
'gpu' else 'fp32'):
- reproducibility.seed_all(1234)
first_input_ids = torch.tensor([[11274, 16390, 11]])
first_input_ids = composer_device.tensor_to_device(first_input_ids)
first_attention_mask = torch.tensor([[1, 1, 1]]).bool()
@@ -1257,7 +1246,6 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict):
assert all(past_key_value[1].shape == (1, 3, 128)
for past_key_value in first_output.past_key_values)
- reproducibility.seed_all(1234)
second_input_ids = torch.tensor([[11274, 16390, 11, 11274]])
second_input_ids = composer_device.tensor_to_device(second_input_ids)
second_attention_mask = torch.tensor([[1, 1, 1, 1]]).bool()
@@ -1287,7 +1275,6 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict):
assert all(past_key_value[1].shape == (1, 4, 128)
for past_key_value in second_output.past_key_values)
- reproducibility.seed_all(1234)
# pass through the first four tokens without the key-value cache
full_output = mpt(second_input_ids,
attention_mask=second_attention_mask)
@@ -1551,7 +1538,6 @@ def test_model_to(attention_impl: str, pos_emb_config: dict):
'init_std': 0.02,
},
)
- reproducibility.seed_all(1234)
mpt = MPTForCausalLM(hf_config)
mpt = mpt.bfloat16()
mpt = mpt.to('cuda')
@@ -1700,14 +1686,12 @@ def test_forward_with_output_attentions_and_output_hidden_states(
'init_std': 0.02,
},
)
- reproducibility.seed_all(1234)
mpt = MPTForCausalLM(hf_config)
mpt = composer_device.module_to_device(mpt)
mpt.eval()
with get_precision_context('amp_bf16' if composer_device.name ==
'gpu' else 'fp32'):
- reproducibility.seed_all(1234)
input_ids = torch.tensor([[11274, 16390, 11]])
input_ids = composer_device.tensor_to_device(input_ids)
attention_mask = torch.tensor([[1, 1, 1]]).bool()
diff --git a/tests/test_mpt_gen.py b/tests/test_mpt_gen.py
index 06ddccd479..c52b765480 100644
--- a/tests/test_mpt_gen.py
+++ b/tests/test_mpt_gen.py
@@ -1,19 +1,21 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
-from typing import List, Optional, Tuple
-from unittest.mock import patch
+from typing import Callable, List, Optional, Tuple
+from unittest.mock import Mock, patch
import pytest
import torch
+from composer import Trainer
+from composer.callbacks import Generate as ComposerGenerate
from composer.core.precision import get_precision_context
-from composer.utils import dist, get_device, reproducibility
-from omegaconf import DictConfig
+from composer.utils import dist, get_device
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.utils.data import DataLoader
+from transformers import PreTrainedTokenizerBase
-from llmfoundry import COMPOSER_MODEL_REGISTRY
-from llmfoundry.models.mpt.modeling_mpt import MPTForCausalLM
-from llmfoundry.utils import build_tokenizer
+from llmfoundry.models.mpt.modeling_mpt import (ComposerMPTCausalLM,
+ MPTForCausalLM)
EOS_TOKEN_ID = 0
@@ -55,44 +57,72 @@ def forward(
@pytest.mark.parametrize('use_alibi', [True, False])
@patch('llmfoundry.models.mpt.modeling_mpt.MPTForCausalLM',
new=MockMPTForCausalLM)
-def test_mpt_generate_multi_gpu(attn_impl: str, use_alibi: bool):
+def test_mpt_generate_multi_gpu(attn_impl: str, use_alibi: bool,
+ build_tiny_mpt: Callable[...,
+ ComposerMPTCausalLM],
+ mpt_tokenizer: PreTrainedTokenizerBase):
"""Tests mpt generation with mutiple gpus.
and generations of different lengths.
"""
- composer_device = get_device('gpu')
- dist.initialize_dist(composer_device)
- reproducibility.seed_all(42)
-
- model_config = DictConfig({
- 'name': 'mpt_causal_lm',
- 'd_model': 128,
- 'n_heads': 4,
- 'n_layers': 2,
- 'expansion_ratio': 2,
- 'no_bias': False,
- 'use_cache': True,
- 'attn_config': {
- 'attn_impl': attn_impl,
- 'attn_uses_sequence_id': False,
- 'alibi': use_alibi
- },
- })
-
- # build tokenizer
- tokenizer = build_tokenizer('EleutherAI/gpt-neox-20b', {})
-
- # build model
- model = COMPOSER_MODEL_REGISTRY[model_config.name](model_config, tokenizer)
- model = composer_device.module_to_device(model)
+ device = get_device('gpu')
+
+ model = build_tiny_mpt(attn_config={
+ 'attn_impl': attn_impl,
+ 'attn_uses_sequence_id': False,
+ 'alibi': use_alibi
+ },)
+ model = device.module_to_device(model)
+
model.eval()
model.model = FSDP(model.model)
with get_precision_context('amp_bf16'):
- _ = model.generate(composer_device.tensor_to_device(
- tokenizer('hello', return_tensors='pt')['input_ids']),
+ _ = model.generate(device.tensor_to_device(
+ mpt_tokenizer('hello', return_tensors='pt')['input_ids']),
max_new_tokens=3,
eos_token_id=EOS_TOKEN_ID,
use_cache=True,
synced_gpus=True)
+
+
+@pytest.mark.gpu
+def test_mpt_generate_callback(build_tiny_mpt: Callable[...,
+ ComposerMPTCausalLM],
+ tiny_ft_dataloader: DataLoader):
+ device = get_device('gpu')
+
+ # build mpt model
+ model = build_tiny_mpt()
+ model = device.module_to_device(model)
+
+ # generate callback
+ prompts = [
+ 'The best banana bread recipe is',
+ '2+2=',
+ 'how much wood could a woodchuck chuck',
+ ]
+ gen_interval = 1
+ generate = ComposerGenerate(
+ prompts,
+ interval=f'{gen_interval}ba',
+ max_new_tokens=5,
+ batch_size=len(prompts),
+ use_cache=True,
+ )
+ generate.generate = Mock(wraps=generate.generate, autospec=True)
+
+ # build trainer
+ trainer = Trainer(
+ model=model,
+ train_dataloader=tiny_ft_dataloader,
+ device=device,
+ max_duration=f'{gen_interval}ba',
+ callbacks=[generate],
+ )
+ trainer.logger.log_table = Mock()
+ trainer.fit()
+
+ generate.generate.assert_called_once()
+ trainer.logger.log_table.assert_called_once()
diff --git a/tests/test_onnx.py b/tests/test_onnx.py
index 4ccb8e4112..d0e01746eb 100644
--- a/tests/test_onnx.py
+++ b/tests/test_onnx.py
@@ -4,7 +4,6 @@
import pathlib
import torch
-from composer.utils import reproducibility
from transformers import AutoModelForCausalLM
from llmfoundry import MPTConfig, MPTForCausalLM
@@ -27,7 +26,6 @@ def gen_random_batch(batch_size: int, vocab_size: int, max_seq_len: int):
def test_onnx_export(tmp_path: pathlib.Path):
- reproducibility.seed_all(42)
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
CONFIG_MAPPING._extra_content['mpt'] = MPTConfig
AutoModelForCausalLM.register(MPTConfig, MPTForCausalLM)
diff --git a/tests/test_tiktoken.py b/tests/test_tiktoken.py
index a255a5ffa7..85ff18100b 100644
--- a/tests/test_tiktoken.py
+++ b/tests/test_tiktoken.py
@@ -45,14 +45,19 @@
def get_tokenizers_for_testing(
- model_name: Optional[str], encoding_name: Optional[str],
- tmp_path: pathlib.Path
+ model_name: Optional[str],
+ encoding_name: Optional[str],
+ tmp_path: pathlib.Path,
+ add_bos_token: bool = False,
+ add_eos_token: bool = False
) -> Tuple[TiktokenTokenizerWrapper, TiktokenTokenizerWrapper, 'Encoding']:
tiktoken = pytest.importorskip('tiktoken')
# Construction
wrapped_tokenizer = TiktokenTokenizerWrapper(model_name=model_name,
- encoding_name=encoding_name)
+ encoding_name=encoding_name,
+ add_bos_token=add_bos_token,
+ add_eos_token=add_eos_token)
if model_name is not None:
original_tokenizer = tiktoken.encoding_for_model(model_name)
else:
@@ -201,3 +206,29 @@ def test_tiktoken_save_from_pretrained(model_name: Optional[str],
model_name, encoding_name, tmp_path)
check_hf_tokenizer_equivalence(wrapped_tokenizer,
reloaded_wrapped_tokenizer)
+
+
+@pytest.mark.parametrize('model_name,encoding_name',
+ MODEL_ENCODING_NAME_PARAMETRIZATION)
+def test_tiktoken_encode_plus(model_name: Optional[str],
+ encoding_name: Optional[str],
+ tmp_path: pathlib.Path):
+ # Testing encode_plus which optionally wrap encodes with bos and eos tokens
+ wrapped_tokenizer, _, _ = get_tokenizers_for_testing(model_name,
+ encoding_name,
+ tmp_path,
+ add_bos_token=True,
+ add_eos_token=True)
+
+ for test_string in TEST_STRINGS:
+ encoded_outputs = wrapped_tokenizer.encode_plus(
+ test_string,
+ add_special_tokens=True,
+ return_special_tokens_mask=True)
+ encoded_input_ids = encoded_outputs.input_ids
+ assert encoded_input_ids[0] == wrapped_tokenizer.bos_token_id
+ assert encoded_input_ids[-1] == wrapped_tokenizer.eos_token_id
+
+ encoded_special_mask = encoded_outputs.special_tokens_mask
+ assert encoded_special_mask[0] == 1
+ assert encoded_special_mask[-1] == 1