diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml
index cf6b85b193..144e3f1ad3 100644
--- a/.github/workflows/release.yaml
+++ b/.github/workflows/release.yaml
@@ -32,14 +32,6 @@ jobs:
PYPI_PACKAGE_NAME="llm-foundry-test-$(date +%Y%m%d%H%M%S)"
fi
- # Remove the peft, xentropy-cuda-lib and triton-pre-mlir dependencies as PyPI does not
- # support direct installs. The error message for importing PEFT, FusedCrossEntropy,
- # and flash_attn_triton gives instructions on how to install if a user tries to use it
- # without this dependency.
- sed '/xentropy-cuda-lib@git+https:\/\/github.com\/HazyResearch\/flash-attention.git@.*/d' -i setup.py
- sed '/triton-pre-mlir@git+https:\/\/github.com\/vchiley\/triton.git@.*/d' -i setup.py
- sed '/peft@git+https:\/\/github.com\/huggingface\/peft.git.*/d' -i setup.py
-
python -m pip install --upgrade build twine
python -m build
twine check --strict dist/*
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 62bc853fb5..3551482244 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,5 @@
default_language_version:
python: python3
-exclude: llmfoundry/models/layers/flash_attn_triton.py
repos:
- repo: https://github.com/google/yapf
rev: v0.32.0
diff --git a/README.md b/README.md
index aaf4b70c35..7ce58c772f 100644
--- a/README.md
+++ b/README.md
@@ -40,6 +40,20 @@ You'll find in this repo:
* `mcli/` - launch any of these workloads using [MCLI](https://docs.mosaicml.com/projects/mcli/en/latest/) and the [MosaicML platform](https://www.mosaicml.com/platform)
* `TUTORIAL.md` - a deeper dive into the repo, example workflows, and FAQs
+# DBRX
+
+DBRX is a state-of-the-art open source LLM trained by Databricks Mosaic team. It uses the Mixture-of-Experts (MoE) architecture and was trained with optimized versions of [Composer](https://github.com/mosaicml/composer), LLM Foundry, and [MegaBlocks](https://github.com/databricks/megablocks). The model has 132B total parameters and 36B active parameters. We have released two DBRX models:
+
+
+| Model | Context Length | Download |
+| ------------------ | -------------- | -------------------------------------------------- |
+| DBRX Base | 32768 | https://huggingface.co/databricks/dbrx-base |
+| DBRX Instruct | 32768 | https://huggingface.co/databricks/dbrx-instruct |
+
+Our model weights and code are licensed for both researchers and commercial entities. The Databricks Open Source License can be found at [LICENSE](https://github.com/databricks/dbrx/LICENSE), and our Acceptable Use Policy can be found [here](https://www.databricks.com/legal/acceptable-use-policy-open-model).
+
+For more information about the DBRX models, see https://github.com/databricks/dbrx.
+
# MPT
Mosaic Pretrained Transformers (MPT) are GPT-style models with some special features -- Flash Attention for efficiency, ALiBi for context length extrapolation, and stability improvements to mitigate loss spikes. As part of MosaicML's Foundation series, we have open-sourced several MPT models:
@@ -184,7 +198,6 @@ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.o
**Lastly**, install the ROCm enabled flash attention (instructions [here](https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm2#amd-gpurocm-support)).
Notes:
-1. `attn_impl: triton` does not work.
1. We don't yet have a Docker image where everything works perfectly. You might need to up/downgrade some packages (in our case, we needed to downgrade to `numpy==1.23.5`) before everything works without issue.
### Intel Gaudi
diff --git a/TUTORIAL.md b/TUTORIAL.md
index 8fa0e41a92..3be4910c4f 100644
--- a/TUTORIAL.md
+++ b/TUTORIAL.md
@@ -32,10 +32,8 @@ This tutorial will provide a brief intro to the repo’s structure and underlyin
- [What hardware can I run eval on?](#what-hardware-can-i-run-eval-on)
- [What hardware can I run inference on?](#what-hardware-can-i-run-inference-on)
- [What is FSDP?](#what-is-fsdp)
- - [What are the different attention options `torch` / `flash` / `triton` for MPT and which one should I use?](#what-are-the-different-attention-options-torch--flash--triton--for-mpt-and-which-one-should-i-use)
+ - [What are the different attention options `torch` / `flash` for MPT and which one should I use?](#what-are-the-different-attention-options-torch--flash--for-mpt-and-which-one-should-i-use)
- [Limitations](#limitations)
- - [What is `triton-pre-mlir`?](#what-is-triton-pre-mlir)
- - [Known issue with sm86+ GPUs](#known-issue-with-sm86-gpus)
- [Support for FlashAttention-2](#support-for-flashattention-2)
- [What kinds of positional embeddings does LLM Foundry support?](#what-kinds-of-positional-embeddings-does-llm-foundry-support)
- [Can I finetune using PEFT / LoRA?](#can-i-finetune-using-peft--lora)
@@ -144,7 +142,7 @@ name = 'mosaicml/mpt-7b'
# Download config
config = AutoConfig.from_pretrained(name, trust_remote_code=True)
-# (Optional) Use `flash` (preferred) or `triton` backend for fast attention. Defaults to `torch`.
+# (Optional) Use `flash` (preferred) backend for fast attention. Defaults to `torch`.
# config.attn_config['attn_impl'] = 'flash'
# (Optional) Change the `max_seq_len` allowed for inference
# config.max_seq_len = 4096
@@ -291,7 +289,7 @@ The purpose of this section is probably pretty self-evident. You’ve got questi
- If OOMs persist with `device_train_microbatch_size: 1` and `device_eval_batch_size: 1`, you may need to use activation checkpointing `fsdp_config.activation_checkpointing: true` (if you are not already) and, as a last resort, activation CPU offloading `fsdp_config.activation_cpu_offload: true`.
### What hardware can I train on?
-- In general, this repo should work on any system with NVIDIA GPUs. Checkout the `scripts/train/README.md` for more [details on GPU memory requirements]([https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#how-many-gpus-do-i-need-to-train-a-llm](https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#how-many-gpus-do-i-need-to-train-a-llm)). We recommend using `Flash` attention instead of `Triton` attention, unless you're training Prefix Language Models (in which case use `Triton`). Keep in mind you may run into issues with `Flash` or `Triton` support on some GPU types. In that situation, you can fall back to `attn_impl: torch`, or raise an issue in the [Flash Attention github repo](https://github.com/Dao-AILab/flash-attention).
+- In general, this repo should work on any system with NVIDIA GPUs. Checkout the `scripts/train/README.md` for more [details on GPU memory requirements]([https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#how-many-gpus-do-i-need-to-train-a-llm](https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#how-many-gpus-do-i-need-to-train-a-llm)). We recommend using `Flash` attention. Keep in mind you may run into issues with `Flash` support on some GPU types. In that situation, you can fall back to `attn_impl: torch`, or raise an issue in the [Flash Attention github repo](https://github.com/Dao-AILab/flash-attention).
### What hardware can I run eval on?
- Similar to above…
@@ -302,8 +300,8 @@ The purpose of this section is probably pretty self-evident. You’ve got questi
### What is FSDP?
- [Fully Sharded Data Parallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) is a PyTorch implementation of the [Zero Redundancy Optimizer (ZeRO)](https://arxiv.org/abs/1910.02054). FSDP shards networks parameters and the optimizer state across all GPUs. This enables users to train models with large parameter counts which do not fit into a single GPUs memory.
-### What are the different attention options `torch` / `flash` / `triton` for MPT and which one should I use?
-- **Short answer:** `torch` is the native pytorch attention implementation, and `flash` and `triton` are different implementations of the much more optimized [Flash Attention](https://arxiv.org/abs/2205.14135) method. `triton` and `flash` will be faster (and use less GPU memory) than `torch`, but they might not work with all hardware and environment setups.
+### What are the different attention options `torch` / `flash` for MPT and which one should I use?
+- **Short answer:** `torch` is the native pytorch attention implementation, and `flash` is an implementation of the much more optimized [Flash Attention](https://arxiv.org/abs/2205.14135) method. `flash` will be faster (and use less GPU memory) than `torch`, but they might not work with all hardware and environment setups.
Our training setups typically use `flash`.
@@ -313,7 +311,6 @@ Furthermore, integrating a recomputation schema decreases the sequence length me
- Setting `attn_config.attn_impl=torch` enables a naive Softmax Attention written using base torch operations.
- Setting `attn_config.attn_impl=flash` enables Flash Attention [implemented by Dao et al in the Dao-AILab repo using CUDA](https://github.com/Dao-AILab/flash-attention). This will have linear memory complexity (enabling larger batch sizes) and will run much faster.
- - Setting `attn_config.attn_impl=triton` enables a Flash Attention [implemented using Triton](https://github.com/mosaicml/llm-foundry/blob/main/llmfoundry/models/layers/flash_attn_triton.py). We recommend using `flash` attention instead of `triton` attention, unless you're training Prefix Language Models (in which case use `Triton`).
+The majority of our training setups use `flash`. -->
#### Limitations
- For training, `torch` uses a lot of memory and is slow.
-- `flash` and `triton` cannot return attention weights and therefore cannot be used with methods that require it.
+- `flash` cannot return attention weights and therefore cannot be used with methods that require it.
- `flash` cannot accept an attention bias. However, it still allows the use of ALiBi positional bias.
-#### What is `triton-pre-mlir`?
-- Torch2 installs and requires a specific version of [Triton](https://openai.com/research/triton).
- `attn_config.attn_impl=triton` requires a different version of triton.
- As a result, you can either use torch2 or `attn_impl=triton`.
- To enable both, we fork triton and make it pip installable as `triton-pre-mlir`.
- `attn_impl=triton` can then use `triton-pre-mlir` leaving the version of triton required for torch2 intact.
-
-#### Known issue with sm86+ GPUs
-- Under the hood, part of `triton-pre-mlir` compile path uses LLVM11.
- H100 GPUs (sm90 GPUs) are not formally supported until LLVM15 (technically it doesn't support anything sm86+).
- Updating the LLVM version used by `triton-pre-mlir` to LLVM13 seems to be relatively easy.
- Updating to LLVM14 (or LLVM15) cannot be done because there are breaking changes.
- What is the result of this? Although sm89+ is not **formally** supported until LLVM15, our testing on H100 GPUs shows that `attn_impl=triton` still works well and still runs fast. The only issue is that when the network is starting to run, LLVM might throw a warning like: `'sm_90' is not a recognized processor for this target (ignoring processor)`. This warning does not seem to affect performance.
-
#### Support for FlashAttention-2
- [FlashAttention-2](https://arxiv.org/pdf/2307.08691.pdf) improves upon FlashAttention to get even faster attention computation. LLM Foundry supports FlashAttention-2. Please follow the instructions [here](https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#flashattention).
@@ -352,8 +334,8 @@ Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706.
| Name | YAML Config | Training MFU on MPT-7B trained on 8 A100 80GB GPUs | Notes |
|:-----------------------------------|:------------------------------------------------------------------|:---------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Learned Positional Embeddings |
model:
learned_pos_emb: True
| 65.7 | |
-| ALiBi | model:
attn_config:
alibi: True
| 64.5 | Requires Flash (v2.4.2 or higher) or Triton or Torch attention. |
-| RoPE (Dao-AILab Implementation) | model:
attn_config:
rope: True
rope_impl: dail
| 64.5 | Requires a CUDA GPU and the [flash-attn library](https://github.com/Dao-AILab/flash-attention) v2.0.1 or higher to be installed. Please see the instructions in the [paragraph above](#support-for-flashattention-2) on how to install flash-attn v2. Note that the attention implementation can still be `torch`, `triton`, or `flash`. |
+| ALiBi | model:
attn_config:
alibi: True
| 64.5 | Requires Flash (v2.4.2 or higher) or Torch attention. |
+| RoPE (Dao-AILab Implementation) | model:
attn_config:
rope: True
rope_impl: dail
| 64.5 | Requires a CUDA GPU and the [flash-attn library](https://github.com/Dao-AILab/flash-attention) v2.0.1 or higher to be installed. Please see the instructions in the [paragraph above](#support-for-flashattention-2) on how to install flash-attn v2. Note that the attention implementation can still be `torch` or `flash`. |
| RoPE (Hugging
Face Implementation) | model:
attn_config:
rope: True
rope_impl: hf
| 62.3 | |
### Can I finetune using PEFT / LoRA?
diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py
index ea9eba74ce..922f738e9a 100644
--- a/llmfoundry/__init__.py
+++ b/llmfoundry/__init__.py
@@ -20,15 +20,13 @@
hf_dynamic_modules_logger.addFilter(new_files_warning_filter)
from llmfoundry import algorithms, callbacks, loggers, optim, registry, utils
-from llmfoundry.data import (ConcatTokensDataset, MixtureOfDenoisersCollator,
- NoConcatDataset, Seq2SeqFinetuningCollator,
- build_finetuning_dataloader,
- build_text_denoising_dataloader)
-from llmfoundry.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM,
- ComposerHFT5)
+from llmfoundry.data import (ConcatTokensDataset, NoConcatDataset,
+ Seq2SeqFinetuningCollator,
+ build_finetuning_dataloader)
+from llmfoundry.models.hf import ComposerHFCausalLM, ComposerHFT5
from llmfoundry.models.layers.attention import (
MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
- flash_attn_fn, scaled_multihead_dot_product_attention, triton_flash_attn_fn)
+ flash_attn_fn, scaled_multihead_dot_product_attention)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig,
@@ -36,9 +34,7 @@
from llmfoundry.tokenizers import TiktokenTokenizerWrapper
__all__ = [
- 'build_text_denoising_dataloader',
'build_finetuning_dataloader',
- 'MixtureOfDenoisersCollator',
'Seq2SeqFinetuningCollator',
'MPTBlock',
'FFN_CLASS_REGISTRY',
@@ -50,11 +46,9 @@
'MPTForCausalLM',
'ComposerMPTCausalLM',
'ComposerHFCausalLM',
- 'ComposerHFPrefixLM',
'ComposerHFT5',
'scaled_multihead_dot_product_attention',
'flash_attn_fn',
- 'triton_flash_attn_fn',
'MultiheadAttention',
'NoConcatDataset',
'ConcatTokensDataset',
@@ -70,4 +64,4 @@
'registry',
]
-__version__ = '0.6.0'
+__version__ = '0.7.0'
diff --git a/llmfoundry/data/__init__.py b/llmfoundry/data/__init__.py
index b64f54cfa3..45d1f8237f 100644
--- a/llmfoundry/data/__init__.py
+++ b/llmfoundry/data/__init__.py
@@ -3,8 +3,6 @@
from llmfoundry.data.data import ConcatTokensDataset, NoConcatDataset
from llmfoundry.data.dataloader import build_dataloader
-from llmfoundry.data.denoising import (MixtureOfDenoisersCollator,
- build_text_denoising_dataloader)
from llmfoundry.data.finetuning import (Seq2SeqFinetuningCollator,
build_finetuning_dataloader)
from llmfoundry.data.text_data import (StreamingTextDataset,
@@ -12,12 +10,9 @@
from llmfoundry.registry import dataloaders
dataloaders.register('text', func=build_text_dataloader)
-dataloaders.register('text_denoising', func=build_text_denoising_dataloader)
dataloaders.register('finetuning', func=build_finetuning_dataloader)
__all__ = [
- 'MixtureOfDenoisersCollator',
- 'build_text_denoising_dataloader',
'Seq2SeqFinetuningCollator',
'build_finetuning_dataloader',
'StreamingTextDataset',
diff --git a/llmfoundry/data/denoising.py b/llmfoundry/data/denoising.py
deleted file mode 100644
index 303c9298bb..0000000000
--- a/llmfoundry/data/denoising.py
+++ /dev/null
@@ -1,957 +0,0 @@
-# Copyright 2022 MosaicML LLM Foundry authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""Streaming dataloader for (mixture of) denoising task(s)."""
-
-import logging
-import random
-import sys
-import warnings
-from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union
-
-import numpy as np
-import torch
-from composer.core.data_spec import DataSpec
-from omegaconf import DictConfig
-from omegaconf import OmegaConf as om
-from torch.utils.data import DataLoader
-from transformers import PreTrainedTokenizerBase
-
-from llmfoundry.data.packing import BinPackCollator
-from llmfoundry.data.text_data import (StreamingTextDataset,
- get_tokens_per_batch_func)
-from llmfoundry.models import utils
-from llmfoundry.utils.warnings import VersionedDeprecationWarning
-
-__all__ = ['MixtureOfDenoisersCollator', 'build_text_denoising_dataloader']
-
-log = logging.getLogger(__name__)
-
-# HuggingFace hardcodes the ignore index to -100
-_HF_IGNORE_INDEX = -100
-
-# Required signature of any `prefix_function` (see below)
-PREFIX_FUNCTION = Callable[[float, Optional[float], PreTrainedTokenizerBase],
- Sequence[int]]
-
-
-def ul2_prefix_function(
- mask_ratio: float,
- mean_length: Optional[float],
- tokenizer: PreTrainedTokenizerBase,
-) -> Sequence[int]:
- """Generates prefixes based on UL2 paper.
-
- See: http://arxiv.org/abs/2205.05131
- """
- if mean_length is None:
- # This is the case for "sequence to sequence"
- prefix = '[S2S]' if mask_ratio < 1.0 else '[CLM]'
- elif mean_length >= 12 or mask_ratio >= 0.3:
- # UL2 tags this corruption rate "extreme"
- prefix = '[NLG]'
- else:
- # UL2 tags this corruption rate as "regular"
- prefix = '[NLU]'
- return tokenizer(prefix, add_special_tokens=False).input_ids
-
-
-class MixtureOfDenoisersCollator:
- """Data collator for mixture of span-corruption denoisers, as in UL2.
-
- This collator supports a variety of tasks used to pre-train an
- encoder-decoder model or a (prefix LM) decoder-only model. This is meant
- to be used with a dataset that yields tokenized text sequences. It is not
- required that the token sequences are already padded or truncate, as this
- collator will internally truncate and pad as needed.
-
- For the denoising mixture recommended in the original UL2 paper,
- http://arxiv.org/abs/2205.05131, use:
- .. python:
- MixtureOfDenoisersCollator(
- ...,
- span_mean_lengths_and_ratios=[
- [3, .15],
- [8, .15],
- [3, .50],
- [8, .50],
- [64, .15],
- [64, .50],
- ],
- sequence_mask_ratios=0.25
- )
-
- Args:
- tokenizer (transformers.PreTrainedTokenizer): The tokenizer used to
- prepare the data from raw text. Any missing sentinel tokens will
- be added by the collator.
- max_seq_length (int): The maximum length of sequences produced by this
- collator. Incoming sequences may be truncated to accommodate this
- limit.
- Note that when formatting for decoder-only models, the context
- tokens and target tokens are concatenated, and max_seq_length
- applies to their combined length. For encoder-decoder models, both
- the encoder and decoder will see up to max_seq_length tokens.
- decoder_only_format (bool, optional): Whether to format the batches
- for a decoder-only model (i.e. a prefix LM) or, if ``False``, an
- encoder-decoder model. Default: ``False``.
- span_mean_lengths_and_rations (optional): A length-2 list of a
- ``[mean_length, mask_ratio]`` pair, or a list of such pairs. Each
- pair adds a span corruption denoising task to the task mixture. For
- example, ``[3, 0.15]`` adds the original span corruption task used
- for pre-training a T5 model as in http://arxiv.org/abs/1910.10683,
- which trained with a single span corruption task that used a mean
- span length of 3 and a mask ratio of 15%.
- Default: ``None`` does not add any span corruption tasks.
- sequence_mask_ratios (optional): A float or list of floats, one for each
- sequence corruption denoising task to add to the task mixture. Each
- sequence mask ratio must be greater than 0.0 and at most 1.0.
- This type of task is a special instance of span corruption, with
- exactly one masked span take from the end of the sequence. The
- length of the span is sampled uniformly such that the average
- portion of masked tokens equals sequence_mask_ratio.
- Note: A value of 1.0 essentially yields causal LM examples.
- Default: ``None` does not add any sequence corruption tasks.
- allow_pad_trimming (bool, optional): Whether to allow the collator to
- trim away sequence regions that are entirely padding (i.e. padding
- for each example in the batch). If ``True``, shorter sequences may
- improve throughput but at a potentially higher memory cost
- owing to variable sequence lengths from batch to batch.
- Default: ``False`` yields batches that are always padded to
- max_seq_length.
- prefix_function (callable, optional): A function that maps denoising
- task parameters (e.g. mean_length=3, mask_ratio=0.15) to a prefix
- that will be added to sequences when the associated "noiser" is
- applied.
- To disable these prefixes, use a value of ``None``.
- Default: :func:`ul2_prefix_function` applies the prefix scheme
- suggested in the UL2 paper: http://arxiv.org/abs/2205.05131.
- context_eos (bool, optional): Whether to attach an EOS token to the end
- of the context sequence, marking the transition from context to
- target sequence. Only applicable if decoder_only_format is True.
- Context EOS tokens are always added for encoder-decoder format.
- Default: ``False`` does not attach context EOS.
- """
-
- def __init__(
- self,
- tokenizer: PreTrainedTokenizerBase,
- max_seq_length: int,
- decoder_only_format: bool = False,
- span_mean_lengths_and_ratios: Optional[List] = None,
- sequence_mask_ratios: Optional[Union[List[float], float]] = None,
- allow_pad_trimming: bool = False,
- prefix_function: Optional[PREFIX_FUNCTION] = ul2_prefix_function,
- context_eos: Optional[bool] = None,
- ):
- # Prepare the tokenizer for denoising tasks
- utils.adapt_tokenizer_for_denoising(tokenizer)
-
- self.tokenizer = tokenizer
- self.max_seq_length = max_seq_length
- self.decoder_only_format = decoder_only_format
- self._sentinel_token_ids = np.array(self.tokenizer.sentinel_token_ids)
-
- # Trimming will always be skipped on at least the first __call__
- self._allow_pad_trimming = allow_pad_trimming
- self._seen_first_batch = False
-
- self.context_eos = bool(context_eos) if decoder_only_format else True
-
- # Process the span_mean_lengths_and_ratios argument
- if span_mean_lengths_and_ratios is None:
- # In this case, there are no span corruption tasks
- self.span_mean_lengths_and_ratios = []
- elif isinstance(span_mean_lengths_and_ratios[0], (int, float)):
- # In this case, there is one span corruption task
- if not len(span_mean_lengths_and_ratios) == 2:
- raise ValueError('`span_mean_lengths_and_ratios` must be a ' + \
- 'pair of [mean_length, mask_ratio], a list ' + \
- f'of such pairs, or None. Got {span_mean_lengths_and_ratios}.')
- self.span_mean_lengths_and_ratios = [span_mean_lengths_and_ratios]
- else:
- # In this case, there are one or more span corruption tasks
- span_mean_lengths_and_ratios = list(span_mean_lengths_and_ratios)
- for spec_pair in span_mean_lengths_and_ratios:
- if len(spec_pair) != 2:
- raise ValueError('`span_mean_lengths_and_ratios` must be a ' + \
- 'pair of [mean_length, mask_ratio], a list ' + \
- f'of such pairs, or None. Got {span_mean_lengths_and_ratios}.')
- self.span_mean_lengths_and_ratios = span_mean_lengths_and_ratios
-
- # Process the sequence_mask_ratios argument
- if sequence_mask_ratios is None:
- # In this case, there are no sequence corruption tasks
- self.sequence_mask_ratios = []
- elif isinstance(sequence_mask_ratios, float):
- # In this case, there is one sequence corruption task
- self.sequence_mask_ratios = [sequence_mask_ratios]
- else:
- # In this case, there is one or more sequence corruption tasks
- for ratio in sequence_mask_ratios:
- if not (0 < ratio <= 1.0):
- raise ValueError('`sequence_mask_ratios` must be a float (or list '+\
- 'of floats) that are each >0.0 and <=1.0, or None. '+\
- f'Got {sequence_mask_ratios}.')
- self.sequence_mask_ratios = sequence_mask_ratios
-
- # Populate the noisers so we can learn to denoise them!
- self._noisers = []
- self._smallest_max_raw_length = self.max_seq_length * 100
- self._largest_max_raw_length = 0
- self._uses_span_corruption = False
-
- # Add "noisers" for any span corruption denoising tasks
- # Each mean_length / mask_ratio combo becomes one of the span
- # corruption denoising tasks
- for span_mean_length, span_mask_ratio in self.span_mean_lengths_and_ratios:
- self._uses_span_corruption = True
- if span_mean_length < 0:
- raise ValueError('All span mean lengths must be positive.')
- if not 0 < span_mask_ratio < 1.0:
- raise ValueError(
- 'All span masking ratios must be between 0.0 and 1.0.')
-
- if prefix_function is not None:
- prefix_tokens = prefix_function(span_mask_ratio,
- span_mean_length,
- self.tokenizer)
- else:
- prefix_tokens = None
-
- max_raw_length = _get_max_starting_length(
- max_length=self.max_seq_length,
- mask_ratio=span_mask_ratio,
- mean_span_length=span_mean_length,
- n_prefix_tokens=len(prefix_tokens or []),
- decoder_only_format=self.decoder_only_format,
- context_eos=self.context_eos)
- if max_raw_length < self._smallest_max_raw_length:
- self._smallest_max_raw_length = max_raw_length
- if max_raw_length > self._largest_max_raw_length:
- self._largest_max_raw_length = max_raw_length
-
- kwargs = {
- 'mean_span_length': span_mean_length,
- 'mask_ratio': span_mask_ratio,
- 'prefix_tokens': prefix_tokens,
- 'max_raw_length': max_raw_length,
- }
- self._noisers.append(kwargs)
-
- # Add "noisers" for any sequential denoising tasks
- for sequence_mask_ratio in self.sequence_mask_ratios:
- if prefix_function is not None:
- prefix_tokens = prefix_function(sequence_mask_ratio, None,
- self.tokenizer)
- else:
- prefix_tokens = None
-
- max_raw_length = self.max_seq_length - len(prefix_tokens or []) - 1
- if decoder_only_format and self.context_eos:
- max_raw_length = max_raw_length - 1
-
- if not self._uses_span_corruption and (
- max_raw_length < self._smallest_max_raw_length):
- # We choose not to count sequence denoising in the smallest
- # unless there is only sequence denoising.
- self._smallest_max_raw_length = max_raw_length
- if max_raw_length > self._largest_max_raw_length:
- self._largest_max_raw_length = max_raw_length
-
- kwargs = {
- 'mean_span_length': None,
- 'mask_ratio': sequence_mask_ratio,
- 'prefix_tokens': prefix_tokens,
- 'max_raw_length': max_raw_length,
- }
- self._noisers.append(kwargs)
-
- if not self._noisers:
- raise ValueError(
- 'No denoising tasks were included. Make sure to set ' + \
- '`span_mean_lengths_and_ratios` and/or `sequence_mask_ratios`.')
-
- @property
- def smallest_max_raw_length(self) -> int:
- return int(self._smallest_max_raw_length)
-
- @property
- def largest_max_raw_length(self) -> int:
- return int(self._largest_max_raw_length)
-
- def __call__(self, examples: List[Dict[str,
- Any]]) -> Dict[str, torch.Tensor]:
- """Batch examples processed by the span corrupter."""
- processed_examples = []
- for example in examples:
- # Randomly pick a "noiser" to apply to this example
- noiser = random.choice(self._noisers)
- # Apply it
- processed_examples.append(
- noise_token_sequence(
- example,
- mask_ratio=noiser['mask_ratio'],
- mean_span_length=noiser['mean_span_length'],
- prefix_tokens=noiser['prefix_tokens'],
- max_raw_length=noiser['max_raw_length'],
- max_seq_length=self.max_seq_length,
- tokenizer=self.tokenizer,
- sentinel_token_ids=self._sentinel_token_ids,
- decoder_only_format=self.decoder_only_format,
- context_eos=self.context_eos))
- batch = self.tokenizer.pad(processed_examples)
-
- # This logic prevents trimming on at least the first batch
- if not (self._allow_pad_trimming and self._seen_first_batch):
- self._seen_first_batch = True
- return batch
- self._seen_first_batch = True
-
- # Truncate portions of the inputs that are purely padding
- # (up to a multiple of 8)
- multiple_of = 8
- n_examples_per_length = batch['attention_mask'].sum(0)
- keep_tokens = torch.sum(n_examples_per_length > 0)
- keep_tokens = int(multiple_of * torch.ceil(keep_tokens / multiple_of))
-
- # Note: EncDec formatting will always produce a right-padded batch
- if self.tokenizer.padding_side == 'left' and self.decoder_only_format:
- batch['input_ids'] = batch['input_ids'][:, -keep_tokens:]
- batch['attention_mask'] = batch['attention_mask'][:, -keep_tokens:]
- else:
- batch['input_ids'] = batch['input_ids'][:, :keep_tokens]
- batch['attention_mask'] = batch['attention_mask'][:, :keep_tokens]
-
- if self.decoder_only_format:
- if self.tokenizer.padding_side == 'left':
- batch['labels'] = batch['labels'][:, -keep_tokens:]
- batch['bidirectional_mask'] = batch[
- 'bidirectional_mask'][:, -keep_tokens:]
- else:
- batch['labels'] = batch['labels'][:, :keep_tokens]
- batch['bidirectional_mask'] = batch[
- 'bidirectional_mask'][:, :keep_tokens]
-
- else:
- # Truncate portions of the decoder inputs that are purely padding
- n_examples_per_length = batch['decoder_attention_mask'].sum(0)
- keep_tokens = torch.sum(n_examples_per_length > 0)
- keep_tokens = int(multiple_of *
- torch.ceil(keep_tokens / multiple_of))
-
- batch['labels'] = batch['labels'][:, :keep_tokens]
- batch['decoder_attention_mask'] = batch[
- 'decoder_attention_mask'][:, :keep_tokens]
- batch['decoder_input_ids'] = batch[
- 'decoder_input_ids'][:, :keep_tokens]
-
- # This slicing can produce non-contiguous tensors, so use .contiguous
- # to prevent related problems
- batch = {k: v.contiguous() for k, v in batch.items()}
-
- return batch
-
-
-def build_text_denoising_dataloader(
- cfg: DictConfig,
- tokenizer: PreTrainedTokenizerBase,
- device_batch_size: int,
-) -> DataSpec:
- """Constructor function for a Mixture of Denoisers dataloader.
-
- This function constructs a dataloader that can be used to train an
- encoder-decoder model or a (prefix LM) decoder-only model on a text
- denoising task mixture (e.g. span corruption, or UL2).
-
- The underlying dataset is a :class:`StreamingTextDataset`, allowing you to
- stream raw text data or pre-tokenized text data.
-
- The dataloader uses a :class:`MixtureOfDenoisersCollator` to prepare the
- tokenized examples into training batches.
-
- Args:
- cfg (DictConfig): An omegaconf dictionary used to configure the loader:
- cfg.name (str): The type of dataloader to build. Must = "text_denoising".
- ---
- cfg.dataset.max_seq_len (int): The maximum length of sequences
- in the batch. See :class:`MixtureOfDenoisersCollator` docstring
- for details.
- cfg.dataset.packing_ratio (Optional[float, Literal['auto']]): If provided, this invokes
- a collator wrapper that packs device_batch_size*packing_ratio
- raw examples into device_batch_size packed examples. This helps
- minimize padding while preserving sequence integrity.
- This adds `sequence_id` to the batch, which indicates which unique
- sequence each token belongs to.
-
- If set to 'auto', packing_ratio is profiled and the highest observed packing ratio with
- zero waste is selected.
- In practice, this may result in > 0 waste because profiling is done on only a portion
- of the dataset.
-
- Note: Using this feature will not change device_batch_size but it
- will determine the number of raw examples consumed by the dataloader
- per batch. Some examples may be discarded if they do not fit when
- packing.
- Select packing_ratio **carefully** based on the dataset
- statistics, max_seq_len, and tolerance for discarding samples!
- The script `scripts/misc/profile_packing.py` can help
- you choose the best packing_ratio.
- See :class:`StreamingTextDataset` for info on other standard config
- options within `cfg.dataset`.
- ---
- cfg.mixture_of_denoisers.decoder_only_format (bool): Whether the
- batches should use the format required for training a decoder-only
- model (if ``True``) or an encoder-decoder model (if ``False``).
- cfg.mixture_of_denoisers.span_mean_lengths_and_ratios (optional): The
- parameters for any span corruption denoising tasks to include in
- the task mixture.
- See :class:`MixtureOfDenoisersCollator` docstring for details.
- cfg.mixture_of_denoisers.sequence_mask_ratios (optional): The
- parameters for any sequence denoising tasks to include in the
- task mixture.
- See :class:`MixtureOfDenoisersCollator` docstring for details.
- cfg.mixture_of_denoisers.allow_pad_trimming (optional): Whether to
- allow the collator to trim padding when possible (if ``True``).
- Defaults to ``False``.
- cfg.mixture_of_denoisers.prefix_function (optional): Set to ``None``
- to disable the UL2-style prefixes that will be automatically
- added by default.
- ---
- See :class:`DataLoader` for standard argument options to the pytorch
- dataloader, such as `cfg.drop_last`, `cfg.num_workers`, etc.
- tokenizer (transformers.PreTrainedTokenizer): The tokenizer used to
- prepare the data from raw text. Any missing sentinel tokens will
- be added by the collator.
- device_batch_size (int): The size of the batches (number of examples)
- that the dataloader will produce.
-
- Note:
- You can use the script `scripts/misc/profile_packing.py` to quickly test the
- padding/waste rates for different `cfg.dataset.packing_ratio` choices,
- given a starting workload YAML.
- """
- warnings.warn(
- VersionedDeprecationWarning('Text denoising is deprecated.',
- remove_version='0.7.0'))
- assert cfg.name == 'text_denoising', f'Tried to build_denoising text dataloader with cfg.name={cfg.name}'
-
- collate_fn = MixtureOfDenoisersCollator(
- tokenizer=tokenizer,
- max_seq_length=cfg.dataset.max_seq_len,
- decoder_only_format=cfg.mixture_of_denoisers.decoder_only_format,
- span_mean_lengths_and_ratios=cfg.mixture_of_denoisers.get(
- 'span_mean_lengths_and_ratios'),
- sequence_mask_ratios=cfg.mixture_of_denoisers.get(
- 'sequence_mask_ratios'),
- allow_pad_trimming=cfg.mixture_of_denoisers.get('allow_pad_trimming',
- False),
- prefix_function=cfg.mixture_of_denoisers.get('prefix_function',
- ul2_prefix_function),
- context_eos=cfg.mixture_of_denoisers.get('context_eos'))
-
- truncate_to = cfg.mixture_of_denoisers.get('truncate_raw_tokens_to')
- if truncate_to is None:
- # By default, truncate to the largest max raw length of the denoisers
- truncate_to = collate_fn.largest_max_raw_length
- elif isinstance(truncate_to, str):
- if truncate_to.lower() == 'min':
- # Truncate to the smallest max raw length of the denoisers
- truncate_to = collate_fn.smallest_max_raw_length
- elif truncate_to.lower() == 'max':
- # Truncate to the largest max raw length of the denoisers
- truncate_to = collate_fn.largest_max_raw_length
- else:
- raise ValueError(
- f'truncate_raw_tokens_to(="{truncate_to.lower()}") must be "min", "max", a positive int, or None.'
- )
- else:
- if not isinstance(truncate_to, int):
- ValueError(
- f'truncate_raw_tokens_to(={truncate_to}) must be "min", "max", a positive int, or None.'
- )
- if truncate_to < 0:
- ValueError(
- f'truncate_raw_tokens_to(={truncate_to}) must be "min", "max", a positive int, or None.'
- )
-
- dataset = StreamingTextDataset(
- local=cfg.dataset.local,
- tokenizer=tokenizer,
- max_seq_len=truncate_to,
- remote=cfg.dataset.get('remote'),
- split=cfg.dataset.get('split'),
- shuffle=cfg.dataset.get('shuffle', False),
- predownload=cfg.dataset.get('predownload', None),
- keep_zip=cfg.dataset.get('keep_zip', False),
- download_retry=cfg.dataset.get('download_retry', 2),
- download_timeout=cfg.dataset.get('download_timeout', 60),
- validate_hash=cfg.dataset.get('validate_hash', None),
- shuffle_seed=cfg.dataset.get('shuffle_seed', 9176),
- num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', None),
- batch_size=device_batch_size,
- )
-
- if dataset.tokenizer.pad_token is None:
- dataset.tokenizer.pad_token = dataset.tokenizer.eos_token
-
- if cfg.dataset.get('packing_ratio'):
- n_examples_to_pack = int(device_batch_size * cfg.dataset.packing_ratio)
- if n_examples_to_pack < device_batch_size:
- raise ValueError('packing_ratio must be >= 1, if supplied')
- if not cfg.mixture_of_denoisers.decoder_only_format:
- raise NotImplementedError(
- 'On-the-fly packing is currently only supported for decoder-only formats.'
- )
- collate_fn = BinPackCollator(
- collator=collate_fn,
- target_batch_size=device_batch_size,
- max_seq_len=cfg.dataset.max_seq_len,
- pad_token_id=dataset.tokenizer.pad_token_id,
- padding_side=dataset.tokenizer.padding_side,
- max_leftover_bins_to_keep=cfg.dataset.get(
- 'max_leftover_bins_to_keep'),
- )
- device_batch_size = n_examples_to_pack
- elif cfg.dataset.get('max_leftover_bins_to_keep') is not None:
- raise ValueError(
- 'cfg.dataset.max_leftover_bins_to_keep has been defined, ' +\
- 'but cfg.dataset.packing_ratio has not been set. Please set ' +\
- 'the latter to turn on packing or remove the former from the config.')
-
- dl = DataLoader(
- dataset,
- collate_fn=collate_fn,
- batch_size=device_batch_size,
- drop_last=cfg.drop_last,
- num_workers=cfg.num_workers,
- pin_memory=cfg.get('pin_memory', True),
- prefetch_factor=cfg.get('prefetch_factor', 2),
- persistent_workers=cfg.get('persistent_workers', False),
- timeout=cfg.get('timeout', 0),
- )
-
- token_counting_func = get_tokens_per_batch_func(
- decoder_only=cfg.mixture_of_denoisers.decoder_only_format)
-
- return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)
-
-
-def noise_token_sequence(
- example: Union[torch.Tensor, Mapping[str, Any]],
- mask_ratio: float,
- mean_span_length: Optional[float],
- prefix_tokens: Optional[Sequence[int]],
- max_raw_length: int,
- max_seq_length: int,
- tokenizer: PreTrainedTokenizerBase,
- sentinel_token_ids: np.ndarray,
- decoder_only_format: bool,
- context_eos: bool,
-) -> Dict[str, torch.Tensor]:
- """Span corruption applicable to all UL2 denoising tasks."""
- # Extract the raw text tokens (trim if we need to)
- if isinstance(example, torch.Tensor):
- # If the example is a tensor, assume is the raw tokens with no padding
- tokens = example
- length = len(tokens)
- else:
- tokens = example['input_ids']
- length = sum(example['attention_mask'])
- if length > max_raw_length:
- length = max_raw_length
- if tokenizer.padding_side == 'left':
- tokens = tokens[-length:]
- else:
- tokens = tokens[:length]
-
- prefix_tokens = prefix_tokens or []
-
- if length < 1:
- raise ValueError('Example cannot be empty but token length <1.')
-
- # mean_span_length==None is a special case for "sequential" denoising
- # (where a single span at the end of the sequence is masked)
- if mean_span_length is None:
- # This ensures that exactly 1 span will be produced and that
- # trimming to max_seq_length won't cut off any token.
- # In the decoder-only case, this won't insert new tokens.
- if mask_ratio <= 0.5:
- u = np.random.uniform(low=0.0, high=mask_ratio * 2)
- else:
- u = np.random.uniform(low=(mask_ratio * 2) - 1, high=1.0)
- mean_span_length = float(np.round(1 + u * (length - 1)))
- mask_ratio = mean_span_length / length
- use_sentinels = False
- else:
- use_sentinels = True
-
- # Generate the mask
- # Note: this function can be used for all the UL2 noising functions
- mask = _sample_mask_array(length, mask_ratio, mean_span_length)
- # The sequence should always be unmasked at the beginning
- assert mask[0] == 0
-
- # Generate the input/label sequences given the raw tokens and the mask
- tokens_inputs = _apply_mask(tokens,
- mask,
- use_sentinels,
- tokenizer.eos_token_id,
- sentinel_token_ids,
- ensure_eos=context_eos)
- tokens_labels = _apply_mask(tokens,
- 1 - mask,
- use_sentinels,
- tokenizer.eos_token_id,
- sentinel_token_ids,
- ensure_eos=True)
-
- # Tag the inputs with any prefix
- if prefix_tokens:
- tokens_inputs = np.concatenate([prefix_tokens, tokens_inputs])
-
- # Trim if necessary
- if len(tokens_inputs) > max_seq_length:
- raise ValueError('This should not exceed the max length')
- if len(tokens_labels) > max_seq_length:
- raise ValueError('This should not exceed the max length')
-
- tokens_inputs = torch.LongTensor(tokens_inputs)
- tokens_labels = torch.LongTensor(tokens_labels)
-
- if decoder_only_format:
- return _format_tokens_for_decoder_only(tokens_inputs, tokens_labels,
- max_seq_length,
- tokenizer.pad_token_id,
- tokenizer.padding_side)
- return _format_tokens_for_encoder_decoder(tokens_inputs, tokens_labels,
- max_seq_length,
- tokenizer.pad_token_id)
-
-
-def _get_max_starting_length(max_length: int, mask_ratio: float,
- mean_span_length: float, n_prefix_tokens: int,
- decoder_only_format: bool,
- context_eos: bool) -> int:
- """Get max num raw tokens that will fit max_length."""
-
- def sequence_stats(length: int):
- length = np.maximum(length, 2)
- num_noise_tokens = int(np.round(mask_ratio * float(length)))
- num_noise_tokens = np.minimum(np.maximum(num_noise_tokens, 1),
- length - 1)
- num_spans = int(np.round(float(num_noise_tokens) / mean_span_length))
- num_noise_spans = np.maximum(num_spans, 1)
- num_nonnoise_tokens = length - num_noise_tokens
- # Prefix, sentinel, and EOS added to input for Enc-Dec
- extra_inp_tokens = n_prefix_tokens + num_noise_spans + int(context_eos)
- # Sentinel and EOS added to target
- extra_targ_tokens = num_noise_spans + 1
- # Sequence totals after corruption
- total_inp_tokens = num_nonnoise_tokens + extra_inp_tokens
- total_targ_tokens = num_noise_tokens + extra_targ_tokens
- return total_inp_tokens, total_targ_tokens
-
- def length_fits(length: int) -> bool:
- total_inp_tokens, total_targ_tokens = sequence_stats(length)
- if decoder_only_format:
- return (total_inp_tokens + total_targ_tokens) <= max_length
- return (total_inp_tokens <= max_length) and (total_targ_tokens <=
- max_length)
-
- # Start with a definitely too-long sequence and reduce until it fits
- num_raw_tokens = max_length * 2
- while num_raw_tokens > 0:
- if length_fits(num_raw_tokens):
- return num_raw_tokens
- num_raw_tokens -= 1
- raise ValueError(
- 'Unable to find a starting sequence length that can fit given the corruption and max_length parameters.'
- )
-
-
-def _sample_mask_array(length: int, mask_ratio: float,
- mean_span_length: float) -> np.ndarray:
- """Samples a span corruption mask."""
- if mask_ratio == 0.0:
- return np.zeros(length)
- # This first block computes the number of noise/non-noise spans and the
- # total tokens in each. Extra steps are taken to handle edge cases that
- # cause degeneracy.
- starting_length = length
- length = np.maximum(length, 2)
- num_noise_tokens = int(np.round(mask_ratio * float(length)))
- num_noise_tokens = np.minimum(np.maximum(num_noise_tokens, 1), length - 1)
- num_spans = int(np.round(float(num_noise_tokens) / mean_span_length))
- num_noise_spans = np.maximum(num_spans, 1)
- num_nonnoise_tokens = length - num_noise_tokens
-
- # Sample the noise/non-noise span lengths and interleave them to
- # generate the mask array.
- # Note: We always start with a non-noise span.
- def _sample_span_lengths(total_tokens: int, num_spans: int) -> np.ndarray:
- """Samples lengths of num_spans segments.
-
- Note: the combined length of segments equals total_tokens.
- """
- span_markers = np.less(np.arange(total_tokens - 1), num_spans -
- 1)[np.random.permutation(total_tokens - 1)]
- span_start_indicator = np.concatenate([np.array([0]), span_markers])
- span_id = np.cumsum(span_start_indicator).reshape(-1, 1)
- spans = np.arange(num_spans).reshape(1, -1)
- span_lengths = np.sum(span_id == spans, axis=0)
- return span_lengths
-
- noise_span_lengths = _sample_span_lengths(num_noise_tokens, num_noise_spans)
- nonnoise_span_lengths = _sample_span_lengths(num_nonnoise_tokens,
- num_noise_spans)
- interleaved_span_lengths = np.reshape(
- np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1),
- [num_noise_spans * 2])
-
- span_starts = np.cumsum(interleaved_span_lengths)[:-1]
- span_start_indicator = np.zeros(length)
- span_start_indicator[span_starts] = 1
- span_id = np.cumsum(span_start_indicator)
- is_noise = np.equal(np.mod(span_id, 2), 1)
-
- mask = is_noise[:starting_length]
-
- return mask
-
-
-def _apply_mask(tokens: Union[torch.Tensor, Sequence[int], np.ndarray],
- mask: np.ndarray,
- use_sentinels: bool,
- eos_token_id: int,
- sentinel_token_ids: np.ndarray,
- ensure_eos: bool = True) -> np.ndarray:
- """Remove or replace masked portions from token sequence."""
- if not use_sentinels:
- # The logic is simple if we don't use sentinel tokens
- noised_tokens = np.array(tokens)[np.logical_not(mask)]
-
- # Ensure there's an end-of-sentence token at the end
- if ensure_eos and (noised_tokens[-1] != eos_token_id):
- noised_tokens = np.concatenate(
- [noised_tokens, np.array([eos_token_id])])
-
- return noised_tokens
-
- # Masking at previous token
- prev_token_mask = np.concatenate([np.array([0]), mask[:-1]])
-
- # Decompose mask into start-of-span mask and non-start-of-span mask
- start_of_noise_span_token = np.logical_and(mask,
- np.logical_not(prev_token_mask))
- nonstart_noise_span_token = np.logical_and(mask, prev_token_mask)
-
- # Replace tokens at the start of each noise span with its corresponding
- # sentinel token
- sentinel_idx = np.minimum(len(sentinel_token_ids),
- np.cumsum(start_of_noise_span_token)) - 1
- tokens = np.where(start_of_noise_span_token,
- sentinel_token_ids[sentinel_idx], tokens)
-
- # Remove masked tokens (but preserving the sentinel tokens)
- noised_tokens = tokens[np.logical_not(nonstart_noise_span_token)]
-
- # Ensure there's an end-of-sentence token at the end
- if ensure_eos and (noised_tokens[-1] != eos_token_id):
- noised_tokens = np.concatenate(
- [noised_tokens, np.array([eos_token_id])])
- return noised_tokens
-
-
-def _format_tokens_for_encoder_decoder(
- tokens_inputs: torch.LongTensor,
- tokens_labels: torch.LongTensor,
- max_seq_length: int,
- pad_token_id: int,
-) -> Dict[str, torch.Tensor]:
- """Package the input/label sequence for an EncDec model."""
- example = {}
- # Re-populate with an empty, padded example
- example['input_ids'] = torch.full((max_seq_length,),
- pad_token_id,
- dtype=torch.int32)
- example['labels'] = torch.full((max_seq_length,),
- _HF_IGNORE_INDEX,
- dtype=torch.int32)
- example['attention_mask'] = torch.zeros_like(example['input_ids'])
- example['decoder_attention_mask'] = torch.zeros_like(example['labels'])
-
- # Fill in with processed results (Note: EncDec format is right-padded)
- example['input_ids'][:len(tokens_inputs)] = tokens_inputs
- example['labels'][:len(tokens_labels)] = tokens_labels
- example['attention_mask'][:len(tokens_inputs)] = 1
- example['decoder_attention_mask'][:len(tokens_labels)] = 1
-
- # Best practice is to include decoder_input_ids (= right-shifted labels)
- example['decoder_input_ids'] = torch.full_like(example['labels'],
- pad_token_id)
- example['decoder_input_ids'][1:len(tokens_labels)] = tokens_labels[:-1]
- return example
-
-
-def _format_tokens_for_decoder_only(
- tokens_inputs: torch.LongTensor,
- tokens_labels: torch.LongTensor,
- max_seq_length: int,
- pad_token_id: int,
- padding_side: str,
-) -> Dict[str, torch.Tensor]:
- """Package the input/label sequence for an decoder-only model."""
- example = {}
- # Re-populate with an empty, padded example
- example['input_ids'] = torch.full((max_seq_length,),
- pad_token_id,
- dtype=torch.int32)
- example['labels'] = torch.full((max_seq_length,),
- _HF_IGNORE_INDEX,
- dtype=torch.int32)
- example['attention_mask'] = torch.full((max_seq_length,),
- 0,
- dtype=torch.bool)
- example['bidirectional_mask'] = torch.full((max_seq_length,),
- 0,
- dtype=torch.bool)
-
- n_input = len(tokens_inputs)
- n_label = len(tokens_labels)
- n_concat = n_input + n_label
- assert n_concat <= max_seq_length, f'{n_concat=}, {n_input=}, {n_label=}'
-
- tokens_concat = torch.concat([tokens_inputs, tokens_labels], dim=0)
-
- # Fill in with the processed results
- if padding_side == 'left':
- example['input_ids'][-n_concat:] = tokens_concat
- # `labels` copies `input_ids` but with -100 at
- # non-loss-generating tokens. `labels` will be shifted in the
- # model code when computing loss.
- example['labels'][-n_concat:] = tokens_concat
- example['labels'][-n_concat:-n_label] = _HF_IGNORE_INDEX
- example['attention_mask'][-n_concat:] = 1
- example['bidirectional_mask'][-n_concat:-n_label] = 1
- else:
- example['input_ids'][:n_concat] = tokens_concat
- # See above comment regarding `labels`
- example['labels'][:n_concat] = tokens_concat
- example['labels'][:n_input] = _HF_IGNORE_INDEX
- example['attention_mask'][:n_concat] = 1
- example['bidirectional_mask'][:n_input] = 1
- return example
-
-
-# Helpful to test if your dataloader is working locally
-# Run `python denoising.py [local] [remote, optional]` and verify that batches
-# are printed out
-if __name__ == '__main__':
- from llmfoundry.utils.builders import build_tokenizer
-
- local = sys.argv[1]
- if len(sys.argv) > 2:
- remote = sys.argv[2]
- else:
- remote = local
- print(f'Reading val split from {remote} -> {local}')
-
- decoder_only = True
-
- cfg = {
- 'name': 'text_denoising',
- 'dataset': {
- 'local': local,
- 'remote': remote,
- 'split': 'val', # 'val_small',
- 'shuffle': False,
- 'max_seq_len': 2048 if decoder_only else 1024,
- 'packing_ratio': 4.5,
- 'predownload': 1000,
- 'keep_zip': True, # in case we need compressed files after testing
- },
- 'mixture_of_denoisers': {
- 'decoder_only_format': decoder_only,
- 'span_mean_lengths_and_ratios': [[3, .15], [8, .5]],
- 'sequence_mask_ratios': 0.25,
- },
- 'drop_last': False,
- 'num_workers': 0,
- }
- cfg = om.create(cfg)
- device_batch_size = 2
-
- tokenizer_name = 'EleutherAI/gpt-neox-20b' if decoder_only else 't5-base'
- tokenizer_kwargs = {'model_max_length': cfg.dataset.max_seq_len}
- tokenizer = build_tokenizer(tokenizer_name=tokenizer_name,
- tokenizer_kwargs=tokenizer_kwargs)
-
- loader = build_text_denoising_dataloader(cfg, tokenizer,
- device_batch_size).dataloader
- assert isinstance(loader, DataLoader)
- assert isinstance(loader.dataset, StreamingTextDataset)
-
- print(f'\n\nTRUNCATING TO: {loader.dataset.max_seq_len}\n\n')
-
- packing = cfg.dataset.get('packing_ratio') is not None
- if packing:
- tokenizer = loader.collate_fn.base_collator.tokenizer
- else:
- tokenizer = loader.collate_fn.tokenizer
- batch_ix = 0
- for batch in loader:
- if batch_ix >= 50:
- batch_ix += 1
- break
- if batch_ix >= 5:
- if not packing:
- break
- batch_ix += 1
- continue
- print('\n')
- print('#' * 20, f'Batch {batch_ix}', '#' * 20)
- for k, v in batch.items():
- print(k, v.shape, v.dtype)
- for sample_ix, token_sample in enumerate(batch['input_ids']):
- if cfg.mixture_of_denoisers.decoder_only_format:
- labels = batch['labels'][sample_ix]
- attn_inputs = batch['bidirectional_mask'][sample_ix].to(
- torch.bool)
- attn_full = batch['attention_mask'][sample_ix].to(torch.bool)
- attn_labels = torch.logical_xor(attn_inputs, attn_full)
- print('-' * 20, f' Sample {sample_ix} ', '-' * 20)
- if packing:
- for subseq in range(
- int(batch['sequence_id'][sample_ix].max()) + 1):
- is_subseq = batch['sequence_id'][sample_ix] == subseq
- print(
- '\033[93m{}\033[00m\n'.format('Input: '),
- tokenizer.decode(token_sample[torch.logical_and(
- is_subseq, attn_inputs)]))
- print(
- '\033[92m{}\033[00m\n'.format('Target: '),
- tokenizer.decode(labels[torch.logical_and(
- is_subseq, attn_labels)]))
- else:
- print('\033[91m{}\033[00m\n'.format('Full: '),
- tokenizer.decode(token_sample[attn_full]))
- print('\033[93m{}\033[00m\n'.format('Input: '),
- tokenizer.decode(token_sample[attn_inputs]))
- print('\033[92m{}\033[00m\n'.format('Target: '),
- tokenizer.decode(labels[attn_labels]))
- else:
- labels = batch['labels'][sample_ix]
- attn_inputs = batch['attention_mask'][sample_ix].to(torch.bool)
- attn_labels = batch['decoder_attention_mask'][sample_ix].to(
- torch.bool)
- print('-' * 20, f' Sample {sample_ix} ', '-' * 20)
- print('\033[93m{}\033[00m\n'.format('Input: '),
- tokenizer.decode(token_sample[attn_inputs]))
- print('\033[92m{}\033[00m\n'.format('Target: '),
- tokenizer.decode(labels[attn_labels]))
- batch_ix += 1
-
- if packing:
- print(f'Padding = {100*(1-loader.collate_fn.efficiency):5.2f}%')
- print(f'Waste = {100*loader.collate_fn.waste:5.2f}%')
diff --git a/llmfoundry/data/finetuning/collator.py b/llmfoundry/data/finetuning/collator.py
index d82fad5ec2..6e3babd657 100644
--- a/llmfoundry/data/finetuning/collator.py
+++ b/llmfoundry/data/finetuning/collator.py
@@ -331,31 +331,21 @@ def _process_and_batch_decoder_only(
self._warned_truncated = True
attention_mask = [1] * len(input_ids)
- # bidirectional_mask is used by our prefix lm model variants
- # Note: this will be malformed if any loss-generating tokens are followed by non-loss-generating tokens
- # (such as in the case of multi-turn chat examples)
- bidirectional_mask = [
- 1 if label == _HF_IGNORE_INDEX else 0 for label in labels
- ]
# Annoyingly, we need to pad everything but input_ids
# and attention_mask ourselves
n_total = len(input_ids)
i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - n_total)
- z_pad = [0] * (self.max_seq_len - n_total)
if self.tokenizer.padding_side == 'left':
labels = i_pad + labels
- bidirectional_mask = z_pad + bidirectional_mask
else:
labels = labels + i_pad
- bidirectional_mask = bidirectional_mask + z_pad
# Update the example
processed_example = {
'input_ids': input_ids,
'labels': labels,
'attention_mask': attention_mask,
- 'bidirectional_mask': bidirectional_mask,
}
processed_examples.append(processed_example)
diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py
index 515888a664..38c9673a14 100644
--- a/llmfoundry/data/finetuning/dataloader.py
+++ b/llmfoundry/data/finetuning/dataloader.py
@@ -18,6 +18,8 @@
dataset_constructor)
from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio
from llmfoundry.data.text_data import build_streams, get_tokens_per_batch_func
+from llmfoundry.utils.exceptions import (MissingHuggingFaceURLSplitError,
+ NotEnoughDatasetSamplesError)
log = logging.getLogger(__name__)
@@ -174,15 +176,12 @@ def build_finetuning_dataloader(cfg: DictConfig,
# Build HF dataloader
dataset_name_or_path = cfg.dataset.hf_name
split = cfg.dataset.get('split')
+ if split is None:
+ raise MissingHuggingFaceURLSplitError()
# If dataset is a remote path, download it first.
backend, _, _ = parse_uri(dataset_name_or_path)
if backend not in ['', None]:
- if split is None:
- raise ValueError(
- 'When using a HuggingFace dataset from a URL, you must set the ' + \
- '`split` key in the dataset config.'
- )
dataset_name_or_path = _download_remote_hf_dataset(
remote_path=dataset_name_or_path, split=split)
split = split.replace('-', '_')
@@ -218,17 +217,13 @@ def build_finetuning_dataloader(cfg: DictConfig,
if hasattr(dataset, '__len__'):
full_dataset_size = len(dataset)
if full_dataset_size < minimum_dataset_size:
- raise ValueError(
- f'Your dataset (name={cfg.dataset.hf_name}, split={split}) '
- +
- f'has {full_dataset_size} samples, but your minimum batch size '
- +
- f'is {minimum_dataset_size} because you are running on {world_size} gpus and '
- +
- f'your per device batch size is {dataloader_batch_size}. Please increase the number '
- +
- f'of samples in your dataset to at least {minimum_dataset_size}.'
- )
+ raise NotEnoughDatasetSamplesError(
+ dataset_name=cfg.dataset.hf_name,
+ split=split,
+ dataloader_batch_size=dataloader_batch_size,
+ world_size=world_size,
+ full_dataset_size=full_dataset_size,
+ minimum_dataset_size=minimum_dataset_size)
# Initialize sampler.
sampler = dist.get_sampler(dataset,
drop_last=cfg.drop_last,
@@ -557,13 +552,13 @@ def _build_collate_fn(
1)],
skip_special_tokens=False,
clean_up_tokenization_spaces=True))
+ context = torch.logical_and(
+ batch['attention_mask'][j] == 1,
+ batch['labels'][j] == _HF_IGNORE_INDEX)
print(
'\033[92m{}\033[00m\n'.format('CONTEXT: '),
tokenizer.decode(batch['input_ids'][
- j,
- torch.logical_and(
- is_subseq, batch['bidirectional_mask'][j] ==
- 1)],
+ j, torch.logical_and(is_subseq, context)],
skip_special_tokens=False,
clean_up_tokenization_spaces=True))
print(
@@ -583,10 +578,12 @@ def _build_collate_fn(
batch['attention_mask'][j] == 1],
skip_special_tokens=False,
clean_up_tokenization_spaces=True))
+ context = torch.logical_and(
+ batch['attention_mask'][j] == 1,
+ batch['labels'][j] == _HF_IGNORE_INDEX)
print(
'\033[92m{}\033[00m\n'.format('CONTEXT: '),
- tokenizer.decode(batch['input_ids'][
- j, batch['bidirectional_mask'][j] == 1],
+ tokenizer.decode(batch['input_ids'][j, context],
skip_special_tokens=False,
clean_up_tokenization_spaces=True))
print(
diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py
index 8b5bbaa654..4ca15e8d1f 100644
--- a/llmfoundry/data/finetuning/tasks.py
+++ b/llmfoundry/data/finetuning/tasks.py
@@ -38,7 +38,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
from collections.abc import Mapping
from functools import partial
from pathlib import Path
-from typing import (Any, Callable, Dict, List, Literal, Optional, Sequence, Set,
+from typing import (Any, Callable, Dict, List, Literal, Optional, Sequence,
Tuple, Union, cast)
import datasets as hf_datasets
@@ -51,6 +51,21 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
from llmfoundry.data.finetuning.collator import (_HF_IGNORE_INDEX,
stitch_turns_decoder_only,
stitch_turns_encoder_decoder)
+# yapf: disable
+from llmfoundry.utils.exceptions import (ConsecutiveRepeatedChatRolesError,
+ IncorrectMessageKeyQuantityError,
+ InvalidContentTypeError,
+ InvalidFileExtensionError,
+ InvalidLastChatMessageRoleError,
+ InvalidPromptResponseKeysError,
+ InvalidPromptTypeError,
+ InvalidResponseTypeError,
+ InvalidRoleError,
+ NotEnoughChatDataError,
+ TooManyKeysInExampleError,
+ UnableToProcessPromptResponseError,
+ UnknownExampleTypeError)
+# yapf: enable
from llmfoundry.utils.logging_utils import SpecificWarningFilter
log = logging.getLogger(__name__)
@@ -94,13 +109,11 @@ def _get_example_type(example: Example) -> ExampleType:
if any(allowed_message_key in example
for allowed_message_key in _ALLOWED_MESSAGES_KEYS):
return 'chat'
- elif any([
- pr in example
- for pr in _ALLOWED_PROMPT_KEYS.union(_ALLOWED_RESPONSE_KEYS)
- ]):
+ elif any(p in example for p in _ALLOWED_PROMPT_KEYS) and any(
+ r in example for r in _ALLOWED_RESPONSE_KEYS):
return 'prompt_response'
else:
- raise KeyError(f'Unknown conversation type {example=}')
+ raise UnknownExampleTypeError(example)
def _is_empty_or_nonexistent(dirpath: str) -> bool:
@@ -115,15 +128,14 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool:
return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0
-def _get_key(dictionary: Mapping[str, Any], allowed_keys: Set[str]):
+def _get_key(dictionary: Mapping[str, Any], allowed_keys: set[str]):
if not isinstance(dictionary, Mapping):
raise TypeError(
f'Expected dictionary to be a mapping, but found {type(dictionary)}'
)
desired_keys = allowed_keys.intersection(dictionary.keys())
if len(desired_keys) != 1:
- raise ValueError(
- f'Dictionary has multiple keys in `allowed_keys`: {desired_keys}')
+ raise TooManyKeysInExampleError(allowed_keys, desired_keys)
return list(desired_keys)[0]
@@ -136,26 +148,29 @@ def _validate_chat_formatted_example(example: ChatFormattedDict):
raise TypeError(
f'Expected messages to be an iterable, but found {type(messages)}')
if len(messages) <= 1:
- raise ValueError('Chat example must have at least two messages')
+ raise NotEnoughChatDataError()
last_message = messages[-1]
role_key = _get_key(last_message, _ALLOWED_ROLE_KEYS)
last_role = last_message[role_key]
if last_role not in _ALLOWED_LAST_MESSAGE_ROLES:
- raise ValueError(f'Invalid last message role: {last_role}')
+ raise InvalidLastChatMessageRoleError(last_role,
+ _ALLOWED_LAST_MESSAGE_ROLES)
+ last_message_role = None
for message in messages:
role_key, content_key = _get_key(message, _ALLOWED_ROLE_KEYS), _get_key(
message, _ALLOWED_CONTENT_KEYS)
if len(message.keys()) != 2:
- raise ValueError(
- f'Expected 2 keys in message, but found {len(message.keys())}')
+ raise IncorrectMessageKeyQuantityError(list(message.keys()))
if message[role_key] not in _ALLOWED_ROLES:
- raise ValueError(f'Invalid role: {message[role_key]}')
+ raise InvalidRoleError(message[role_key], _ALLOWED_ROLES)
if not isinstance(message[content_key], str):
- raise TypeError(
- f'Expected content to be a string, but found {type(message[content_key])}'
- )
+ raise InvalidContentTypeError(type(message[content_key]))
+ if last_message_role is not None and last_message_role == message[
+ role_key]:
+ raise ConsecutiveRepeatedChatRolesError(last_message_role)
+ last_message_role = message[role_key]
def _slice_chat_formatted_example(
@@ -182,8 +197,8 @@ def _slice_chat_formatted_example(
last_message = messages[-1]
if last_message['role'] != 'assistant':
- raise ValueError(
- f'last message must be from assistant. {last_message=}')
+ raise InvalidLastChatMessageRoleError(last_message['role'],
+ set(['assistant']))
def slice_out_last_turn(
messages_through_current_turn: List[Dict[str, str]],
@@ -291,31 +306,20 @@ def _tokenize_prompt_response_formatted_example(
response_keys = example_keys.intersection(_ALLOWED_RESPONSE_KEYS)
if len(prompt_keys) != 1:
- raise KeyError(
- f'Unable to tokenize example because {len(prompt_keys)} of the allowed prompt keys ' +\
- f'were present in {example_keys=}. Please specify exactly one. {_ALLOWED_PROMPT_KEYS=}'
- )
+ raise TooManyKeysInExampleError(_ALLOWED_PROMPT_KEYS, prompt_keys)
if len(response_keys) != 1:
- raise KeyError(
- f'Unable to tokenize example because {len(response_keys)} of the allowed response keys ' +\
- f'were present in {example_keys=}. Please specify exactly one. {_ALLOWED_RESPONSE_KEYS=}'
- )
+ raise TooManyKeysInExampleError(_ALLOWED_RESPONSE_KEYS, response_keys)
prompt_key = prompt_keys.pop()
response_key = response_keys.pop()
prompt = example[prompt_key]
response = example[response_key]
-
if not isinstance(prompt, str):
- raise TypeError(
- f'Unable to tokenize example because {prompt_key} was not a string. {example=}'
- )
+ raise InvalidPromptTypeError(type(prompt))
if not isinstance(response, str):
- raise TypeError(
- f'Unable to tokenize example because {response_key} was not a string. {example=}'
- )
+ raise InvalidResponseTypeError(type(response))
# Note: We default to the tokenizer's add_bos_token and add_eos_token behavior here
# (which we do not do for chat-formatted examples). This is because chat examples specifically
@@ -360,7 +364,7 @@ def tokenize_formatted_example(
return _tokenize_prompt_response_formatted_example(
prompt_response_example, tokenizer)
else:
- raise ValueError(f'Unknown conversation type {example_format=}')
+ raise UnknownExampleTypeError(example)
def is_valid_ift_example(max_seq_len: int, target_prompts: str,
@@ -428,7 +432,7 @@ def _stream_remote_local_validate(remote: Optional[str], local: Optional[str],
contents = set(os.listdir(local))
if split is not None and split not in contents:
raise ValueError(
- f'local directory {local} does not contain split {split}')
+ f'Local directory {local} does not contain split {split}')
class StreamingFinetuningDataset(StreamingDataset):
@@ -636,9 +640,7 @@ def get_preprocessing_fn_from_dict(
def _preprocessor(example: Dict[str, Any]) -> Dict[str, str]:
if list(mapping.keys()) != ['prompt', 'response']:
- raise ValueError(
- f'Expected {mapping=} to have keys "prompt" and "response".'
- )
+ raise InvalidPromptResponseKeysError(mapping, example)
return {
'prompt': example[mapping['prompt']],
'response': example[mapping['response']]
@@ -697,9 +699,8 @@ def get_preprocessing_fn_from_str(
return preprocessing_fn
def build_from_hf(
- self, dataset_name: str, split: Optional[str], safe_load: bool,
- max_seq_len: int, preprocessing_fn: Optional[Callable[[dict[str, Any]],
- dict[str, str]]],
+ self, dataset_name: str, split: str, safe_load: bool, max_seq_len: int,
+ preprocessing_fn: Optional[Callable[[dict[str, Any]], dict[str, str]]],
tokenizer: PreTrainedTokenizerBase, target_prompts: str,
target_responses: str, decoder_only_format: bool, hf_kwargs: Dict[str,
Any]
@@ -758,9 +759,8 @@ def build_from_hf(
local_dir_use_symlinks=False,
local_dir=local_dataset_dir)
if _is_empty_or_nonexistent(dirpath=local_dataset_dir):
- raise FileNotFoundError(
- f'safe_load is set to True. No data files with safe extensions {SUPPORTED_EXTENSIONS} '
- + f'found for dataset {dataset_name}. ')
+ raise InvalidFileExtensionError(
+ dataset_name, SUPPORTED_EXTENSIONS)
# Set dataset_name to the downloaded location.
dataset_name = local_dataset_dir
@@ -774,9 +774,9 @@ def build_from_hf(
if not all(
Path(f).suffix in SUPPORTED_EXTENSIONS
for f in dataset_files):
- raise ValueError(
- f'Dataset at local path {dataset_name} contains invalid file types. '
- + f'Allowed file types are: {SUPPORTED_EXTENSIONS}')
+ raise InvalidFileExtensionError(dataset_name,
+ SUPPORTED_EXTENSIONS)
+
dataset = hf_datasets.load_dataset(dataset_name,
split=split,
**hf_kwargs)
@@ -853,9 +853,8 @@ def alpaca_preprocessing_function(inp: Dict) -> Dict[str, str]:
prompt, response = inp['text'].split('### Response:')
prompt += '### Response:'
except Exception as e:
- raise ValueError(
- f"Unable to extract prompt/response from 'text'={inp['text']}"
- ) from e
+ raise UnableToProcessPromptResponseError(inp) from e
+
return {'prompt': prompt, 'response': response}
@@ -871,8 +870,7 @@ def dolly_preprocessing_function(inp: Dict) -> Dict[str, str]:
prompt = PROMPT_FORMAT.format(instruction=instruction)
response = inp['output']
except Exception as e:
- raise ValueError(
- f'Unable to extract prompt/response from {inp=}') from e
+ raise UnableToProcessPromptResponseError(inp) from e
return {'prompt': prompt, 'response': response}
@@ -898,6 +896,5 @@ def muennighoff_tokenize_function(inp: Dict) -> Dict[str, str]:
response.startswith(transitions)):
response = ' ' + response
except Exception as e:
- raise ValueError(
- f'Unable to process prompt/response from {inp=}') from e
+ raise UnableToProcessPromptResponseError(inp) from e
return {'prompt': prompt, 'response': response}
diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py
index fba3ab2d3e..9f4af3099e 100644
--- a/llmfoundry/data/packing.py
+++ b/llmfoundry/data/packing.py
@@ -71,7 +71,6 @@ def pack(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
'input_ids',
'labels',
'attention_mask',
- 'bidirectional_mask',
'sequence_id',
]
# Cut everything down to size
@@ -278,7 +277,6 @@ def pad_tensor(tensor: torch.Tensor, pad_value: int):
'input_ids': pad_token_id,
'labels': -100,
'attention_mask': 0,
- 'bidirectional_mask': 0,
'sequence_id': -1,
}
keys = packed_examples[0].keys()
diff --git a/llmfoundry/metrics/__init__.py b/llmfoundry/metrics/__init__.py
index 116b6dd08c..6c71a3ea08 100644
--- a/llmfoundry/metrics/__init__.py
+++ b/llmfoundry/metrics/__init__.py
@@ -43,11 +43,6 @@
'code_eval_accuracy',
]
-DEFAULT_PREFIX_LM_METRICS = [
- 'language_cross_entropy',
- 'masked_accuracy',
-]
-
DEFAULT_ENC_DEC_METRICS = [
'language_cross_entropy',
'masked_accuracy',
@@ -66,6 +61,5 @@
'MaskedAccuracy',
'DEFAULT_CAUSAL_LM_TRAIN_METRICS',
'DEFAULT_CAUSAL_LM_EVAL_METRICS',
- 'DEFAULT_PREFIX_LM_METRICS',
'DEFAULT_ENC_DEC_METRICS',
]
diff --git a/llmfoundry/models/__init__.py b/llmfoundry/models/__init__.py
index 36234d3c14..ea144225c0 100644
--- a/llmfoundry/models/__init__.py
+++ b/llmfoundry/models/__init__.py
@@ -1,8 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
-from llmfoundry.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM,
- ComposerHFT5)
+from llmfoundry.models.hf import ComposerHFCausalLM, ComposerHFT5
from llmfoundry.models.inference_api_wrapper import (FMAPICasualLMEvalWrapper,
FMAPIChatAPIEvalWrapper,
OpenAICausalLMEvalWrapper,
@@ -13,7 +12,6 @@
models.register('mpt_causal_lm', func=ComposerMPTCausalLM)
models.register('hf_causal_lm', func=ComposerHFCausalLM)
-models.register('hf_prefix_lm', func=ComposerHFPrefixLM)
models.register('hf_t5', func=ComposerHFT5)
models.register('openai_causal_lm', func=OpenAICausalLMEvalWrapper)
models.register('fmapi_causal_lm', func=FMAPICasualLMEvalWrapper)
@@ -22,7 +20,6 @@
__all__ = [
'ComposerHFCausalLM',
- 'ComposerHFPrefixLM',
'ComposerHFT5',
'MPTConfig',
'MPTPreTrainedModel',
diff --git a/llmfoundry/models/hf/__init__.py b/llmfoundry/models/hf/__init__.py
index d0ab65dfaf..3c35080d6e 100644
--- a/llmfoundry/models/hf/__init__.py
+++ b/llmfoundry/models/hf/__init__.py
@@ -5,12 +5,10 @@
from llmfoundry.models.hf.hf_fsdp import (prepare_hf_causal_lm_model_for_fsdp,
prepare_hf_enc_dec_model_for_fsdp,
prepare_hf_model_for_fsdp)
-from llmfoundry.models.hf.hf_prefix_lm import ComposerHFPrefixLM
from llmfoundry.models.hf.hf_t5 import ComposerHFT5
__all__ = [
'ComposerHFCausalLM',
- 'ComposerHFPrefixLM',
'ComposerHFT5',
'prepare_hf_causal_lm_model_for_fsdp',
'prepare_hf_enc_dec_model_for_fsdp',
diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py
index 7640a7b8a8..38ed7a7e70 100644
--- a/llmfoundry/models/hf/hf_causal_lm.py
+++ b/llmfoundry/models/hf/hf_causal_lm.py
@@ -17,11 +17,10 @@
from llmfoundry.metrics import (DEFAULT_CAUSAL_LM_EVAL_METRICS,
DEFAULT_CAUSAL_LM_TRAIN_METRICS)
from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
-from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
+from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP
from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.utils import init_empty_weights
from llmfoundry.utils.config_utils import pop_config
-from llmfoundry.utils.warnings import VersionedDeprecationWarning
if TYPE_CHECKING:
from peft import PeftConfig
@@ -31,7 +30,7 @@
log = logging.getLogger(__name__)
-class ComposerHFCausalLM(HuggingFaceModelWithZLoss):
+class ComposerHFCausalLM(HuggingFaceModelWithFSDP):
"""Configures a :class:`.HuggingFaceModel` around a Causal LM.
Args:
@@ -54,11 +53,8 @@ class ComposerHFCausalLM(HuggingFaceModelWithZLoss):
cfg.use_auth_token (bool, optional): Whether to use the Hugging Face authentication token when
loading from Hugging Face Hub. Default: ``False``.
cfg.use_train_metrics (bool, optional): Whether to use training metrics. Default: ``True``.
- cfg.z_loss (float, optional): The z-loss coefficient. Default: ``0.0``.
cfg.load_in_8bit (bool, optional): Whether to load the model in 8-bit mode. Default: ``False``.
cfg.init_device (str, optional): Which device to initialize the model on. Default: ``'cpu'``.
- cfg.attention_patch_type (str, optional): Which attention patch to use for llama models. Default: ``None``.
- Deprecated. Will automatically use flash attention 2.
cfg.use_flash_attention_2 (bool, optional): Whether to use flash-attention 2. Default: ``False``.
tokenizer (PreTrainedTokenizer): The tokenizer that the model will use.
"""
@@ -88,26 +84,15 @@ def __init__(self, om_model_config: DictConfig,
load_in_8bit = om_model_config.get('load_in_8bit', False)
# Set up config args for the model construction and base classes
- z_loss = om_model_config.get('z_loss', 0.0)
init_device = om_model_config.get('init_device', 'cpu')
# Resolve "mixed" init device to either "cpu" or "meta"
resolved_init_device = hf_get_init_device(init_device)
- attention_patch_type = om_model_config.get('attention_patch_type', None)
- if attention_patch_type is not None:
- warnings.warn(
- VersionedDeprecationWarning(
- 'attention_patch_type is deprecated and will automatically use flash attention 2. '
- +
- 'We recommend `use_flash_attention_2: true` for llama models.',
- remove_version='0.7.0'))
- use_flash_attention_2 = True
-
requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager'
if use_flash_attention_2 and not is_flash_v2_installed():
raise ValueError(
'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. '
- + 'Please `pip install llm-foundry[gpu-flash2]`.')
+ + 'Please `pip install llm-foundry[gpu]`.')
peft_config_dict = pop_config(om_model_config,
'peft_config',
@@ -269,7 +254,6 @@ def _autoset_attn_implementation_monkeypatch(
tokenizer=tokenizer,
metrics=train_metrics,
eval_metrics=eval_metrics,
- z_loss=z_loss,
init_device=init_device,
peft_config=peft_config,
)
diff --git a/llmfoundry/models/hf/hf_prefix_lm.py b/llmfoundry/models/hf/hf_prefix_lm.py
deleted file mode 100644
index 67060a02b8..0000000000
--- a/llmfoundry/models/hf/hf_prefix_lm.py
+++ /dev/null
@@ -1,149 +0,0 @@
-# Copyright 2022 MosaicML LLM Foundry authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""Implements a Hugging Prefix LM wrapped inside a :class:`.ComposerModel`."""
-
-from __future__ import annotations
-
-from typing import Mapping, MutableMapping
-
-from composer.utils import dist
-from omegaconf import DictConfig
-from transformers import (AutoConfig, AutoModelForCausalLM,
- PreTrainedTokenizerBase)
-
-from llmfoundry.metrics import DEFAULT_PREFIX_LM_METRICS
-from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
-from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
-from llmfoundry.models.utils import (adapt_tokenizer_for_denoising,
- add_bidirectional_mask_if_missing,
- convert_hf_causal_lm_to_prefix_lm,
- init_empty_weights)
-
-__all__ = ['ComposerHFPrefixLM']
-
-
-class ComposerHFPrefixLM(HuggingFaceModelWithZLoss):
- """Configures a :class:`.HuggingFaceModel` around a Prefix LM.
-
- Note: HuggingFace does not natively support Prefix LM-style models. This function uses
- `transformers.AutoModelForCausalLM` to instantiate a Causal LM, then uses a conversion utility
- to turn the model into a Prefix LM. Currently, that conversion utility only supports the
- following HuggingFace Causal LM types:
- - `GPT2LMHeadModel`
- - `GPTNeoForCausalLM`
- - `GPTNeoXForCausalLM`
- - `GPTJForCausalLM`
- - `BloomForCausalLM`
- - `OPTForCausalLM`
-
- Args:
- cfg (DictConfig): An omegaconf dictionary used to configure the model:
- cfg.pretrained_model_name_or_path (str): The name of or local path to
- the HF model (e.g., `gpt2` to instantiate a GPT2LMHeadModel). The model
- will be converted to a Prefix LM during initialization.
- cfg.config_overrides (dict, optional): An optional dictionary of keyword
- arguments that override the default configuration associated with
- cfg.pretrained_model_name_or_path. Default: ``{}``.
- cfg.pretrained (bool): Whether to instantiate the model with pre-trained
- weights coming from cfg.pretrained_model_name_or_path. If ``True``,
- cfg.config_overrides must be compatible with the pre-trained weights.
- cfg.init_device ('cpu' | 'meta'): Which device, 'cpu' or 'meta', to
- initialize the model on. Currently, `meta` is only supported when
- cfg.pretrained is ``False``. Default: ``'cpu'``.
- cfg.z_loss (float, optional): The coefficient of the z-loss. If >0.0, this
- the z-loss will be multiplied by this value before being added to the
- standard loss term. Default: ``0.0``.
- cfg.adapt_vocab_for_denoising (bool, optional): Whether to adapt the vocab
- of the model/tokenizer to include sentinel tokens that are used in denoising
- tasks like Span Corruption. If you intend to load from an existing Composer
- checkpoint that was trained on such a task, set this to ``True`` to ensure
- that the model vocab size matches your checkpoint's vocab size when loading
- the weights. Default: ``False``.
- tokenizer (PreTrainedTokenizer): The tokenizer that the model will use.
- """
-
- def __init__(self, om_model_config: DictConfig,
- tokenizer: PreTrainedTokenizerBase):
- from llmfoundry.utils.builders import build_metric
-
- config = AutoConfig.from_pretrained(
- om_model_config.pretrained_model_name_or_path,
- trust_remote_code=om_model_config.get('trust_remote_code', True),
- use_auth_token=om_model_config.get('use_auth_token', False),
- )
-
- # set config overrides
- for k, v in om_model_config.get('config_overrides', {}).items():
- if not hasattr(config, k):
- raise ValueError(
- f'config does not have attribute "{k}" to override ({k}: {v}).'
- )
-
- attr = getattr(config, k)
- if isinstance(attr, Mapping):
- extra_keys = [_k for _k in v.keys() if _k not in attr.keys()]
- if extra_keys:
- raise ValueError(
- f'Config dict override got unknown keys. ' +
- f'Extra keys: {extra_keys}. ' +
- f'Expected (a subset of) keys: {list(attr.keys())}.')
- getattr(config, k).update(v)
- else:
- setattr(config, k, v)
-
- # Set up the tokenizer (add tokens for denoising sentinels if needed)
- if om_model_config.get('adapt_vocab_for_denoising', False):
- adapt_tokenizer_for_denoising(tokenizer)
-
- init_device = om_model_config.get('init_device', 'cpu')
-
- # Get the device we want to initialize, and use the
- # resolved version to initialize the HF model
- resolved_init_device = hf_get_init_device(init_device)
-
- # We need to have all non-zero local ranks be not-pretrained
- # Rank 0 will still be pretrained, and distribute the weights appropriately
- if dist.get_local_rank() != 0 and init_device == 'mixed':
- om_model_config.pretrained = False
-
- if resolved_init_device == 'cpu':
- if om_model_config.pretrained:
- model = AutoModelForCausalLM.from_pretrained(
- om_model_config.pretrained_model_name_or_path,
- config=config)
- else:
- model = AutoModelForCausalLM.from_config(config)
- elif resolved_init_device == 'meta':
- if om_model_config.pretrained:
- raise ValueError(
- 'Setting cfg.pretrained=True is not supported when init_device="meta".'
- )
- with init_empty_weights(include_buffers=False):
- model = AutoModelForCausalLM.from_config(config)
- else:
- raise ValueError(
- f'init_device="{init_device}" must be either "cpu" or "meta".')
-
- # Convert the Causal LM into a Prefix LM via our custom wrapper
- model = convert_hf_causal_lm_to_prefix_lm(model)
-
- metrics = [
- build_metric(metric, {}) for metric in DEFAULT_PREFIX_LM_METRICS +
- om_model_config.get('additional_train_metrics', [])
- ]
-
- composer_model = super().__init__(model=model,
- shift_labels=True,
- tokenizer=tokenizer,
- metrics=metrics,
- z_loss=om_model_config.get(
- 'z_loss', 0.0),
- init_device=init_device)
-
- return composer_model
-
- def forward(self, batch: MutableMapping):
- # Add bidirectional_mask if it is missing and can be constructed
- add_bidirectional_mask_if_missing(batch)
- return super().forward(batch)
diff --git a/llmfoundry/models/hf/hf_t5.py b/llmfoundry/models/hf/hf_t5.py
index 5956d49cc8..b9c1df64cf 100644
--- a/llmfoundry/models/hf/hf_t5.py
+++ b/llmfoundry/models/hf/hf_t5.py
@@ -14,16 +14,15 @@
from llmfoundry.metrics import DEFAULT_ENC_DEC_METRICS
from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
-from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
-from llmfoundry.models.utils import (adapt_tokenizer_for_denoising,
- init_empty_weights)
+from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP
+from llmfoundry.models.utils import init_empty_weights
from llmfoundry.utils.warnings import experimental_class
__all__ = ['ComposerHFT5']
@experimental_class('ComposerHFT5')
-class ComposerHFT5(HuggingFaceModelWithZLoss):
+class ComposerHFT5(HuggingFaceModelWithFSDP):
"""Configures a :class:`.HuggingFaceModel` around a T5.
Note: This function uses `transformers.T5ForConditionalGeneration`. Future releases
@@ -42,15 +41,6 @@ class ComposerHFT5(HuggingFaceModelWithZLoss):
cfg.init_device ('cpu' | 'meta'): Which device, 'cpu' or 'meta', to
initialize the model on. Currently, `meta` is only supported when
cfg.pretrained is ``False``. Default: ``'cpu'``.
- cfg.z_loss (float, optional): The coefficient of the z-loss. If >0.0, this
- the z-loss will be multiplied by this value before being added to the
- standard loss term. Default: ``0.0``.
- cfg.adapt_vocab_for_denoising (bool, optional): Whether to adapt the vocab
- of the model/tokenizer to include sentinel tokens that are used in denoising
- tasks like Span Corruption. If you intend to load from an existing Composer
- checkpoint that was trained on such a task, set this to ``True`` to ensure
- that the model vocab size matches your checkpoint's vocab size when loading
- the weights. Default: ``False``.
tokenizer (PreTrainedTokenizer): The tokenizer that the model will use.
"""
@@ -87,10 +77,6 @@ def __init__(self, om_model_config: DictConfig,
raise ValueError(f'Model type "hf_t5" currently only supports T5 models ' +\
f'using configs where `is_encoder_decoder` is ``True``.')
- # Set up the tokenizer (add tokens for denoising sentinels if needed)
- if om_model_config.get('adapt_vocab_for_denoising', False):
- adapt_tokenizer_for_denoising(tokenizer)
-
init_device = om_model_config.get('init_device', 'cpu')
# Get the device we want to initialize, and use the
@@ -128,8 +114,6 @@ def __init__(self, om_model_config: DictConfig,
composer_model = super().__init__(model=model,
tokenizer=tokenizer,
metrics=metrics,
- z_loss=om_model_config.get(
- 'z_loss', 0.0),
init_device=init_device)
return composer_model
diff --git a/llmfoundry/models/hf/model_wrapper.py b/llmfoundry/models/hf/model_wrapper.py
index e9f4c89796..2ba88d390c 100644
--- a/llmfoundry/models/hf/model_wrapper.py
+++ b/llmfoundry/models/hf/model_wrapper.py
@@ -5,11 +5,9 @@
from __future__ import annotations
-import warnings
from collections import UserDict
from typing import TYPE_CHECKING, List, Mapping, Optional
-import torch
import transformers
from composer.models.huggingface import HuggingFaceModel
from torchmetrics import Metric
@@ -17,7 +15,6 @@
from transformers.utils.generic import ModelOutput
from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp
-from llmfoundry.utils.warnings import VersionedDeprecationWarning
if TYPE_CHECKING:
from peft import PeftConfig
@@ -26,19 +23,9 @@
_HF_IGNORE_INDEX = -100
-class HuggingFaceModelWithZLoss(HuggingFaceModel):
+class HuggingFaceModelWithFSDP(HuggingFaceModel):
"""Wrapper around HuggingFaceModel.
- This adds z-loss, which is used in some training contexts,
- and is a convenient way to patch features that are generically
- useful for HF models.
- See use of z_loss in PaLM: https://arxiv.org/abs/2204.02311v3, Section 5.
- Also, from https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666:
- Two uses of z_loss are:
- - To keep the logits from drifting too far from zero, which can cause
- unacceptable roundoff errors in bfloat16.
- - To encourage the logits to be normalized log-probabilities.
-
Handles preparation for FSDP wrapping.
"""
@@ -47,7 +34,6 @@ def __init__(self,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
metrics: Optional[List[Metric]] = None,
eval_metrics: Optional[List[Metric]] = None,
- z_loss: float = 0.0,
shift_labels: bool = False,
init_device: Optional[str] = None,
peft_config: Optional['PeftConfig'] = None):
@@ -61,9 +47,6 @@ def __init__(self,
peft_config=peft_config,
should_save_peft_only=True,
)
- self.z_loss = float(z_loss)
- if self.z_loss < 0.0:
- raise ValueError(f'z_loss(={z_loss}) cannot be negative.')
# Note: We need to add the FSDP related attributes to the model AFTER the super init,
# so that the (possible) embedding resizing doesn't destroy them
@@ -88,27 +71,6 @@ def forward(self, batch: Mapping):
def loss(self, outputs: ModelOutput, batch: Mapping):
if self.config.use_return_dict:
- loss, logits = outputs['loss'], outputs['logits']
- else:
- # loss is at index 0 in the output tuple, logits are at index 1
- loss, logits = outputs[:2]
- if self.z_loss == 0.0:
- return loss
-
- warnings.warn(
- VersionedDeprecationWarning('z-loss is deprecated.',
- remove_version='0.7.0'))
-
- # Add a z_loss to the standard loss
- logits_flat = logits.view(-1, logits.size(-1))
- labels_flat = batch['labels'].view(-1)
- log_z = torch.logsumexp(logits_flat[labels_flat != _HF_IGNORE_INDEX],
- dim=1)
- log_z2 = log_z**2
- z_loss = log_z2.mean() * self.z_loss
- if self.config.use_return_dict:
- outputs['loss'] += z_loss
return outputs['loss']
- else:
- outputs[0] += z_loss
- return outputs[0]
+ # loss is at index 0 in the output tuple, logits are at index 1
+ return outputs[:2]
diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py
index 05350b059b..df4216b81c 100644
--- a/llmfoundry/models/layers/__init__.py
+++ b/llmfoundry/models/layers/__init__.py
@@ -4,7 +4,7 @@
from llmfoundry.models.layers.attention import (
ATTN_CLASS_REGISTRY, GroupedQueryAttention, MultiheadAttention,
MultiQueryAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
- flash_attn_fn, scaled_multihead_dot_product_attention, triton_flash_attn_fn)
+ flash_attn_fn, scaled_multihead_dot_product_attention)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
@@ -14,7 +14,6 @@
__all__ = [
'scaled_multihead_dot_product_attention',
'flash_attn_fn',
- 'triton_flash_attn_fn',
'MultiheadAttention',
'MultiQueryAttention',
'GroupedQueryAttention',
diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py
index 79780bccee..1deca69eb2 100644
--- a/llmfoundry/models/layers/attention.py
+++ b/llmfoundry/models/layers/attention.py
@@ -50,7 +50,7 @@ def check_alibi_support(attention_impl: str) -> bool:
def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
original_is_causal: bool) -> bool:
# disable causal when it is not needed
- # necessary for flash & triton for generation with kv_cache
+ # necessary for flash for generation with kv_cache
if original_is_causal and num_query_tokens != num_key_tokens:
if num_query_tokens != 1:
raise NotImplementedError(
@@ -100,7 +100,7 @@ def scaled_multihead_dot_product_attention(
v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
if past_key_value is not None:
- # attn_impl: flash & triton use kernels which expect input shape [b, s, h, d_head].
+ # attn_impl: flash attn uses kernels which expect input shape [b, s, h, d_head].
# kv_cache is therefore stored using that shape.
# attn_impl: torch stores the kv_cache in the ordering which is most advantageous
# for its attn computation ie
@@ -341,123 +341,13 @@ def flash_attn_fn(
return output, None, past_key_value
-def triton_flash_attn_fn(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- n_heads: int,
- kv_n_heads: int,
- past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
- softmax_scale: Optional[float] = None,
- attn_bias: Optional[torch.Tensor] = None,
- key_padding_mask: Optional[torch.Tensor] = None,
- is_causal: bool = False,
- dropout_p: float = 0.0,
- training: bool = False,
- needs_weights: bool = False,
-) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
- torch.Tensor]]]:
- try:
- from llmfoundry.models.layers.flash_attn_triton import flash_attn_func
- except:
- _installed = False
- if version.parse(torch.__version__) < version.parse('2.0.0'):
- _installed = True
- # if torch1.13.1 revert to using triton flash attn from HazyResearch
- # with flash-attn==1.0.9 and triton==2.0.0.dev20221202
- try:
- from flash_attn.flash_attn_triton import flash_attn_func
- except:
- _installed = False
- if not _installed:
- # installing triton-pre-mlir works for both torch1.13.1 and torch2.0+
- # default recommendation is to install this variant
- raise RuntimeError(
- 'Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU '
- +
- 'and `pip install .[gpu]` if installing from llm-foundry source or '
- +
- '`pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` '
- +
- 'if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). '
- +
- 'Note: (1) requires you have CMake and PyTorch already installed.'
- )
-
- check_valid_inputs(query, key, value)
-
- if past_key_value is not None:
- if len(past_key_value) != 0:
- key = torch.cat([past_key_value[0], key], dim=1)
- value = torch.cat([past_key_value[1], value], dim=1)
-
- past_key_value = (key, value)
-
- if attn_bias is not None:
- # clamp to 0 necessary for torch 2.0 compile()
- _s_q = max(0, attn_bias.size(2) - query.size(1))
- _s_k = max(0, attn_bias.size(3) - key.size(1))
- attn_bias = attn_bias[:, :, _s_q:, _s_k:]
-
- if dropout_p:
- raise NotImplementedError(
- f'Dropout not implemented for attn_impl: triton.')
- dropout_p = dropout_p if training else 0.0
-
- if needs_weights:
- raise NotImplementedError(
- f'attn_impl: triton cannot return attn weights.')
-
- if key_padding_mask is not None:
- warnings.warn(
- 'Propagating key_padding_mask to the attention module ' +\
- 'and applying it within the attention module can cause ' +\
- 'unnecessary computation/memory usage. Consider integrating ' +\
- 'into attn_bias once and passing that to each attention ' +\
- 'module instead.'
- )
- b_size, s_k = key_padding_mask.shape[:2]
-
- if attn_bias is None:
- attn_bias = query.new_zeros(b_size, 1, 1, s_k)
-
- attn_bias = attn_bias.masked_fill(
- ~key_padding_mask.view((b_size, 1, 1, s_k)),
- torch.finfo(query.dtype).min)
-
- query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
- key = rearrange(key, 'b s (h d) -> b s h d', h=kv_n_heads)
- value = rearrange(value, 'b s (h d) -> b s h d', h=kv_n_heads)
-
- # multi-query case
- if kv_n_heads == 1:
- # necessary to repeat instead of expand tensor because
- # output contains NaN in edge cases such as with head dimension = 8
- key = key.repeat(1, 1, n_heads, 1)
- value = value.repeat(1, 1, n_heads, 1)
- # grouped query case
- elif kv_n_heads < n_heads:
- # Each query belong to a group of kv heads of group size n_heads // kv_n_heads
- # We repeat each kv head by the group size number to use the underlying MHA kernels
- key = repeat_kv_for_gqa(key, n_heads // kv_n_heads)
- value = repeat_kv_for_gqa(value, n_heads // kv_n_heads)
-
- reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
- attn_output = flash_attn_func( # type: ignore
- query, key, value, attn_bias, reset_is_causal, softmax_scale)
-
- output = attn_output.view(*attn_output.shape[:2], -1) # type: ignore
-
- return output, None, past_key_value
-
-
class GroupedQueryAttention(nn.Module):
"""Grouped Query Attention (GQA) is a generalization of Multi-head (MHA).
and Multi-query attention (MQA).
This allows the user to set a variable of number of kv_n_heads, rather than
- just n_heads or 1, as in MHA and MQA. Using torch or triton attention
+ just n_heads or 1, as in MHA and MQA. Using torch attention
implementation enables user to also use additive bias.
"""
@@ -466,7 +356,7 @@ def __init__(
d_model: int,
n_heads: int,
kv_n_heads: int,
- attn_impl: str = 'triton',
+ attn_impl: str = 'flash',
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
qk_gn: bool = False,
@@ -538,8 +428,6 @@ def __init__(
if self.attn_impl == 'flash':
self.attn_fn = flash_attn_fn
- elif self.attn_impl == 'triton':
- self.attn_fn = triton_flash_attn_fn
elif self.attn_impl == 'torch':
self.attn_fn = scaled_multihead_dot_product_attention
else:
@@ -679,15 +567,14 @@ def forward(
class MultiheadAttention(GroupedQueryAttention):
"""Multi-head self attention.
- Using torch or triton attention implementation enables user to also use
- additive bias.
+ Using torch attention implementation enables user to also use additive bias.
"""
def __init__(
self,
d_model: int,
n_heads: int,
- attn_impl: str = 'triton',
+ attn_impl: str = 'flash',
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
qk_gn: bool = False,
@@ -720,15 +607,14 @@ def __init__(
class MultiQueryAttention(GroupedQueryAttention):
"""Multi-Query self attention.
- Using torch or triton attention implementation enables user to also use
- additive bias.
+ Using torch attention implementation enables user to also use additive bias.
"""
def __init__(
self,
d_model: int,
n_heads: int,
- attn_impl: str = 'triton',
+ attn_impl: str = 'flash',
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
qk_gn: bool = False,
@@ -759,17 +645,16 @@ def __init__(
def attn_bias_shape(
- attn_impl: str, n_heads: int, seq_len: int, alibi: bool,
- prefix_lm: bool, causal: bool,
+ attn_impl: str, n_heads: int, seq_len: int, alibi: bool, causal: bool,
use_sequence_id: bool) -> Optional[tuple[int, int, int, int]]:
if attn_impl == 'flash':
return None
- elif attn_impl in ['torch', 'triton']:
+ elif attn_impl == 'torch':
if alibi:
- if (prefix_lm or not causal) or use_sequence_id:
+ if (not causal) or use_sequence_id:
return (1, n_heads, seq_len, seq_len)
return (1, n_heads, 1, seq_len)
- elif prefix_lm or use_sequence_id:
+ elif use_sequence_id:
return (1, 1, seq_len, seq_len)
return None
else:
@@ -787,7 +672,7 @@ def build_attn_bias(
) -> Optional[torch.Tensor]:
if attn_impl == 'flash':
return None
- elif attn_impl in ['torch', 'triton']:
+ elif attn_impl == 'torch':
if alibi:
# in place add alibi to attn bias
device, dtype = attn_bias.device, attn_bias.dtype
diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py
index 4ac43a8bac..855df7903f 100644
--- a/llmfoundry/models/layers/blocks.py
+++ b/llmfoundry/models/layers/blocks.py
@@ -20,12 +20,11 @@
attn_config_defaults: Dict = {
'attn_type': 'multihead_attention',
'attn_pdrop': 0.0,
- 'attn_impl': 'triton',
+ 'attn_impl': 'flash',
'qk_ln': False,
'qk_gn': False,
'clip_qkv': None,
'softmax_scale': None,
- 'prefix_lm': False,
'attn_uses_sequence_id': False,
'sliding_window_size': -1,
'alibi': False,
@@ -79,9 +78,9 @@ def __init__(
# necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
args_to_exclude_in_attn_class = {
- 'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id',
- 'alibi_bias_max', 'rope', 'rope_theta', 'rope_impl',
- 'rope_dail_config', 'rope_hf_config'
+ 'attn_type', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max',
+ 'rope', 'rope_theta', 'rope_impl', 'rope_dail_config',
+ 'rope_hf_config'
}
attn_config_subset_for_attn_class = {
k: v
diff --git a/llmfoundry/models/layers/flash_attn_triton.py b/llmfoundry/models/layers/flash_attn_triton.py
deleted file mode 100644
index 9276d0f917..0000000000
--- a/llmfoundry/models/layers/flash_attn_triton.py
+++ /dev/null
@@ -1,835 +0,0 @@
-"""
-Copied from https://github.com/HazyResearch/flash-attention/blob/eff9fe6b8076df59d64d7a3f464696738a3c7c24/flash_attn/flash_attn_triton.py
-update imports to use 'triton_pre_mlir'
-
-*Experimental* implementation of FlashAttention in Triton.
-Tested with triton==2.0.0.dev20221202.
-Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
-other than 64:
-https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
-We'll update this implementation with the new Triton backend once this is fixed.
-
-We use the FlashAttention implementation from Phil Tillet a starting point.
-https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
-
-Changes:
-- Implement both causal and non-causal attention.
-- Implement both self-attention and cross-attention.
-- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
-- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
-- Support attention bias.
-- Speed up the forward pass a bit, and only store the LSE instead of m and l.
-- Make the backward for d=128 much faster by reducing register spilling.
-- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
-small batch size * nheads.
-
-Caution:
-- This is an *experimental* implementation. The forward pass should be quite robust but
-I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
-- This implementation has only been tested on A100.
-- If you plan to use headdim other than 64 and 128, you should test for race conditions
-(due to the Triton compiler), as done in tests/test_flash_attn.py
-"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
-for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
-that there are none left for other head dimensions.
-
-Differences between this Triton version and the CUDA version:
-- Triton version doesn't support dropout.
-- Triton forward is generally faster than CUDA forward, while Triton backward is
-generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
-than CUDA forward + backward.
-- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
-- Triton version supports attention bias, while CUDA version doesn't.
-"""
-
-import math
-
-import torch
-
-import triton_pre_mlir as triton
-import triton_pre_mlir.language as tl
-
-
-# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128
-# @triton.autotune(
-# configs=[
-# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1),
-# # This config has a race condition when EVEN_M == False, disabling it for now.
-# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
-# ],
-# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']
-# )
-@triton.heuristics(
- {
- "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
- "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
- "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
- }
-)
-@triton.jit
-def _fwd_kernel(
- Q, K, V, Bias, Out,
- Lse, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
- softmax_scale,
- stride_qb, stride_qh, stride_qm,
- stride_kb, stride_kh, stride_kn,
- stride_vb, stride_vh, stride_vn,
- stride_bb, stride_bh, stride_bm,
- stride_ob, stride_oh, stride_om,
- nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
- CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
- BIAS_TYPE: tl.constexpr,
- IS_CAUSAL: tl.constexpr,
- BLOCK_HEADDIM: tl.constexpr,
- EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
- BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
-):
- start_m = tl.program_id(0)
- off_hb = tl.program_id(1)
- off_b = off_hb // nheads
- off_h = off_hb % nheads
- # off_b = tl.program_id(1)
- # off_h = tl.program_id(2)
- # off_hb = off_b * nheads + off_h
- # initialize offsets
- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_n = tl.arange(0, BLOCK_N)
- offs_d = tl.arange(0, BLOCK_HEADDIM)
- # Initialize pointers to Q, K, V
- # Adding parenthesis around indexing might use int32 math instead of int64 math?
- # https://github.com/openai/triton/issues/741
- # I'm seeing a tiny bit of difference (5-7us)
- q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
- k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
- v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
- if BIAS_TYPE == 'vector':
- b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
- elif BIAS_TYPE == 'matrix':
- b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
- # initialize pointer to m and l
- t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
- lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
- m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
- acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
- # load q: it will stay in SRAM throughout
- # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
- # tl.load(q_ptrs), we get the wrong output!
- if EVEN_M & EVEN_N:
- if EVEN_HEADDIM:
- q = tl.load(q_ptrs)
- else:
- q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
- else:
- if EVEN_HEADDIM:
- q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
- else:
- q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
- other=0.0)
- # loop over k, v and update accumulator
- end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
- for start_n in range(0, end_n, BLOCK_N):
- start_n = tl.multiple_of(start_n, BLOCK_N)
- # -- compute qk ----
- if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
- if EVEN_HEADDIM:
- k = tl.load(k_ptrs + start_n * stride_kn)
- else:
- k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
- else:
- if EVEN_HEADDIM:
- k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,
- other=0.0)
- else:
- k = tl.load(k_ptrs + start_n * stride_kn,
- mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
- other=0.0)
- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
- qk += tl.dot(q, k, trans_b=True)
- # Trying to combine the two masks seem to make the result wrong
- if not EVEN_N: # Need to mask out otherwise the softmax is wrong
- qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
- if IS_CAUSAL:
- qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
- if BIAS_TYPE != 'none':
- if BIAS_TYPE == 'vector':
- if EVEN_N:
- bias = tl.load(b_ptrs + start_n).to(tl.float32)
- else:
- bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32)
- bias = bias[None, :]
- elif BIAS_TYPE == 'matrix':
- if EVEN_M & EVEN_N:
- bias = tl.load(b_ptrs + start_n).to(tl.float32)
- else:
- bias = tl.load(b_ptrs + start_n,
- mask=(offs_m[:, None] < seqlen_q)
- & ((start_n + offs_n)[None, :] < seqlen_k),
- other=0.0).to(tl.float32)
- # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
- # can then fuse the mult and add into an fma instruction. But if we have bias we need to
- # to multiply with softmax_scale here.
- qk = qk * softmax_scale + bias
- m_ij = tl.maximum(tl.max(qk, 1), lse_i)
- p = tl.exp(qk - m_ij[:, None])
- else:
- m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
- p = tl.exp(qk * softmax_scale - m_ij[:, None])
- l_ij = tl.sum(p, 1)
-
- # scale acc_o
- acc_o_scale = tl.exp(m_i - m_ij)
-
- # # -- update output accumulator --
- # BUG: have to store and immediately load
- tl.store(t_ptrs, acc_o_scale)
- acc_o_scale = tl.load(t_ptrs)
- acc_o = acc_o * acc_o_scale[:, None]
- # update acc_o
- if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
- if EVEN_HEADDIM:
- v = tl.load(v_ptrs + start_n * stride_vn)
- else:
- v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
- else:
- if EVEN_HEADDIM:
- v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k,
- other=0.0)
- else:
- v = tl.load(v_ptrs + start_n * stride_vn,
- mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
- other=0.0)
- p = p.to(v.dtype)
- acc_o += tl.dot(p, v)
-
- # -- update statistics
- m_i = m_ij
- l_i_new = tl.exp(lse_i - m_ij) + l_ij
- lse_i = m_ij + tl.log(l_i_new)
-
- o_scale = tl.exp(m_i - lse_i)
- # BUG: have to store and immediately load
- tl.store(t_ptrs, o_scale)
- o_scale = tl.load(t_ptrs)
- acc_o = acc_o * o_scale[:, None]
- # rematerialize offsets to save registers
- start_m = tl.program_id(0)
- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
- # write back l and m
- lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
- tl.store(lse_ptrs, lse_i)
- # initialize pointers to output
- offs_d = tl.arange(0, BLOCK_HEADDIM)
- out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])
- if EVEN_M:
- if EVEN_HEADDIM:
- tl.store(out_ptrs, acc_o)
- else:
- tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
- else:
- if EVEN_HEADDIM:
- tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
- else:
- tl.store(out_ptrs, acc_o,
- mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
-
-
-@triton.jit
-def _bwd_preprocess_do_o_dot(
- Out, DO, Delta,
- stride_ob, stride_oh, stride_om,
- stride_dob, stride_doh, stride_dom,
- nheads, seqlen_q, seqlen_q_rounded, headdim,
- BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
-):
- start_m = tl.program_id(0)
- off_hb = tl.program_id(1)
- off_b = off_hb // nheads
- off_h = off_hb % nheads
- # initialize offsets
- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_d = tl.arange(0, BLOCK_HEADDIM)
- # load
- o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
- mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
- do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :],
- mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
- delta = tl.sum(o * do, axis=1)
- # write-back
- tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
-
-
-@triton.jit
-def _bwd_store_dk_dv(
- dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
- EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
-):
- # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
- # if we just call tl.store(dv_ptrs), there's a race condition
- if EVEN_N & EVEN_M:
- if EVEN_HEADDIM:
- tl.store(dv_ptrs, dv)
- tl.store(dk_ptrs, dk)
- else:
- tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
- tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
- else:
- if EVEN_HEADDIM:
- tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
- tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
- else:
- tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
- tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
-
-
-@triton.jit
-def _bwd_kernel_one_col_block(
- start_n,
- Q, K, V, Bias,
- DO, DQ, DK, DV,
- LSE, D,
- softmax_scale,
- stride_qm, stride_kn, stride_vn, stride_bm,
- stride_dom, stride_dqm, stride_dkn, stride_dvn,
- seqlen_q, seqlen_k, headdim,
- ATOMIC_ADD: tl.constexpr,
- BIAS_TYPE: tl.constexpr,
- IS_CAUSAL: tl.constexpr,
- BLOCK_HEADDIM: tl.constexpr,
- EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
- BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
-):
- # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
- begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
- # initialize row/col offsets
- offs_qm = begin_m + tl.arange(0, BLOCK_M)
- offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
- offs_m = tl.arange(0, BLOCK_M)
- offs_d = tl.arange(0, BLOCK_HEADDIM)
- # initialize pointers to value-like data
- q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
- k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
- v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
- do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
- dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
- if BIAS_TYPE == 'vector':
- b_ptrs = Bias + offs_n
- elif BIAS_TYPE == 'matrix':
- b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
- # initialize dv and dk
- dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
- dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
- # There seems to be some problem with Triton pipelining that makes results wrong for
- # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop
- # may have zero step, and pipelining with the bias matrix could screw it up.
- # So we just exit early.
- if begin_m >= seqlen_q:
- dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
- dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
- _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
- EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
- return
- # k and v stay in SRAM throughout
- # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
- # if we just call tl.load(k_ptrs), we get the wrong output!
- if EVEN_N & EVEN_M:
- if EVEN_HEADDIM:
- k = tl.load(k_ptrs)
- v = tl.load(v_ptrs)
- else:
- k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
- v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
- else:
- if EVEN_HEADDIM:
- k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
- v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
- else:
- k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
- other=0.0)
- v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
- other=0.0)
- # loop over rows
- num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
- for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
- start_m = tl.multiple_of(start_m, BLOCK_M)
- offs_m_curr = start_m + offs_m
- # load q, k, v, do on-chip
- # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117)
- if EVEN_M & EVEN_HEADDIM:
- q = tl.load(q_ptrs)
- else:
- if EVEN_HEADDIM:
- q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
- else:
- q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
- & (offs_d[None, :] < headdim), other=0.0)
- # recompute p = softmax(qk, dim=-1).T
- qk = tl.dot(q, k, trans_b=True)
- # Trying to combine the two masks seem to make the result wrong
- if not EVEN_N: # Need to mask out otherwise the softmax is wrong
- qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
- if IS_CAUSAL:
- qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
- if BIAS_TYPE != 'none':
- tl.debug_barrier() # Race condition otherwise
- if BIAS_TYPE == 'vector':
- if EVEN_N:
- bias = tl.load(b_ptrs).to(tl.float32)
- else:
- bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
- bias = bias[None, :]
- elif BIAS_TYPE == 'matrix':
- if EVEN_M & EVEN_N:
- bias = tl.load(b_ptrs).to(tl.float32)
- else:
- bias = tl.load(b_ptrs,
- mask=(offs_m_curr[:, None] < seqlen_q)
- & (offs_n[None, :] < seqlen_k),
- other=0.0).to(tl.float32)
- qk = qk * softmax_scale + bias
- # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
- # Also wrong for headdim=64.
- if not (EVEN_M & EVEN_HEADDIM):
- tl.debug_barrier()
- lse_i = tl.load(LSE + offs_m_curr)
- if BIAS_TYPE == 'none':
- p = tl.exp(qk * softmax_scale - lse_i[:, None])
- else:
- p = tl.exp(qk - lse_i[:, None])
- # compute dv
- # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
- # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
- # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512,
- # the output is correct.
- if EVEN_M & EVEN_HEADDIM:
- do = tl.load(do_ptrs)
- else:
- # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
- do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
- & (offs_d[None, :] < headdim), other=0.0)
- # if EVEN_M:
- # if EVEN_HEADDIM:
- # do = tl.load(do_ptrs)
- # else:
- # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
- # else:
- # if EVEN_HEADDIM:
- # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
- # else:
- # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
- # & (offs_d[None, :] < headdim), other=0.0)
- dv += tl.dot(p.to(do.dtype), do, trans_a=True)
- # compute dp = dot(v, do)
- # There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
- # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
- # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
- if not (EVEN_M & EVEN_HEADDIM):
- tl.debug_barrier()
- dp = tl.dot(do, v, trans_b=True)
- # There's a race condition for headdim=48
- if not EVEN_HEADDIM:
- tl.debug_barrier()
- # compute ds = p * (dp - delta[:, None])
- # Putting the subtraction after the dp matmul (instead of before) is slightly faster
- Di = tl.load(D + offs_m_curr)
- # Converting ds to q.dtype here reduces register pressure and makes it much faster
- # for BLOCK_HEADDIM=128
- ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
- # compute dk = dot(ds.T, q)
- dk += tl.dot(ds, q, trans_a=True)
- # compute dq
- if not (EVEN_M & EVEN_HEADDIM): # Otherewise there's a race condition when BIAS_TYPE='matrix'
- tl.debug_barrier()
- if not ATOMIC_ADD:
- if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
- dq = tl.load(dq_ptrs, eviction_policy="evict_last")
- dq += tl.dot(ds, k)
- tl.store(dq_ptrs, dq, eviction_policy="evict_last")
- else:
- if EVEN_HEADDIM:
- dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0,
- eviction_policy="evict_last")
- dq += tl.dot(ds, k)
- tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q,
- eviction_policy="evict_last")
- else:
- dq = tl.load(dq_ptrs,
- mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
- other=0.0, eviction_policy="evict_last")
- dq += tl.dot(ds, k)
- tl.store(dq_ptrs, dq,
- mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
- eviction_policy="evict_last")
- else: # If we're parallelizing across the seqlen_k dimension
- dq = tl.dot(ds, k)
- if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
- tl.atomic_add(dq_ptrs, dq)
- else:
- if EVEN_HEADDIM:
- tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
- else:
- tl.atomic_add(dq_ptrs, dq,
- mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
- # increment pointers
- dq_ptrs += BLOCK_M * stride_dqm
- q_ptrs += BLOCK_M * stride_qm
- do_ptrs += BLOCK_M * stride_dom
- if BIAS_TYPE == 'matrix':
- b_ptrs += BLOCK_M * stride_bm
- # write-back
- dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
- dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
- _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
- EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
-
-
-def init_to_zero(name):
- return lambda nargs: nargs[name].zero_()
-
-
-@triton.autotune(
- configs=[
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
- # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
- # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
- # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
- # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
- # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
- # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
- ],
- key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'],
-)
-@triton.heuristics(
- {
- "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
- "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
- "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
- }
-)
-@triton.jit
-def _bwd_kernel(
- Q, K, V, Bias,
- DO, DQ, DK, DV,
- LSE, D,
- softmax_scale,
- stride_qb, stride_qh, stride_qm,
- stride_kb, stride_kh, stride_kn,
- stride_vb, stride_vh, stride_vn,
- stride_bb, stride_bh, stride_bm,
- stride_dob, stride_doh, stride_dom,
- stride_dqb, stride_dqh, stride_dqm,
- stride_dkb, stride_dkh, stride_dkn,
- stride_dvb, stride_dvh, stride_dvn,
- nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
- CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
- BIAS_TYPE: tl.constexpr,
- IS_CAUSAL: tl.constexpr,
- BLOCK_HEADDIM: tl.constexpr,
- SEQUENCE_PARALLEL: tl.constexpr,
- EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
- BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
-):
- off_hb = tl.program_id(1)
- off_b = off_hb // nheads
- off_h = off_hb % nheads
- # offset pointers for batch/head
- Q += off_b * stride_qb + off_h * stride_qh
- K += off_b * stride_kb + off_h * stride_kh
- V += off_b * stride_vb + off_h * stride_vh
- DO += off_b * stride_dob + off_h * stride_doh
- DQ += off_b * stride_dqb + off_h * stride_dqh
- DK += off_b * stride_dkb + off_h * stride_dkh
- DV += off_b * stride_dvb + off_h * stride_dvh
- if BIAS_TYPE != 'none':
- Bias += off_b * stride_bb + off_h * stride_bh
- # pointer to row-wise quantities in value-like data
- D += off_hb * seqlen_q_rounded
- LSE += off_hb * seqlen_q_rounded
- if not SEQUENCE_PARALLEL:
- num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
- for start_n in range(0, num_block_n):
- _bwd_kernel_one_col_block(
- start_n,
- Q, K, V, Bias,
- DO, DQ, DK, DV,
- LSE, D,
- softmax_scale,
- stride_qm, stride_kn, stride_vn, stride_bm,
- stride_dom, stride_dqm, stride_dkn, stride_dvn,
- seqlen_q, seqlen_k, headdim,
- ATOMIC_ADD=False,
- BIAS_TYPE=BIAS_TYPE,
- IS_CAUSAL=IS_CAUSAL,
- BLOCK_HEADDIM=BLOCK_HEADDIM,
- EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
- )
- else:
- start_n = tl.program_id(0)
- _bwd_kernel_one_col_block(
- start_n,
- Q, K, V, Bias,
- DO, DQ, DK, DV,
- LSE, D,
- softmax_scale,
- stride_qm, stride_kn, stride_vn, stride_bm,
- stride_dom, stride_dqm, stride_dkn, stride_dvn,
- seqlen_q, seqlen_k, headdim,
- ATOMIC_ADD=True,
- BIAS_TYPE=BIAS_TYPE,
- IS_CAUSAL=IS_CAUSAL,
- BLOCK_HEADDIM=BLOCK_HEADDIM,
- EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
- )
-
-
-def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
- # shape constraints
- batch, seqlen_q, nheads, d = q.shape
- _, seqlen_k, _, _ = k.shape
- assert k.shape == (batch, seqlen_k, nheads, d)
- assert v.shape == (batch, seqlen_k, nheads, d)
- assert d <= 128, 'FlashAttention only support head dimensions up to 128'
- assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
- assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16'
- assert q.is_cuda and k.is_cuda and v.is_cuda
- softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
-
- has_bias = bias is not None
- bias_type = 'none'
- if has_bias:
- assert bias.dtype in [q.dtype, torch.float]
- assert bias.is_cuda
- assert bias.dim() == 4
- if bias.stride(-1) != 1:
- bias = bias.contiguous()
- if bias.shape[2:] == (1, seqlen_k):
- bias_type = 'vector'
- elif bias.shape[2:] == (seqlen_q, seqlen_k):
- bias_type = 'matrix'
- else:
- raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
- ' or (seqlen_q, seqlen_k)')
- bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
- bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
-
- seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
- lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
- tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
- o = torch.empty_like(q)
-
- BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
- BLOCK = 128
- num_warps = 4 if d <= 64 else 8
- grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
- _fwd_kernel[grid](
- q, k, v, bias, o,
- lse, tmp,
- softmax_scale,
- q.stride(0), q.stride(2), q.stride(1),
- k.stride(0), k.stride(2), k.stride(1),
- v.stride(0), v.stride(2), v.stride(1),
- *bias_strides,
- o.stride(0), o.stride(2), o.stride(1),
- nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
- seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
- # Can't use kwargs here because triton autotune expects key to be args, not kwargs
- # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
- bias_type, causal, BLOCK_HEADDIM,
- BLOCK_M=BLOCK, BLOCK_N=BLOCK,
- num_warps=num_warps,
- num_stages=1,
- )
- return o, lse, softmax_scale # softmax_scale could have been updated
-
-
-def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):
- # Make sure that the last dimension is contiguous
- if do.stride(-1) != 1:
- do = do.contiguous()
- batch, seqlen_q, nheads, d = q.shape
- _, seqlen_k, _, _ = k.shape
- # assert d in {16, 32, 64, 128}
- assert d <= 128
- seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
- assert lse.shape == (batch, nheads, seqlen_q_rounded)
- assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
- assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
- softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
- # dq_accum = torch.zeros_like(q, dtype=torch.float32)
- dq_accum = torch.empty_like(q, dtype=torch.float32)
- delta = torch.empty_like(lse)
- # delta = torch.zeros_like(lse)
-
- BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
- grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
- _bwd_preprocess_do_o_dot[grid](
- o, do, delta,
- o.stride(0), o.stride(2), o.stride(1),
- do.stride(0), do.stride(2), do.stride(1),
- nheads, seqlen_q, seqlen_q_rounded, d,
- BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM,
- )
-
- has_bias = bias is not None
- bias_type = 'none'
- if has_bias:
- assert bias.dtype in [q.dtype, torch.float]
- assert bias.is_cuda
- assert bias.dim() == 4
- assert bias.stride(-1) == 1
- if bias.shape[2:] == (1, seqlen_k):
- bias_type = 'vector'
- elif bias.shape[2:] == (seqlen_q, seqlen_k):
- bias_type = 'matrix'
- else:
- raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
- ' or (seqlen_q, seqlen_k)')
- bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
- bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
-
- # BLOCK_M = 128
- # BLOCK_N = 64
- # num_warps = 4
- grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1,
- batch * nheads)
- _bwd_kernel[grid](
- q, k, v, bias,
- do, dq_accum, dk, dv,
- lse, delta,
- softmax_scale,
- q.stride(0), q.stride(2), q.stride(1),
- k.stride(0), k.stride(2), k.stride(1),
- v.stride(0), v.stride(2), v.stride(1),
- *bias_strides,
- do.stride(0), do.stride(2), do.stride(1),
- dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1),
- dk.stride(0), dk.stride(2), dk.stride(1),
- dv.stride(0), dv.stride(2), dv.stride(1),
- nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
- seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
- # Can't use kwargs here because triton autotune expects key to be args, not kwargs
- # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
- bias_type, causal, BLOCK_HEADDIM,
- # SEQUENCE_PARALLEL=False,
- # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
- # num_warps=num_warps,
- # num_stages=1,
- )
- dq.copy_(dq_accum)
-
-
-class FlashAttnQKVPackedFunc(torch.autograd.Function):
-
- @staticmethod
- def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
- """
- qkv: (batch, seqlen, 3, nheads, headdim)
- bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
- For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
- ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
- """
- # Make sure that the last dimension is contiguous
- if qkv.stride(-1) != 1:
- qkv = qkv.contiguous()
- o, lse, ctx.softmax_scale = _flash_attn_forward(
- qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal,
- softmax_scale=softmax_scale
- )
- ctx.save_for_backward(qkv, o, lse, bias)
- ctx.causal = causal
- return o
-
- @staticmethod
- def backward(ctx, do):
- qkv, o, lse, bias = ctx.saved_tensors
- assert not ctx.needs_input_grad[1], 'FlashAttention does not support bias gradient yet'
- # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
- # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
- with torch.inference_mode():
- dqkv = torch.empty_like(qkv)
- _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse,
- dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
- bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
- return dqkv, None, None, None
-
-
-flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
-
-
-class FlashAttnKVPackedFunc(torch.autograd.Function):
-
- @staticmethod
- def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
- """
- q: (batch, seqlen_q, nheads, headdim)
- kv: (batch, seqlen_k, 2, nheads, headdim)
- bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
- For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
- ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
- """
- # Make sure that the last dimension is contiguous
- q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
- o, lse, ctx.softmax_scale = _flash_attn_forward(
- q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale
- )
- ctx.save_for_backward(q, kv, o, lse, bias)
- ctx.causal = causal
- return o
-
- @staticmethod
- def backward(ctx, do):
- q, kv, o, lse, bias = ctx.saved_tensors
- if len(ctx.needs_input_grad) >= 3:
- assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet'
- # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
- # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
- with torch.inference_mode():
- dq = torch.empty_like(q)
- dkv = torch.empty_like(kv)
- _flash_attn_backward(do, q, kv[:, :, 0], kv[:, :, 1], o, lse,
- dq, dkv[:, :, 0], dkv[:, :, 1],
- bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
- return dq, dkv, None, None, None
-
-
-flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
-
-
-class FlashAttnFunc(torch.autograd.Function):
-
- @staticmethod
- def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
- """
- q: (batch_size, seqlen_q, nheads, headdim)
- k, v: (batch_size, seqlen_k, nheads, headdim)
- bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
- For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
- ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
- """
- # Make sure that the last dimension is contiguous
- q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
- o, lse, ctx.softmax_scale = _flash_attn_forward(
- q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
- )
- ctx.save_for_backward(q, k, v, o, lse, bias)
- ctx.causal = causal
- return o
-
- @staticmethod
- def backward(ctx, do):
- q, k, v, o, lse, bias = ctx.saved_tensors
- assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet'
- # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
- # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
- with torch.inference_mode():
- dq = torch.empty_like(q)
- dk = torch.empty_like(k)
- dv = torch.empty_like(v)
- _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv,
- bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
- return dq, dk, dv, None, None, None
-
-
-flash_attn_func = FlashAttnFunc.apply
diff --git a/llmfoundry/models/layers/llama_attention_monkeypatch.py b/llmfoundry/models/layers/llama_attention_monkeypatch.py
deleted file mode 100644
index eb4a65cd62..0000000000
--- a/llmfoundry/models/layers/llama_attention_monkeypatch.py
+++ /dev/null
@@ -1,317 +0,0 @@
-# Copyright 2022 MosaicML LLM Foundry authors
-# SPDX-License-Identifier: Apache-2.0
-
-# This file is copied and modified from
-# https://github.com/huggingface/transformers/blob/fe3c8ab1af558b95f67f5fafc0c55f09fd2b09db/src/transformers/models/llama/modeling_llama.py
-# See the clearly denoted code blocks for the main modifications (there are a few others like type ignores, and error messages)
-
-import logging
-from typing import Callable, Optional, Tuple
-
-import torch
-import torch.nn.functional as F
-from transformers.models.llama.modeling_llama import LlamaAttention
-
-from llmfoundry.models.layers.attention import (
- scaled_multihead_dot_product_attention, triton_flash_attn_fn)
-
-log = logging.getLogger(__name__)
-
-
-def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """Equivalent of torch.repeat_interleave(x, dim=1,
-
- repeats=n_rep).
-
- The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
- (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :,
- None, :, :].expand(batch, num_key_value_heads,
- n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
- head_dim)
-
-
-def rotate_half(x: torch.Tensor) -> torch.Tensor:
- """Rotates half the hidden dims of the input."""
- x1 = x[..., :x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2:]
- return torch.cat((-x2, x1), dim=-1)
-
-
-def apply_rotary_pos_emb(
- q: torch.Tensor,
- k: torch.Tensor,
- cos: torch.Tensor,
- sin: torch.Tensor,
- position_ids: Optional[torch.Tensor] = None,
- unsqueeze_dim: int = 1,
-):
- """Applies Rotary Position Embedding to the query and key tensors.
-
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`, *optional*):
- Deprecated and unused.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
-
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
-
-
-def get_llama_attention_patch_fn(patch_fn_name: str = 'torch') -> Callable:
- if patch_fn_name == 'torch':
- return llama_attention_patch_torch
- elif patch_fn_name == 'triton':
- return llama_attention_patch_triton
- else:
- raise ValueError(
- f'Unrecognized llama attention patch function: {patch_fn_name}')
-
-
-def llama_attention_patch_torch(
- self: LlamaAttention,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- if use_cache:
- raise NotImplementedError(
- 'use_cache is not yet supported when patching Llama attention.')
-
- bsz, q_len, _ = hidden_states.size()
-
- if self.config.pretraining_tp > 1:
- key_value_slicing = (self.num_key_value_heads *
- self.head_dim) // self.config.pretraining_tp
- query_slices = self.q_proj.weight.split(
- (self.num_heads * self.head_dim) // self.config.pretraining_tp,
- dim=0)
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
-
- query_states = [
- F.linear(hidden_states, query_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- query_states = torch.cat(query_states, dim=-1)
-
- key_states = [
- F.linear(hidden_states, key_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- key_states = torch.cat(key_states, dim=-1)
-
- value_states = [
- F.linear(hidden_states, value_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- value_states = torch.cat(value_states, dim=-1)
- else:
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
- cos, sin = self.rotary_emb(value_states, position_ids)
-
- query_states, key_states = apply_rotary_pos_emb(
- q=query_states,
- k=key_states,
- cos=cos,
- sin=sin,
- position_ids=None,
- )
-
- ### MAIN MODIFICATIONS START HERE ###
- query_states = query_states.transpose(1, 2).view(
- bsz, q_len, self.num_heads * self.head_dim)
- key_states = key_states.transpose(1, 2).view(
- bsz, q_len, self.num_key_value_heads * self.head_dim)
- value_states = value_states.transpose(1, 2).view(
- bsz, q_len, self.num_key_value_heads * self.head_dim)
-
- attn_output, attn_weights, _ = scaled_multihead_dot_product_attention(
- query=query_states,
- key=key_states,
- value=value_states,
- n_heads=self.num_heads,
- kv_n_heads=self.num_key_value_heads,
- past_key_value=None,
- softmax_scale=None,
- attn_bias=attention_mask,
- key_padding_mask=None,
- is_causal=False, # The causal mask is propagated from LLamaForCausalLM
- dropout_p=0,
- training=self.training,
- needs_weights=False,
- )
- ### MAIN MODIFICATIONS END HERE ###
-
- if self.config.pretraining_tp > 1:
- attn_output = attn_output.split(self.hidden_size //
- self.config.pretraining_tp,
- dim=2)
- o_proj_slices = self.o_proj.weight.split(self.hidden_size //
- self.config.pretraining_tp,
- dim=1)
- attn_output = sum([
- F.linear(attn_output[i], o_proj_slices[i])
- for i in range(self.config.pretraining_tp)
- ])
- else:
- attn_output = self.o_proj(attn_output)
-
- assert isinstance(attn_output, torch.Tensor)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, None
-
-
-def llama_attention_patch_triton(
- self: LlamaAttention,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- if use_cache:
- raise NotImplementedError(
- 'use_cache is not yet supported when patching Llama attention.')
- # output_attentions is not support for triton attention
- if output_attentions:
- raise NotImplementedError(
- 'output_attentions is not supported when patching Llama attention with triton attention.'
- )
-
- bsz, q_len, _ = hidden_states.size()
-
- if self.config.pretraining_tp > 1:
- key_value_slicing = (self.num_key_value_heads *
- self.head_dim) // self.config.pretraining_tp
- query_slices = self.q_proj.weight.split(
- (self.num_heads * self.head_dim) // self.config.pretraining_tp,
- dim=0)
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
-
- query_states = [
- F.linear(hidden_states, query_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- query_states = torch.cat(query_states, dim=-1)
-
- key_states = [
- F.linear(hidden_states, key_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- key_states = torch.cat(key_states, dim=-1)
-
- value_states = [
- F.linear(hidden_states, value_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- value_states = torch.cat(value_states, dim=-1)
- else:
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
- cos, sin = self.rotary_emb(value_states, position_ids)
- query_states, key_states = apply_rotary_pos_emb(
- q=query_states,
- k=key_states,
- cos=cos,
- sin=sin,
- position_ids=None,
- )
-
- ### MAIN MODIFICATIONS START HERE ###
- query_states = query_states.transpose(1, 2).view(
- bsz, q_len, self.num_heads * self.head_dim)
- key_states = key_states.transpose(1, 2).view(
- bsz, q_len, self.num_key_value_heads * self.head_dim)
- value_states = value_states.transpose(1, 2).view(
- bsz, q_len, self.num_key_value_heads * self.head_dim)
-
- attn_output, _, _ = triton_flash_attn_fn(
- query=query_states,
- key=key_states,
- value=value_states,
- n_heads=self.num_heads,
- kv_n_heads=self.num_key_value_heads,
- past_key_value=None,
- softmax_scale=None,
- attn_bias=attention_mask,
- key_padding_mask=None,
- is_causal=False, # The causal mask is propagated from LLamaForCausalLM
- dropout_p=0,
- training=self.training,
- needs_weights=False,
- )
- ### MAIN MODIFICATIONS END HERE ###
-
- if self.config.pretraining_tp > 1:
- attn_output = attn_output.split(self.hidden_size //
- self.config.pretraining_tp,
- dim=2)
- o_proj_slices = self.o_proj.weight.split(self.hidden_size //
- self.config.pretraining_tp,
- dim=1)
- attn_output = sum([
- F.linear(attn_output[i], o_proj_slices[i])
- for i in range(self.config.pretraining_tp)
- ])
- else:
- attn_output = self.o_proj(attn_output)
-
- assert isinstance(attn_output, torch.Tensor)
-
- return attn_output, None, None
diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py
index b5a099002e..20c3850a82 100644
--- a/llmfoundry/models/mpt/configuration_mpt.py
+++ b/llmfoundry/models/mpt/configuration_mpt.py
@@ -20,8 +20,6 @@
from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note)
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note)
-from llmfoundry.utils.warnings import VersionedDeprecationWarning
-
ffn_config_defaults: Dict = {
'ffn_type': 'mptmlp',
}
@@ -81,16 +79,13 @@ def __init__(
attn_config (Dict): A dictionary used to configure the model's attention module:
attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention, grouped_query_attention
attn_pdrop (float): The dropout probability for the attention layers.
- attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
+ attn_impl (str): The attention implementation to use. One of 'torch' or 'flash'.
qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
qk_gn (bool): Whether to apply group normalization to the queries and keys in the attention layer.
clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
this value.
softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
use the default scale of ``1/sqrt(d_keys)``.
- prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
- extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
- can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
which sub-sequence each token belongs to.
@@ -210,43 +205,20 @@ def _validate_config(self) -> None:
raise ValueError(
"self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1"
)
- if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
+ if self.attn_config['attn_impl'] not in ['torch', 'flash']:
raise ValueError(
f"Unknown attn_impl={self.attn_config['attn_impl']}")
- if self.attn_config['prefix_lm']:
- warnings.warn(
- VersionedDeprecationWarning(
- 'Support for Prefix Language Models is deprecated.',
- remove_version='0.7.0'))
- if self.attn_config['attn_impl'] == 'triton':
- warnings.warn(
- VersionedDeprecationWarning(
- 'Support for triton attention is deprecated. Please use torch or flash attention.',
- remove_version='0.7.0'))
-
- if self.attn_config['prefix_lm'] and self.attn_config[
- 'attn_impl'] not in ['torch', 'triton']:
- raise NotImplementedError(
- 'prefix_lm only implemented with torch and triton attention.')
-
- if self.attn_config[
- 'attn_impl'] == 'triton' and not self.attn_config['prefix_lm']:
- warnings.warn(
- UserWarning(
- 'If not using a Prefix Language Model, we recommend setting "attn_impl" to "flash" instead of "triton".'
- ))
-
if self.attn_config['alibi'] and not check_alibi_support(
self.attn_config['attn_impl']):
raise NotImplementedError(
- 'alibi only implemented with torch, triton, and flash (v2.4.2 or higher) attention.'
+ 'alibi only implemented with torch and flash (v2.4.2 or higher) attention.'
)
if self.attn_config['attn_uses_sequence_id'] and not (
- self.attn_config['attn_impl'] in ['torch', 'triton'] or
+ self.attn_config['attn_impl'] == 'torch' or
(self.attn_config['attn_impl'] == 'flash' and
is_flash_v2_installed(v2_version='v2.1.2'))):
raise NotImplementedError(
- 'attn_uses_sequence_id only implemented with torch, triton, and flash (v2.1.2 or higher) attention.'
+ 'attn_uses_sequence_id only implemented with torch and flash (v2.1.2 or higher) attention.'
)
if self.attn_config['rope'] and (self.attn_config['rope_impl']
not in ['dail', 'hf']):
diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py
index bf93bd71f0..e0a666f62c 100644
--- a/llmfoundry/models/mpt/modeling_mpt.py
+++ b/llmfoundry/models/mpt/modeling_mpt.py
@@ -53,14 +53,6 @@
# HuggingFace can detect all the needed files to copy into its modules folder.
# Otherwise, certain modules are missing.
# isort: off
-from llmfoundry.models.utils.adapt_tokenizer import (
- AutoTokenizerForMOD, # type: ignore (see note)
- adapt_tokenizer_for_denoising, # type: ignore (see note)
-)
-from llmfoundry.models.utils.hf_prefixlm_converter import (
- add_bidirectional_mask_if_missing, # type: ignore (see note)
- convert_hf_causal_lm_to_prefix_lm, # type: ignore (see note)
-)
from llmfoundry.models.utils.meta_init_context import \
init_empty_weights # type: ignore (see note)
from llmfoundry.models.utils.param_init_fns import (
@@ -72,12 +64,6 @@
build_act_ckpt_mod_to_blocks,
check_mapping_blocks_overlap)
-try:
- from llmfoundry.models.layers.flash_attn_triton import flash_attn_func as flash_attn_func
-except:
- pass
-# isort: on
-
import logging
log = logging.getLogger(__name__)
@@ -299,7 +285,6 @@ def __init__(self, config: MPTConfig):
super().__init__(config)
self.attn_impl = config.attn_config['attn_impl']
- self.prefix_lm = config.attn_config['prefix_lm']
self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
self.alibi = config.attn_config['alibi']
self.alibi_bias_max = config.attn_config['alibi_bias_max']
@@ -364,7 +349,7 @@ def __init__(self, config: MPTConfig):
)
self.apply(self.param_init_fn)
- self.is_causal = not self.prefix_lm
+ self.is_causal = True
# define attn mask
self._attn_bias_initialized = False
@@ -374,7 +359,6 @@ def __init__(self, config: MPTConfig):
config.n_heads,
config.max_seq_len,
self.alibi,
- prefix_lm=self.prefix_lm,
causal=self.is_causal,
use_sequence_id=self.attn_uses_sequence_id,
)
@@ -407,7 +391,6 @@ def _attn_bias(
device: torch.device,
dtype: torch.dtype,
attention_mask: Optional[torch.ByteTensor] = None,
- prefix_mask: Optional[torch.ByteTensor] = None,
sequence_id: Optional[torch.LongTensor] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.ByteTensor]]:
if not self._attn_bias_initialized:
@@ -426,8 +409,7 @@ def _attn_bias(
)
self._attn_bias_initialized = True
- # flash does not support prefix_lm and will incorporate any
- # attention_mask inside the attention module
+ # flash will incorporate any attention_mask inside the attention module
if self.attn_impl == 'flash':
return self.attn_bias, attention_mask
@@ -438,19 +420,13 @@ def _attn_bias(
attn_bias = self.attn_bias
- # If using torch or triton, we incorporate the prefix_mask (if appropriate)
- if self.prefix_lm:
- assert isinstance(attn_bias, torch.Tensor) # pyright
- assert isinstance(prefix_mask, torch.Tensor) # pyright
- attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
-
- # If using torch or triton, we incorporate sequence_id (if appropriate)
+ # If using torch, we incorporate sequence_id (if appropriate)
if self.attn_uses_sequence_id and sequence_id is not None:
assert isinstance(attn_bias, torch.Tensor) # pyright
attn_bias = apply_sequence_id(attn_bias, sequence_id,
self.config.max_seq_len)
- # If using torch or triton, we incorporate attention_mask. This will output
+ # If using torch, we incorporate attention_mask. This will output
# None in place of attention_mask since it will not be further needed in the
# attention modules.
if attention_mask is not None:
@@ -463,54 +439,17 @@ def _attn_bias(
# clamp to 0 necessary for torch 2.0 compile()
_s_k = max(0, attn_bias.size(-1) - s_k)
attn_bias = attn_bias[:, :, :, _s_k:]
- if prefix_mask is not None and (attention_mask.shape !=
- prefix_mask.shape):
- raise ValueError(
- f'attention_mask shape={attention_mask.shape} ' +
- f'and prefix_mask shape={prefix_mask.shape} are not equal.')
min_val = torch.finfo(attn_bias.dtype).min
attn_bias = attn_bias.masked_fill(
~attention_mask.view(-1, 1, 1, s_k), min_val)
return attn_bias, attention_mask
- def _apply_prefix_mask(self, attn_bias: torch.Tensor,
- prefix_mask: torch.Tensor) -> torch.Tensor:
- s_k, s_q = attn_bias.shape[-2:]
- if (s_k != self.config.max_seq_len) or (s_q != self.config.max_seq_len):
- raise ValueError(
- 'attn_bias does not match the expected shape. ' +
- f'The last two dimensions should both be {self.config.max_length} '
- + f'but are {s_k} and {s_q}.')
- seq_len = prefix_mask.shape[-1]
- if seq_len > self.config.max_seq_len:
- raise ValueError(
- f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}'
- )
-
- # select seq_len subset of attn mask
- attn_bias = attn_bias[..., :seq_len, :seq_len]
-
- # Mix the causal max and the bidirectional mask to get the full
- # allowable attention (i.e. full = not accounting for padding yet)
- causal = torch.tril(
- torch.ones((seq_len, seq_len),
- dtype=torch.bool,
- device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
- prefix = prefix_mask.view(-1, 1, 1, seq_len)
- cannot_attend = ~torch.logical_or(causal, prefix.bool())
-
- min_val = torch.finfo(attn_bias.dtype).min
- attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
-
- return attn_bias
-
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
- prefix_mask: Optional[torch.ByteTensor] = None,
sequence_id: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
output_attentions: Optional[bool] = None,
@@ -526,9 +465,6 @@ def forward(
if attention_mask is not None:
attention_mask = attention_mask.bool() # type: ignore
- if prefix_mask is not None:
- prefix_mask = prefix_mask.bool() # type: ignore
-
# These args are passed in by keyword in huggingface's generate function
# https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206
# but have not yet been fully implemented in MPTModel
@@ -538,7 +474,7 @@ def forward(
if output_attentions:
if self.attn_impl != 'torch':
raise NotImplementedError(
- 'output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.'
+ 'output_attentions is not implemented for MPT when using attn_impl `flash`.'
)
if (self.training and attention_mask is not None and
@@ -546,11 +482,6 @@ def forward(
raise NotImplementedError(
'MPT does not support training with left padding.')
- if self.prefix_lm and prefix_mask is None:
- raise ValueError(
- 'prefix_mask is a required argument when MPT is configured with prefix_lm=True.'
- )
-
if self.training:
if self.attn_uses_sequence_id and sequence_id is None:
raise ValueError(
@@ -594,8 +525,8 @@ def forward(
+
f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).'
)
- # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
- # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
+ # For attn_impl: flash, the past key tensor spec is (batch, seq, dim).
+ # For attn_impl: torch, the past key tensor spec is (batch, heads, head_dim, seq).
# Here we shift position embedding using the `seq` dim of the past key
past_position = past_key_values[0][0].size(1)
if self.attn_impl == 'torch':
@@ -654,7 +585,6 @@ def forward(
device=x.device,
dtype=torch.float32,
attention_mask=attention_mask,
- prefix_mask=prefix_mask,
sequence_id=sequence_id,
)
attention_mask_in_length = gen_attention_mask_in_length(
@@ -825,7 +755,6 @@ def forward(
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
- prefix_mask: Optional[torch.ByteTensor] = None,
sequence_id: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
@@ -843,7 +772,6 @@ def forward(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
- prefix_mask=prefix_mask,
sequence_id=sequence_id,
return_dict=return_dict,
output_attentions=output_attentions,
@@ -975,16 +903,6 @@ def prepare_inputs_for_generation(
if past_key_values is not None:
input_ids = input_ids[:, -1].unsqueeze(-1)
- if self.transformer.prefix_lm:
- # Leverage a convenience of sequential generation!
- prefix_mask = torch.ones_like(attention_mask)
- # This requires that we're using the cache
- if kwargs.get('use_cache') == False:
- raise NotImplementedError(
- 'MPT with prefix_lm=True does not support use_cache=False.')
- else:
- prefix_mask = None
-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
@@ -993,7 +911,6 @@ def prepare_inputs_for_generation(
model_inputs.update({
'attention_mask': attention_mask,
- 'prefix_mask': prefix_mask,
'sequence_id': sequence_id,
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache', True),
@@ -1089,13 +1006,9 @@ def get_targets(self, batch: Mapping) -> torch.Tensor:
return targets
def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast:
- if self.model.transformer.prefix_lm:
- add_bidirectional_mask_if_missing(batch)
- # Note: prefix_mask is only used if model.prefix_lm is True
return self.model(
input_ids=batch.get('input_ids', None),
attention_mask=batch.get('attention_mask', None),
- prefix_mask=batch.get('bidirectional_mask', None),
sequence_id=batch.get('sequence_id', None),
inputs_embeds=batch.get('inputs_embeds', None),
)
diff --git a/llmfoundry/models/utils/__init__.py b/llmfoundry/models/utils/__init__.py
index 35a15e530a..7c808ff449 100644
--- a/llmfoundry/models/utils/__init__.py
+++ b/llmfoundry/models/utils/__init__.py
@@ -1,22 +1,14 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
-from llmfoundry.models.utils.adapt_tokenizer import (
- AutoTokenizerForMOD, adapt_tokenizer_for_denoising)
-from llmfoundry.models.utils.hf_prefixlm_converter import (
- add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm)
from llmfoundry.models.utils.meta_init_context import (init_empty_weights,
init_on_device)
from llmfoundry.models.utils.param_init_fns import (MODEL_INIT_REGISTRY,
generic_param_init_fn_)
__all__ = [
- 'AutoTokenizerForMOD',
- 'adapt_tokenizer_for_denoising',
- 'convert_hf_causal_lm_to_prefix_lm',
'init_empty_weights',
'init_on_device',
- 'add_bidirectional_mask_if_missing',
'generic_param_init_fn_',
'MODEL_INIT_REGISTRY',
]
diff --git a/llmfoundry/models/utils/adapt_tokenizer.py b/llmfoundry/models/utils/adapt_tokenizer.py
deleted file mode 100644
index 8cb0c33697..0000000000
--- a/llmfoundry/models/utils/adapt_tokenizer.py
+++ /dev/null
@@ -1,57 +0,0 @@
-# Copyright 2022 MosaicML LLM Foundry authors
-# SPDX-License-Identifier: Apache-2.0
-
-from typing import Any
-
-from transformers import AutoTokenizer, PreTrainedTokenizerBase
-
-# For consistency with T5 Tokenizer, which is what this adaptation aims to mimic,
-# we hardcode there to be 100 sentinel tokens
-NUM_SENTINEL_TOKENS: int = 100
-
-
-def adapt_tokenizer_for_denoising(tokenizer: PreTrainedTokenizerBase) -> None:
- """Adds sentinel tokens and padding token (if missing).
-
- Expands the tokenizer vocabulary to include sentinel tokens
- used in mixture-of-denoiser tasks as well as a padding token.
-
- All added tokens are added as special tokens. No tokens are
- added if sentinel tokens and padding token already exist.
- """
- # Add sentinel tokens (e.g., , , and so on). Has no effect if these are already in the vocab.
- sentinels_to_add = [f'' for i in range(NUM_SENTINEL_TOKENS)]
- tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
-
- # If the padding token has not been set, add and use it
- if tokenizer.pad_token is None:
- tokenizer.add_tokens('', special_tokens=True)
- tokenizer.pad_token = ''
- assert tokenizer.pad_token_id is not None
-
- # Register a property that gets us the ids of the sentinel tokens
- sentinels = ''.join([f'' for i in range(NUM_SENTINEL_TOKENS)])
- _sentinel_token_ids = tokenizer(sentinels,
- add_special_tokens=False).input_ids
-
- tokenizer.sentinel_token_ids = _sentinel_token_ids
-
-
-class AutoTokenizerForMOD(AutoTokenizer):
- """AutoTokenizer + Adaptation for MOD.
-
- A simple wrapper around AutoTokenizer to make instantiating
- an MOD-adapted tokenizer a bit easier.
-
- MOD-adapted tokenizers have sentinel tokens (e.g., ),
- a padding token, and a property to get the token ids of the
- sentinel tokens.
- """
-
- @classmethod
- def from_pretrained(cls, *args: Any,
- **kwargs: Any) -> PreTrainedTokenizerBase:
- """See `AutoTokenizer.from_pretrained` docstring."""
- tokenizer = super().from_pretrained(*args, **kwargs)
- adapt_tokenizer_for_denoising(tokenizer)
- return tokenizer
diff --git a/llmfoundry/models/utils/hf_prefixlm_converter.py b/llmfoundry/models/utils/hf_prefixlm_converter.py
deleted file mode 100644
index 692fab94c2..0000000000
--- a/llmfoundry/models/utils/hf_prefixlm_converter.py
+++ /dev/null
@@ -1,301 +0,0 @@
-# Copyright 2022 MosaicML LLM Foundry authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""Converts Huggingface Causal LM to Prefix LM.
-
-Conversion does lightweight surgery on a HuggingFace
-Causal LM to convert it to a Prefix LM.
-
-Prefix LMs accepts a `bidirectional_mask` input in `forward`
-and treat the input prompt as the prefix in `generate`.
-"""
-
-from types import MethodType
-from typing import Any, List, MutableMapping, Optional, Tuple, Union
-
-import torch
-from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
-from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
-from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
-from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
-
-_SUPPORTED_GPT_MODELS = (
- GPT2LMHeadModel,
- GPTJForCausalLM,
- GPTNeoForCausalLM,
- GPTNeoXForCausalLM,
-)
-
-CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM,
- GPTNeoXForCausalLM,]
-
-
-def _convert_gpt_causal_lm_to_prefix_lm(
- model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
- """Converts a GPT-style Causal LM to a Prefix LM.
-
- Supported HuggingFace model classes:
- - `GPT2LMHeadModel`
- - `GPTNeoForCausalLM`
- - `GPTNeoXForCausalLM`
- - `GPTJForCausalLM`
-
- See `convert_hf_causal_lm_to_prefix_lm` for more details.
- """
- if hasattr(model, '_prefix_lm_converted'):
- return model
-
- assert isinstance(model, _SUPPORTED_GPT_MODELS)
- assert model.config.add_cross_attention == False, 'Only supports GPT-style decoder-only models'
-
- def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
- """Helper that gets a list of the model's attention modules.
-
- Each module has a `bias` buffer used for causal masking. The Prefix LM
- conversion adds logic to dynamically manipulate these biases to support
- Prefix LM attention masking.
- """
- attn_modules = []
-
- if isinstance(model, GPTNeoXForCausalLM):
- blocks = model.gpt_neox.layers
- else:
- blocks = model.transformer.h
-
- for block in blocks:
- if isinstance(model, GPTNeoForCausalLM):
- # Ignore "local" layers in this model type
- if block.attn.attention_type != 'global':
- continue
- attn_module = block.attn.attention
- elif isinstance(model, GPTNeoXForCausalLM):
- attn_module = block.attention
- else:
- attn_module = block.attn
-
- attn_modules.append(attn_module)
-
- return attn_modules
-
- # Rename methods to allow:
- # - new `forward` to wrap original `forward`
- # - new `generate` to wrap original `generate`
- setattr(model, '_original_forward', getattr(model, 'forward'))
- setattr(model, '_original_generate', getattr(model, 'generate'))
-
- def forward(
- self: CAUSAL_GPT_TYPES,
- input_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- bidirectional_mask: Optional[torch.Tensor] = None,
- token_type_ids: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ):
- """Wraps original forward to enable PrefixLM attention."""
-
- def call_og_forward():
- if isinstance(self, GPTNeoXForCausalLM):
- return self._original_forward(
- input_ids=input_ids,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- labels=labels,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- else:
- return self._original_forward(
- input_ids=input_ids,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- labels=labels,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
-
- if bidirectional_mask is None:
- # This wrapper is a no-op if bidirectional masks are not supplied
- return call_og_forward()
- assert isinstance(bidirectional_mask, torch.Tensor)
-
- attn_modules = _get_attn_modules(model)
-
- # Handle bidirectional_mask sizing
- # Note: all attn_modules.bias have the same size
- b, s = bidirectional_mask.shape
-
- max_length = attn_modules[0].bias.shape[-1] # type: ignore
-
- if s > max_length:
- raise ValueError(
- f'bidirectional_mask sequence length (={s}) exceeds the ' +\
- f'max length allowed by the model ({max_length}).'
- )
- assert s <= max_length
- if s < max_length:
- pad = torch.zeros((int(b), int(max_length - s)),
- dtype=bidirectional_mask.dtype,
- device=bidirectional_mask.device)
- bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
- bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
-
- # Incorporate the bidirectional mask into the original causal mask
- for attn_module in attn_modules:
- assert isinstance(attn_module.bias, torch.Tensor)
- attn_module.bias.data = torch.logical_or(attn_module.bias.data,
- bidirectional)
-
- # Collect outputs using the model's original forward method
- output = call_og_forward()
-
- # Reset the masks
- for attn_module in attn_modules:
- attn_module.bias.data = torch.tril(
- attn_module.bias.data[0, 0])[None, None] # type: ignore
-
- # Return the outputs
- return output
-
- def generate(self: CAUSAL_GPT_TYPES, *args: Any, **kwargs: Any):
- """Wraps original generate to enable PrefixLM attention."""
- attn_modules = _get_attn_modules(model)
-
- # A convenient answer to PrefixLM generation is to set the causal mask
- # to be bidirectional. All the tokens in the input prompt can attend to
- # one another and, since tokens are generated one-by-one, each new
- # token gets to see everything behind it. This depends on activations
- # being cached and not updated, which is how the HF implementation works.
- for attn_module in attn_modules:
- attn_module.bias.data[:] = 1 # type: ignore
-
- # Collect outputs using the model's original forward method
- output = self._original_generate(*args, **kwargs)
-
- # Reset the masks
- for attn_module in attn_modules:
- attn_module.bias.data = torch.tril(
- attn_module.bias.data[0, 0])[None, None] # type: ignore
-
- # Return the outputs
- return output
-
- # Replace `forward` and `generate` with the new wrappers
- setattr(model, 'forward', MethodType(forward, model))
- setattr(model, 'generate', MethodType(generate, model))
-
- # Finally, tag the model so that this conversion cannot happen again.
- setattr(model, '_prefix_lm_converted', True)
- return model
-
-
-_SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS
-
-CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM,
- GPTNeoXForCausalLM]
-
-
-def convert_hf_causal_lm_to_prefix_lm(
- model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
- """Converts a HuggingFace Causal LM to a Prefix LM.
-
- Supported HuggingFace model classes:
- - `GPT2LMHeadModel`
- - `GPTNeoForCausalLM`
- - `GPTNeoXForCausalLM`
- - `GPTJForCausalLM`
-
- Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
- `generate` method and/or select underlying methods depending on the model class.
-
- These changes preserve the model API, but add a new input to `forward`: "bidirectional_mask".
-
- Notes on training:
- To actually train the converted model as a Prefix LM, training batches will need to indicate
- the prefix/target structure by including `bidirectional_mask` as part of the batch inputs.
-
- **This is not a standard input and requires custom layers either within or after your dataloader.**
-
- In addition to adding `bidirectional_mask` to the batch, this custom code should modify `labels`
- such that `batch['labels'][batch['bidirectional_mask'] == 1] == -100`.
- That is, the prefix portion of the sequence should not generate any loss. Loss should only be
- generated by the target portion of the sequence.
-
- Notes on `GPTNeoForCausalLM`:
- To simplify the implementation, "global" and "local" attention layers are handled differently.
- For "global" layers, we handle conversion as described above. For "local" layers, which use a
- causal attention mask within a restricted local window, we do not alter the masking.
-
- Notes on `forward` method conversion:
- After conversion, the `forward` method will handle a new input, `bidirectional_mask`,
- which should be a [batch_size, seq_length] byte tensor, where 1 indicates token positions
- belonging to the prefix (prefix tokens can attend to one another bidirectionally), and
- 0 indicates token positions belonging to the target.
-
- The new `forward` method will incorporate `bidirectional_mask` (if supplied) into the existing
- causal mask, call the original `forward` method, and (if the causal mask is a buffer) reset
- the causal masks before returning the result.
-
- Notes on `generate` method conversion:
- After conversion, the `generate` method will have the same signature but will internally
- convert all causal masks to be purely bidirectional, call the original `generate` method, and
- (where appropriate) reset the causal masks before returning the result.
-
- This works thanks to the logic of the HuggingFace `generate` API, which first encodes the token
- "prompt" passed to `generate` (which is treated as the prefix) and then sequentially generates
- each new token. Encodings are cached as generation happens, so all prefix tokens can attend to one
- another (as expected in a Prefix LM) and generated tokens can only attend to prefix tokens and
- previously-generated tokens (also as expected in a Prefix LM).
-
- To preserve the API, the original methods are renamed to `_original_forward` and
- `_original_generate`, and replaced with new `forward` and `generate` methods that wrap
- them, respectively. Although implementation details vary by model class.
- """
- if isinstance(model, _SUPPORTED_GPT_MODELS):
- return _convert_gpt_causal_lm_to_prefix_lm(model)
- else:
- raise TypeError(
- f'Cannot convert model to Prefix LM. ' +\
- f'Model does not belong to set of supported HF models:' +\
- f'\n{_SUPPORTED_HF_MODELS}'
- )
-
-
-def add_bidirectional_mask_if_missing(batch: MutableMapping):
- """Attempts to add bidirectional_mask to batch if missing.
-
- Raises:
- KeyError if bidirectional_mask is missing and can't be inferred
- """
- if 'bidirectional_mask' not in batch:
- if batch.get('mode', None) == 'icl_task':
- batch['bidirectional_mask'] = batch['attention_mask'].clone()
- for i, continuation_indices in enumerate(
- batch['continuation_indices']):
- batch['bidirectional_mask'][i, continuation_indices] = 0
- elif ('labels' in batch) and ('attention_mask' in batch):
- batch['bidirectional_mask'] = torch.logical_and(
- torch.eq(batch['attention_mask'], 1),
- torch.eq(batch['labels'], -100),
- ).type_as(batch['attention_mask'])
- else:
- raise KeyError(
- 'No bidirectional_mask in batch and not sure how to construct one.'
- )
diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py
new file mode 100644
index 0000000000..fe24d0eae6
--- /dev/null
+++ b/llmfoundry/utils/exceptions.py
@@ -0,0 +1,206 @@
+# Copyright 2024 MosaicML LLM Foundry authors
+# SPDX-License-Identifier: Apache-2.0
+
+"""Custom exceptions for the LLMFoundry."""
+from collections.abc import Mapping
+from typing import Any, Dict, List
+
+
+# Finetuning dataloader exceptions
+class MissingHuggingFaceURLSplitError(ValueError):
+ """Error thrown when there's no split used in HF dataset config."""
+
+ def __init__(self) -> None:
+ message = 'When using a HuggingFace dataset from a URL, you must set the ' + \
+ '`split` key in the dataset config.'
+ super().__init__(message)
+
+
+class NotEnoughDatasetSamplesError(ValueError):
+ """Error thrown when there is not enough data to train a model."""
+
+ def __init__(self, dataset_name: str, split: str,
+ dataloader_batch_size: int, world_size: int,
+ full_dataset_size: int, minimum_dataset_size: int) -> None:
+ self.dataset_name = dataset_name
+ self.split = split
+ self.dataloader_batch_size = dataloader_batch_size
+ self.world_size = world_size
+ self.full_dataset_size = full_dataset_size
+ self.minimum_dataset_size = minimum_dataset_size
+ message = (
+ f'Your dataset (name={dataset_name}, split={split}) ' +
+ f'has {full_dataset_size} samples, but your minimum batch size ' +
+ f'is {minimum_dataset_size} because you are running on {world_size} gpus and '
+ +
+ f'your per device batch size is {dataloader_batch_size}. Please increase the number '
+ + f'of samples in your dataset to at least {minimum_dataset_size}.')
+ super().__init__(message)
+
+
+## Tasks exceptions
+class UnknownExampleTypeError(KeyError):
+ """Error thrown when an unknown example type is used in a task."""
+
+ def __init__(self, example: Mapping) -> None:
+ self.example = example
+ message = f'Unknown example type {example=}'
+ super().__init__(message)
+
+
+class TooManyKeysInExampleError(ValueError):
+ """Error thrown when a data sample has too many keys."""
+
+ def __init__(self, desired_keys: set[str], keys: set[str]) -> None:
+ self.desired_keys = desired_keys
+ self.keys = keys
+ message = f'Data sample has {len(keys)} keys in `allowed_keys`: {desired_keys} Please specify exactly one. Provided keys: {keys}'
+ super().__init__(message)
+
+
+class NotEnoughChatDataError(ValueError):
+ """Error thrown when there is not enough chat data to train a model."""
+
+ def __init__(self) -> None:
+ message = 'Chat example must have at least two messages'
+ super().__init__(message)
+
+
+class ConsecutiveRepeatedChatRolesError(ValueError):
+ """Error thrown when there are consecutive repeated chat roles."""
+
+ def __init__(self, repeated_role: str) -> None:
+ self.repeated_role = repeated_role
+ message = f'Conversation roles must alternate but found {repeated_role} repeated consecutively.'
+ super().__init__(message)
+
+
+class InvalidLastChatMessageRoleError(ValueError):
+ """Error thrown when the last message role in a chat example is invalid."""
+
+ def __init__(self, last_role: str, expected_roles: set[str]) -> None:
+ self.last_role = last_role
+ self.expected_roles = expected_roles
+ message = f'Invalid last message role: {last_role}. Expected one of: {expected_roles}'
+ super().__init__(message)
+
+
+class IncorrectMessageKeyQuantityError(ValueError):
+ """Error thrown when a message has an incorrect number of keys."""
+
+ def __init__(self, keys: List[str]) -> None:
+ self.keys = keys
+ message = f'Expected 2 keys in message, but found {len(keys)}'
+ super().__init__(message)
+
+
+class InvalidRoleError(ValueError):
+ """Error thrown when a role is invalid."""
+
+ def __init__(self, role: str, valid_roles: set[str]) -> None:
+ self.role = role
+ self.valid_roles = valid_roles
+ message = f'Expected role to be one of {valid_roles} but found: {role}'
+ super().__init__(message)
+
+
+class InvalidContentTypeError(TypeError):
+ """Error thrown when the content type is invalid."""
+
+ def __init__(self, content_type: type) -> None:
+ self.content_type = content_type
+ message = f'Expected content to be a string, but found {content_type}'
+ super().__init__(message)
+
+
+class InvalidPromptTypeError(TypeError):
+ """Error thrown when the prompt type is invalid."""
+
+ def __init__(self, prompt_type: type) -> None:
+ self.prompt_type = prompt_type
+ message = f'Expected prompt to be a string, but found {prompt_type}'
+ super().__init__(message)
+
+
+class InvalidResponseTypeError(TypeError):
+ """Error thrown when the response type is invalid."""
+
+ def __init__(self, response_type: type) -> None:
+ self.response_type = response_type
+ message = f'Expected response to be a string, but found {response_type}'
+ super().__init__(message)
+
+
+class InvalidPromptResponseKeysError(ValueError):
+ """Error thrown when missing expected prompt and response keys."""
+
+ def __init__(self, mapping: Dict[str, str], example: Dict[str, Any]):
+ self.example = example
+ message = f'Expected {mapping=} to have keys "prompt" and "response".'
+ super().__init__(message)
+
+
+class InvalidFileExtensionError(FileNotFoundError):
+ """Error thrown when a file extension is not a safe extension."""
+
+ def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None:
+ self.dataset_name = dataset_name
+ self.valid_extensions = valid_extensions
+ message = (
+ f'safe_load is set to True. No data files with safe extensions {valid_extensions} '
+ + f'found for dataset at local path {dataset_name}.')
+ super().__init__(message)
+
+
+class UnableToProcessPromptResponseError(ValueError):
+ """Error thrown when a prompt and response cannot be processed."""
+
+ def __init__(self, input: Dict) -> None:
+ self.input = input
+ message = f'Unable to extract prompt/response from {input}'
+ super().__init__(message)
+
+
+## Convert Delta to JSON exceptions
+class ClusterDoesNotExistError(ValueError):
+ """Error thrown when the cluster does not exist."""
+
+ def __init__(self, cluster_id: str) -> None:
+ self.cluster_id = cluster_id
+ message = f'Cluster with id {cluster_id} does not exist. Check cluster id and try again!'
+ super().__init__(message)
+
+
+class FailedToCreateSQLConnectionError(RuntimeError):
+ """Error thrown when client can't sql connect to Databricks."""
+
+ def __init__(self) -> None:
+ message = 'Failed to create sql connection to db workspace. To use sql connect, you need to provide http_path and cluster_id!'
+ super().__init__(message)
+
+
+class FailedToConnectToDatabricksError(RuntimeError):
+ """Error thrown when the client fails to connect to Databricks."""
+
+ def __init__(self) -> None:
+ message = 'Failed to create databricks connection. Check hostname and access token!'
+ super().__init__(message)
+
+
+## Convert Text to MDS exceptions
+class InputFolderMissingDataError(ValueError):
+ """Error thrown when the input folder is missing data."""
+
+ def __init__(self, input_folder: str) -> None:
+ self.input_folder = input_folder
+ message = f'No text files were found at {input_folder}.'
+ super().__init__(message)
+
+
+class OutputFolderNotEmptyError(FileExistsError):
+ """Error thrown when the output folder is not empty."""
+
+ def __init__(self, output_folder: str) -> None:
+ self.output_folder = output_folder
+ message = f'{output_folder} is not empty. Please remove or empty it and retry.'
+ super().__init__(message)
diff --git a/llmfoundry/utils/huggingface_hub_utils.py b/llmfoundry/utils/huggingface_hub_utils.py
index 9fdc20c0d6..5a198bc8df 100644
--- a/llmfoundry/utils/huggingface_hub_utils.py
+++ b/llmfoundry/utils/huggingface_hub_utils.py
@@ -131,7 +131,8 @@ def edit_files_for_hf_compatibility(
folder: str,
flatten_imports_prefix: Sequence[str] = ('llmfoundry',),
remove_imports_prefix: Sequence[str] = ('composer', 'omegaconf',
- 'llmfoundry.metrics'),
+ 'llmfoundry.metrics',
+ 'llmfoundry.utils.builders'),
) -> None:
"""Edit files to be compatible with Hugging Face Hub.
@@ -139,7 +140,7 @@ def edit_files_for_hf_compatibility(
folder (str): The folder to process.
flatten_imports_prefix (Sequence[str], optional): Sequence of prefixes to flatten. Defaults to ('llmfoundry',).
remove_imports_prefix (Sequence[str], optional): Sequence of prefixes to remove. Takes precedence over flattening.
- Defaults to ('composer', 'omegaconf', 'llmfoundry.metrics').
+ Defaults to ('composer', 'omegaconf', 'llmfoundry.metrics', 'llmfoundry.utils.builders').
"""
files_to_process = [
os.path.join(folder, filename)
diff --git a/llmfoundry/utils/logging_utils.py b/llmfoundry/utils/logging_utils.py
index 3a2d1eedd8..f6c930beab 100644
--- a/llmfoundry/utils/logging_utils.py
+++ b/llmfoundry/utils/logging_utils.py
@@ -2,6 +2,11 @@
# SPDX-License-Identifier: Apache-2.0
import logging
+import os
+
+from composer.loggers import MosaicMLLogger
+from composer.loggers.mosaicml_logger import (MOSAICML_ACCESS_TOKEN_ENV_VAR,
+ MOSAICML_PLATFORM_ENV_VAR)
__all__ = [
'SpecificWarningFilter',
@@ -23,3 +28,12 @@ def __init__(self, message_to_suppress: str):
def filter(self, record: logging.LogRecord) -> bool:
return self.message_to_suppress not in record.getMessage()
+
+
+def get_mosaicml_logger():
+ if os.environ.get(MOSAICML_PLATFORM_ENV_VAR, 'false').lower(
+ ) == 'true' and os.environ.get(MOSAICML_ACCESS_TOKEN_ENV_VAR):
+ # Adds mosaicml logger to composer if the run was sent from Mosaic platform, access token is set
+ return MosaicMLLogger()
+ else:
+ return None
diff --git a/pyproject.toml b/pyproject.toml
index 503899c4fa..6bde062abb 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -22,8 +22,7 @@ include = [
# Pyright
[tool.pyright]
-exclude = ['env-**', 'venv*', '**/flash_attn_triton.py', '.venv']
-ignore = ['llmfoundry/models/layers/flash_attn_triton.py']
+exclude = ['env-**', 'venv*', '.venv']
stubPath = "" # suppress useless 'stubPath is not a valid directory' errors
reportUnnecessaryIsInstance = "none" # it is ok to do this for clarity or safety
diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py
index 856f038651..aefafdb49a 100644
--- a/scripts/data_prep/convert_delta_to_json.py
+++ b/scripts/data_prep/convert_delta_to_json.py
@@ -33,6 +33,11 @@
from pyspark.sql.dataframe import DataFrame as SparkDataFrame
from pyspark.sql.types import Row
+from llmfoundry.utils import maybe_create_mosaicml_logger
+from llmfoundry.utils.exceptions import (ClusterDoesNotExistError,
+ FailedToConnectToDatabricksError,
+ FailedToCreateSQLConnectionError)
+
MINIMUM_DB_CONNECT_DBR_VERSION = '14.1'
MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2'
@@ -401,9 +406,8 @@ def validate_and_get_cluster_info(cluster_id: str,
w = WorkspaceClient()
res = w.clusters.get(cluster_id=cluster_id)
if res is None:
- raise ValueError(
- f'Cluster id {cluster_id} does not exist. Check cluster id and try again!'
- )
+ raise ClusterDoesNotExistError(cluster_id)
+
stripped_runtime = re.sub(
r'[a-zA-Z]',
'',
@@ -436,9 +440,7 @@ def validate_and_get_cluster_info(cluster_id: str,
cluster_id=cluster_id).getOrCreate()
except Exception as e:
- raise RuntimeError(
- 'Failed to create databricks connection. Check hostname and access token!'
- ) from e
+ raise FailedToConnectToDatabricksError() from e
else:
try:
dbsql = sql.connect(
@@ -449,9 +451,7 @@ def validate_and_get_cluster_info(cluster_id: str,
access_token=databricks_token,
)
except Exception as e:
- raise RuntimeError(
- 'Failed to create sql connection to db workspace. To use sql connect, you need to provide http_path and cluster_id!'
- ) from e
+ raise FailedToCreateSQLConnectionError() from e
return method, dbsql, sparkSession
@@ -462,13 +462,13 @@ def fetch_DT(args: Namespace) -> None:
obj = urllib.parse.urlparse(args.json_output_folder)
if obj.scheme != '':
raise ValueError(
- f'Check the json_output_folder and verify it is a local path!')
+ 'Check the json_output_folder and verify it is a local path!')
if os.path.exists(args.json_output_folder):
if not os.path.isdir(args.json_output_folder) or os.listdir(
args.json_output_folder):
raise RuntimeError(
- f'A file or a folder {args.json_output_folder} already exists and is not empty. Remove it and retry!'
+ f'Output folder {args.json_output_folder} already exists and is not empty. Please remove it and retry.'
)
os.makedirs(args.json_output_folder, exist_ok=True)
@@ -547,12 +547,18 @@ def fetch_DT(args: Namespace) -> None:
'The name of the combined final jsonl that combines all partitioned jsonl'
)
args = parser.parse_args()
+ mosaicml_logger = maybe_create_mosaicml_logger()
- from databricks.sdk import WorkspaceClient
- w = WorkspaceClient()
- args.DATABRICKS_HOST = w.config.host
- args.DATABRICKS_TOKEN = w.config.token
+ try:
+ w = WorkspaceClient()
+ args.DATABRICKS_HOST = w.config.host
+ args.DATABRICKS_TOKEN = w.config.token
- tik = time.time()
- fetch_DT(args)
- log.info('Elapsed time', time.time() - tik)
+ tik = time.time()
+ fetch_DT(args)
+ log.info('Elapsed time', time.time() - tik)
+
+ except Exception as e:
+ if mosaicml_logger is not None:
+ mosaicml_logger.log_exception(e)
+ raise e
diff --git a/scripts/data_prep/convert_text_to_mds.py b/scripts/data_prep/convert_text_to_mds.py
index bfd60b8ee1..df39e38a90 100644
--- a/scripts/data_prep/convert_text_to_mds.py
+++ b/scripts/data_prep/convert_text_to_mds.py
@@ -18,10 +18,14 @@
from transformers import AutoTokenizer
from llmfoundry.data import ConcatTokensDataset
+from llmfoundry.utils import maybe_create_mosaicml_logger
from llmfoundry.utils.data_prep_utils import (DownloadingIterable,
merge_shard_groups)
+from llmfoundry.utils.exceptions import (InputFolderMissingDataError,
+ OutputFolderNotEmptyError)
log = logging.getLogger(__name__)
+
DONE_FILENAME = '.text_to_mds_conversion_done'
@@ -369,7 +373,7 @@ def convert_text_to_mds(
object_names = get_object_names(input_folder)
if len(object_names) == 0:
- raise ValueError(f'No text files were found at {input_folder}.')
+ raise InputFolderMissingDataError(input_folder)
# Check if the text files in the bucket have already been processed.
if not reprocess and is_already_processed(output_folder, args_str,
@@ -386,8 +390,7 @@ def convert_text_to_mds(
).name if is_remote_output else output_folder
if os.path.isdir(output_folder) and len(os.listdir(output_folder)) > 0:
- raise FileExistsError(
- f'{output_folder=} is not empty. Please remove or empty it.')
+ raise OutputFolderNotEmptyError(output_folder)
if processes > 1:
# Download and convert the text files in parallel
@@ -446,14 +449,21 @@ def _args_str(original_args: Namespace) -> str:
if __name__ == '__main__':
args = parse_args()
- convert_text_to_mds(tokenizer_name=args.tokenizer,
- output_folder=args.output_folder,
- input_folder=args.input_folder,
- concat_tokens=args.concat_tokens,
- eos_text=args.eos_text,
- bos_text=args.bos_text,
- no_wrap=args.no_wrap,
- compression=args.compression,
- processes=args.processes,
- reprocess=args.reprocess,
- args_str=_args_str(args))
+ mosaicml_logger = maybe_create_mosaicml_logger()
+
+ try:
+ convert_text_to_mds(tokenizer_name=args.tokenizer,
+ output_folder=args.output_folder,
+ input_folder=args.input_folder,
+ concat_tokens=args.concat_tokens,
+ eos_text=args.eos_text,
+ bos_text=args.bos_text,
+ no_wrap=args.no_wrap,
+ compression=args.compression,
+ processes=args.processes,
+ reprocess=args.reprocess,
+ args_str=_args_str(args))
+ except Exception as e:
+ if mosaicml_logger is not None:
+ mosaicml_logger.log_exception(e)
+ raise e
diff --git a/scripts/misc/convert_examples_ckpt.py b/scripts/misc/convert_examples_ckpt.py
index a533aec72d..db1301674c 100644
--- a/scripts/misc/convert_examples_ckpt.py
+++ b/scripts/misc/convert_examples_ckpt.py
@@ -125,7 +125,7 @@ def convert_examples_ckpt(
'attn_clip_qkv', attn_config_defaults['clip_qkv'])
for k in [
- 'attn_pdrop', 'attn_impl', 'softmax_scale', 'prefix_lm',
+ 'attn_pdrop', 'attn_impl', 'softmax_scale',
'attn_uses_sequence_id', 'alibi', 'alibi_bias_max'
]:
if k in hf_config:
diff --git a/scripts/train/README.md b/scripts/train/README.md
index 0d9a335848..36974ec943 100644
--- a/scripts/train/README.md
+++ b/scripts/train/README.md
@@ -335,7 +335,7 @@ train_loader:
# Using Flash Attention
-Flash Attention is an optimized implementation of the attention mechanism, first introduced by [Dao et al.](https://github.com/Dao-AILab/flash-attention). There are three versions of Flash Attention that can be used with LLM Foundry: Flash Attention V1, Flash Attention V2, and a Triton implementation of Flash Attention. The support for Flash Attention V1 has been deprecated, and we recommend using Flash Attention V2. We also recommend using `Flash` attention instead of `Triton` attention, unless you're training Prefix Language Models (in which case we recommend using `Triton`). To start, we recommend using one of our [provided Docker images](../../README.md#mosaicml-docker-images) corresponding to the Flash Attention version you would like to use. The Triton implementation can be used with either Flash Attention V1 or V2. Next, how you specify to use Flash Attention depends on which model you are using.
+Flash Attention is an optimized implementation of the attention mechanism, first introduced by [Dao et al.](https://github.com/Dao-AILab/flash-attention). LLM Foundry supports Flash Attention V2. To start, we recommend using one of our [provided Docker images](../../README.md#mosaicml-docker-images) corresponding to the Flash Attention version you would like to use. Next, how you specify to use Flash Attention depends on which model you are using.
For MPT, you can specify Flash Attention in your YAML like so:
```yaml
@@ -343,8 +343,6 @@ model:
name: mpt_causal_lm
...
attn_config:
- # Will use either V1 or V2 depending on what is installed
- # "triton" will use the Triton implementation
attn_impl: flash
...
```
@@ -356,8 +354,6 @@ model:
pretrained_model_name_or_path: mosaicml/mpt-7b
...
config_overrides:
- # Will use either V1 or V2 depending on what is installed
- # "triton" will use the Triton implementation
attn_config:
attn_impl: flash
...
diff --git a/scripts/train/benchmarking/submit_benchmarks.py b/scripts/train/benchmarking/submit_benchmarks.py
index a020745581..aff570e3d4 100644
--- a/scripts/train/benchmarking/submit_benchmarks.py
+++ b/scripts/train/benchmarking/submit_benchmarks.py
@@ -141,7 +141,7 @@ def parse_args():
nargs='+',
help='model sizes to test')
- parser.add_argument('--attn_impl', type=str, default='triton')
+ parser.add_argument('--attn_impl', type=str, default='flash')
parser.add_argument('-c',
'--clusters',
diff --git a/scripts/train/benchmarking/sweep.py b/scripts/train/benchmarking/sweep.py
index 441c9825f8..57b3aa262c 100644
--- a/scripts/train/benchmarking/sweep.py
+++ b/scripts/train/benchmarking/sweep.py
@@ -64,14 +64,14 @@
'--fsdp_config_activation_checkpointing true',
'--fsdp_config_shard_strategy FULL_SHARD',
'--microbatch_size 16',
- '--attn_impl triton',
+ '--attn_impl flash',
],
[
'--model_yamls 30b.yaml',
'--fsdp_config_activation_checkpointing true',
'--fsdp_config_shard_strategy FULL_SHARD',
'--microbatch_size 8',
- '--attn_impl triton',
+ '--attn_impl flash',
],
[
'--model_yamls 70b.yaml',
diff --git a/scripts/train/train.py b/scripts/train/train.py
index 93491452dd..44cfc053f4 100644
--- a/scripts/train/train.py
+++ b/scripts/train/train.py
@@ -24,7 +24,6 @@
maybe_create_mosaicml_logger)
install()
-
from llmfoundry.callbacks import AsyncEval
from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.utils.builders import (add_metrics_to_eval_loaders,
@@ -56,43 +55,10 @@ def validate_config(cfg: DictConfig):
loaders.append(eval_loader)
for loader in loaders:
if loader.name == 'text':
- if cfg.model.name in ['hf_prefix_lm', 'hf_t5']:
+ if cfg.model.name == 'hf_t5':
raise ValueError(
f'Model type "{cfg.model.name}" is not supported when using the "text " ' +\
- f'dataloader. Please use the "text_denoising" dataloader to pre-train that model type.')
- elif loader.name == 'text_denoising':
- if cfg.model.name == 'hf_causal_lm':
- raise ValueError(
- f'Model type "{cfg.model.name}" is not supported when using the "text_denoising" ' +\
- f'dataloader. Please use the "text" dataloader to pre-train that model type.')
- if loader.mixture_of_denoisers.decoder_only_format and cfg.model.name == 'hf_t5':
- warnings.warn(
- 'Model type "hf_t5" requires `decoder_only_format` to be ``False``. ' +\
- 'Overriding `decoder_only_format` from ``True`` to ``False``.')
- loader.mixture_of_denoisers.decoder_only_format = False
- if (not loader.mixture_of_denoisers.decoder_only_format
- ) and cfg.model.name == 'hf_prefix_lm':
- warnings.warn(
- 'Model type "hf_prefix_lm" requires `decoder_only_format` to be ``True``. ' +\
- 'Overriding `decoder_only_format` from ``False`` to ``True``.')
- loader.mixture_of_denoisers.decoder_only_format = True
- elif loader.name == 'finetuning':
- if cfg.model.name == 'hf_prefix_lm':
- is_prefix_lm = True
- elif cfg.model.name == 'mpt_causal_lm':
- is_prefix_lm = cfg.model.get('attn_config',
- {}).get('prefix_lm', False)
- else:
- # Note: This only covers the two prefix-lms introduced in this repo
- is_prefix_lm = False
- target_responses = loader.dataset.get('target_responses', 'last')
- target_prompts = loader.dataset.get('target_prompts', 'none')
- prefix_lm_safe = target_responses == 'last' and target_prompts == 'none'
- if is_prefix_lm and not prefix_lm_safe:
- raise ValueError(
- 'The model configuration is building a Prefix-LM, which requires that the finetuning ' +\
- 'dataloader uses `target_responses`="last" and `target_prompts`="none".'
- )
+ f'dataloader. Only finetuning is supported.')
if 'icl_tasks' in cfg:
if cfg.model.name == 'hf_t5':
@@ -493,11 +459,16 @@ def main(cfg: DictConfig) -> Trainer:
# Dataloaders
log.info('Building train loader...')
- train_loader = build_dataloader(
- train_loader_config,
- tokenizer,
- device_train_batch_size,
- )
+ try:
+ train_loader = build_dataloader(
+ train_loader_config,
+ tokenizer,
+ device_train_batch_size,
+ )
+ except Exception as e:
+ if mosaicml_logger is not None:
+ mosaicml_logger.log_exception(e)
+ raise e
if mosaicml_logger is not None:
mosaicml_logger.log_metrics({'data_validated': time.time()})
@@ -531,7 +502,6 @@ def main(cfg: DictConfig) -> Trainer:
eval_loader_config, callback_configs,
tokenizer_name, load_path, icl_tasks_config,
eval_gauntlet_config)
-
# Build Model
log.info('Initializing model...')
model = build_composer_model(
@@ -556,13 +526,19 @@ def main(cfg: DictConfig) -> Trainer:
optimizer = build_optimizer(model, optimizer_name, optimizer_config)
# Now add the eval metrics
- if eval_loader_config is not None and not use_async_eval:
- eval_metrics = model.get_metrics(is_train=False)
- non_icl_metrics = [
- metric_name for metric_name, metric in eval_metrics.items()
- if not isinstance(metric, InContextLearningMetric)
- ]
- evaluators = add_metrics_to_eval_loaders(evaluators, non_icl_metrics)
+ try:
+ if eval_loader_config is not None and not use_async_eval:
+ eval_metrics = model.get_metrics(is_train=False)
+ non_icl_metrics = [
+ metric_name for metric_name, metric in eval_metrics.items()
+ if not isinstance(metric, InContextLearningMetric)
+ ]
+ evaluators = add_metrics_to_eval_loaders(evaluators,
+ non_icl_metrics)
+ except Exception as e:
+ if mosaicml_logger is not None:
+ mosaicml_logger.log_exception(e)
+ raise e
# Build the Trainer
log.info('Building trainer...')
diff --git a/scripts/train/yamls/finetune/dbrx-full-ft.yaml b/scripts/train/yamls/finetune/dbrx-full-ft.yaml
new file mode 100644
index 0000000000..859f59138c
--- /dev/null
+++ b/scripts/train/yamls/finetune/dbrx-full-ft.yaml
@@ -0,0 +1,132 @@
+# Note: This requires ~64x80GB GPUs
+max_seq_len: 4096
+icl_seq_len: 1024
+
+# Run Name
+run_name: # If left blank, will be read from env var $RUN_NAME
+
+# Model
+model:
+ name: hf_causal_lm
+ pretrained: true
+ init_device: mixed
+ use_auth_token: true
+ config_overrides: {}
+ use_flash_attention_2: true
+ pretrained_model_name_or_path: databricks/dbrx-instruct
+
+# Tokenizer
+tokenizer:
+ name: databricks/dbrx-instruct
+ kwargs:
+ model_max_length: ${max_seq_len}
+ trust_remote_code: true
+
+# Dataloaders
+train_loader:
+ name: finetuning
+ dataset:
+ split: train
+ hf_name: mosaicml/dolly_hhrlhf
+ shuffle: true
+ max_seq_len: ${max_seq_len}
+ eos_token_id: 0
+ packing_ratio: auto
+ allow_pad_trimming: false
+ decoder_only_format: true
+ drop_last: true
+ pin_memory: true
+ num_workers: 8
+ prefetch_factor: 2
+ persistent_workers: true
+
+eval_loader:
+ name: finetuning
+ dataset:
+ split: test
+ hf_name: mosaicml/dolly_hhrlhf
+ shuffle: false
+ max_seq_len: ${max_seq_len}
+ packing_ratio: null
+ allow_pad_trimming: false
+ decoder_only_format: true
+ drop_last: true
+ pin_memory: true
+ num_workers: 8
+ prefetch_factor: 2
+ persistent_workers: true
+
+# Optimization
+optimizer:
+ lr: 0.000001
+ name: decoupled_lionw
+ betas:
+ - 0.9
+ - 0.95
+ weight_decay: 1.0e-06
+
+scheduler:
+ name: cosine_with_warmup
+ alpha_f: 0
+ t_warmup: 0.02dur
+
+algorithms:
+ gradient_clipping:
+ clipping_type: norm
+ clipping_threshold: 1
+
+max_duration: 2ep
+eval_interval: 1ep
+global_train_batch_size: 64
+eval_first: false
+# eval_subset_num_batches: -1
+
+# System
+seed: 17
+device_train_microbatch_size: 1
+device_eval_batch_size: 1
+precision: amp_bf16
+autoresume: true
+dist_timeout: 3600
+
+# FSDP
+fsdp_config:
+ mixed_precision: PURE
+ state_dict_type: sharded
+ limit_all_gathers: true
+ sharding_strategy: FULL_SHARD
+ activation_cpu_offload: false
+ activation_checkpointing: true
+ activation_checkpointing_reentrant: false
+
+# Logging
+progress_bar: false
+log_to_console: true
+console_log_interval: 1ba
+
+# Callbacks
+callbacks:
+ lr_monitor: {}
+ speed_monitor:
+ window_size: 1
+ memory_monitor: {}
+ hf_checkpointer:
+ overwrite: true
+ precision: bfloat16
+ save_folder: ./{run_name}/checkpoints
+ save_interval: 1dur
+ runtime_estimator: {}
+# Checkpoint to local filesystem or remote object store
+# save_interval: 5000ba
+# save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK
+# save_folder: ./{run_name}/checkpoints
+# save_folder: s3://my-bucket/my-folder/{run_name}/checkpoints
+
+# Logging
+# loggers:
+# wandb:
+# name:
+# group:
+# mlflow:
+# tracking_uri:
+# experiment_name:
diff --git a/scripts/train/yamls/finetune/dbrx-lora-ft.yaml b/scripts/train/yamls/finetune/dbrx-lora-ft.yaml
new file mode 100644
index 0000000000..3fb4bb99e7
--- /dev/null
+++ b/scripts/train/yamls/finetune/dbrx-lora-ft.yaml
@@ -0,0 +1,140 @@
+# Note: This requires ~16x80GB GPUs
+max_seq_len: 4096
+icl_seq_len: 1024
+
+# Run Name
+run_name: # If left blank, will be read from env var $RUN_NAME
+
+# Model
+model:
+ name: hf_causal_lm
+ pretrained: true
+ init_device: mixed
+ peft_config:
+ r: 64
+ peft_type: LORA
+ task_type: CAUSAL_LM
+ lora_alpha: 128
+ lora_dropout: 0.05
+ target_modules:
+ - Wqkv
+ use_auth_token: true
+ config_overrides: {}
+ use_flash_attention_2: true
+ pretrained_model_name_or_path: databricks/dbrx-instruct
+
+# Tokenizer
+tokenizer:
+ name: databricks/dbrx-instruct
+ kwargs:
+ model_max_length: ${max_seq_len}
+ trust_remote_code: true
+
+# Dataloaders
+train_loader:
+ name: finetuning
+ dataset:
+ split: train
+ hf_name: mosaicml/dolly_hhrlhf
+ shuffle: true
+ max_seq_len: ${max_seq_len}
+ eos_token_id: 0
+ packing_ratio: auto
+ allow_pad_trimming: false
+ decoder_only_format: true
+ drop_last: true
+ pin_memory: true
+ num_workers: 8
+ prefetch_factor: 2
+ persistent_workers: true
+
+eval_loader:
+ name: finetuning
+ dataset:
+ split: test
+ hf_name: mosaicml/dolly_hhrlhf
+ shuffle: false
+ max_seq_len: ${max_seq_len}
+ packing_ratio: null
+ allow_pad_trimming: false
+ decoder_only_format: true
+ drop_last: true
+ pin_memory: true
+ num_workers: 8
+ prefetch_factor: 2
+ persistent_workers: true
+
+# Optimization
+optimizer:
+ lr: 0.0001
+ name: decoupled_lionw
+ betas:
+ - 0.9
+ - 0.95
+ weight_decay: 1.0e-06
+
+scheduler:
+ name: cosine_with_warmup
+ alpha_f: 0
+ t_warmup: 0.02dur
+
+algorithms:
+ gradient_clipping:
+ clipping_type: norm
+ clipping_threshold: 1
+
+max_duration: 2ep
+eval_interval: 1ep
+global_train_batch_size: 16
+eval_first: false
+# eval_subset_num_batches: -1
+
+# System
+seed: 17
+device_train_microbatch_size: 1
+device_eval_batch_size: 1
+precision: amp_bf16
+autoresume: true
+dist_timeout: 3600
+
+# FSDP
+fsdp_config:
+ mixed_precision: PURE
+ state_dict_type: sharded
+ limit_all_gathers: true
+ sharding_strategy: FULL_SHARD
+ activation_cpu_offload: false
+ activation_checkpointing: true
+ activation_checkpointing_reentrant: false
+
+# Logging
+progress_bar: false
+log_to_console: true
+console_log_interval: 1ba
+
+# Callbacks
+callbacks:
+ lr_monitor: {}
+ speed_monitor:
+ window_size: 1
+ memory_monitor: {}
+ hf_checkpointer:
+ overwrite: true
+ precision: bfloat16
+ save_folder: ./{run_name}/checkpoints
+ save_interval: 1dur
+ runtime_estimator: {}
+# Checkpoint to local filesystem or remote object store
+# save_interval: 5000ba
+# save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK
+# save_folder: ./{run_name}/checkpoints
+# save_folder: s3://my-bucket/my-folder/{run_name}/checkpoints
+
+# Logging
+# loggers:
+# wandb:
+# name:
+# group:
+# mlflow:
+# tracking_uri:
+# experiment_name:
diff --git a/setup.py b/setup.py
index 9a26307e5a..22b7cb17ca 100644
--- a/setup.py
+++ b/setup.py
@@ -66,9 +66,6 @@
'mosaicml-cli>=0.6.10,<1',
'onnx==1.14.0',
'onnxruntime==1.15.1',
- 'cmake>=3.25.0,<=3.26.3', # required for triton-pre-mlir below
- # PyPI does not support direct dependencies, so we remove this line before uploading from PyPI
- 'triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir_sm90#subdirectory=python',
'boto3>=1.21.45,<2',
'huggingface-hub>=0.17.0,<1.0',
'beautifulsoup4>=4.12.2,<5', # required for model download utils
diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py
index b366d8635a..7839455563 100644
--- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py
+++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py
@@ -14,7 +14,7 @@
run_query)
-class TestConverDeltaToJsonl(unittest.TestCase):
+class TestConvertDeltaToJsonl(unittest.TestCase):
@patch('scripts.data_prep.convert_delta_to_json.sql.connect')
@patch('scripts.data_prep.convert_delta_to_json.os.makedirs')
diff --git a/tests/a_scripts/data_prep/test_convert_text_to_mds.py b/tests/a_scripts/data_prep/test_convert_text_to_mds.py
index 3a00a8889f..e458cb1dfc 100644
--- a/tests/a_scripts/data_prep/test_convert_text_to_mds.py
+++ b/tests/a_scripts/data_prep/test_convert_text_to_mds.py
@@ -14,6 +14,8 @@
from streaming import StreamingDataset
from transformers import AutoTokenizer
+from llmfoundry.utils.exceptions import (InputFolderMissingDataError,
+ OutputFolderNotEmptyError)
from scripts.data_prep.convert_text_to_mds import (DONE_FILENAME,
convert_text_to_mds,
download_and_convert,
@@ -209,7 +211,7 @@ def call_convert_text_to_mds(reprocess: bool):
assert os.path.exists(output_folder / 'shard.00000.mds.zstd')
# Test reprocessing.
- with pytest.raises(FileExistsError):
+ with pytest.raises(OutputFolderNotEmptyError):
call_convert_text_to_mds(reprocess=True)
shutil.rmtree(output_folder)
@@ -217,6 +219,24 @@ def call_convert_text_to_mds(reprocess: bool):
call_convert_text_to_mds(reprocess=True)
+def test_input_folder_not_exist(tmp_path: pathlib.Path):
+ with pytest.raises(InputFolderMissingDataError,
+ match='No text files were found'):
+ convert_text_to_mds(
+ tokenizer_name='mosaicml/mpt-7b',
+ output_folder=str(tmp_path / 'output'),
+ input_folder=str(tmp_path / 'input'),
+ concat_tokens=1,
+ eos_text='',
+ bos_text='',
+ no_wrap=False,
+ compression='zstd',
+ processes=1,
+ args_str='Namespace()',
+ reprocess=False,
+ )
+
+
def test_is_already_processed(tmp_path: pathlib.Path):
tmp_path_str = str(tmp_path)
args_str = 'Namespace(x = 5)'
diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py
index f30c64f0ad..b6be3be8e0 100644
--- a/tests/a_scripts/inference/test_convert_composer_to_hf.py
+++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py
@@ -863,50 +863,6 @@ def test_convert_and_generate(model: str, tie_word_embeddings: bool,
delete_transformers_cache()
-@pytest.mark.gpu
-@pytest.mark.parametrize('tie_word_embeddings', [True, False])
-def test_convert_and_generate_triton(tie_word_embeddings: str,
- tmp_path: pathlib.Path):
- delete_transformers_cache()
-
- cfg = get_config()
- cfg['model']['init_device'] = 'cpu'
- cfg['tie_word_embeddings'] = tie_word_embeddings
- tokenizer = transformers.AutoTokenizer.from_pretrained(
- 'EleutherAI/gpt-neox-20b')
- model = ComposerMPTCausalLM(cfg['model'], tokenizer)
- trainer = Trainer(model=model)
- trainer.save_checkpoint(os.path.join(tmp_path, 'checkpoint.pt'))
-
- args = Namespace(composer_path=os.path.join(tmp_path, 'checkpoint.pt'),
- hf_output_path=os.path.join(tmp_path, 'hf-output-folder'),
- output_precision='fp32',
- local_checkpoint_save_location=None,
- hf_repo_for_upload=None,
- trust_remote_code=False,
- test_uploaded_model=False)
- convert_composer_to_hf(args)
-
- config = transformers.AutoConfig.from_pretrained(os.path.join(
- tmp_path, 'hf-output-folder'),
- trust_remote_code=True)
- config.attn_config['attn_impl'] = 'triton'
- model = transformers.AutoModelForCausalLM.from_pretrained(
- os.path.join(tmp_path, 'hf-output-folder'),
- config=config,
- trust_remote_code=True)
- model.to(device='cuda', dtype=torch.bfloat16)
- tokenizer = transformers.AutoTokenizer.from_pretrained(
- os.path.join(tmp_path, 'hf-output-folder'), trust_remote_code=True)
-
- output = model.generate(tokenizer(
- 'hello', return_tensors='pt')['input_ids'].to(device='cuda'),
- max_new_tokens=1)
- assert output.shape == (1, 2)
-
- delete_transformers_cache()
-
-
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_convert_and_generate_meta(tie_word_embeddings: str,
tmp_path: pathlib.Path):
diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py
index dc06bfb3c1..f5c2631fa7 100644
--- a/tests/data/test_dataloader.py
+++ b/tests/data/test_dataloader.py
@@ -5,7 +5,6 @@
import pathlib
import random
import shutil
-import tempfile
from argparse import Namespace
from contextlib import nullcontext as does_not_raise
from pathlib import Path
@@ -22,8 +21,7 @@
from omegaconf import OmegaConf as om
from streaming import MDSWriter
-from llmfoundry import (build_finetuning_dataloader,
- build_text_denoising_dataloader)
+from llmfoundry import build_finetuning_dataloader
from llmfoundry.data import build_dataloader
from llmfoundry.data.finetuning.collator import (_HF_IGNORE_INDEX,
validate_target_settings)
@@ -36,9 +34,23 @@
build_text_dataloader,
get_tokens_per_batch_func)
from llmfoundry.utils.builders import build_tokenizer
+# yapf: disable
+from llmfoundry.utils.exceptions import (ConsecutiveRepeatedChatRolesError,
+ IncorrectMessageKeyQuantityError,
+ InvalidContentTypeError,
+ InvalidLastChatMessageRoleError,
+ InvalidPromptTypeError,
+ InvalidResponseTypeError,
+ InvalidRoleError,
+ NotEnoughDatasetSamplesError,
+ TooManyKeysInExampleError,
+ UnknownExampleTypeError)
+# yapf: enable
from scripts.data_prep.convert_dataset_hf import main as main_hf
from scripts.data_prep.convert_finetuning_dataset import get_columns_and_format
-from tests.data_utils import make_tiny_ft_dataset
+from tests.data_utils import (make_tiny_conversation_ft_dataset,
+ make_tiny_ft_dataset)
+from tests.test_utils import generate_exclusive_test_params
def get_config(conf_path: str = 'yamls/mpt/125m.yaml'):
@@ -256,72 +268,6 @@ def test_sequence_id_wrapper(eos_token_id: Optional[int],
raise NotImplementedError()
-@pytest.mark.parametrize('decoder_only_format', [True, False])
-@pytest.mark.parametrize('pretokenize', [True, False])
-@pytest.mark.parametrize('packing_ratio', [None, 5.5])
-def test_denoising_dataloader(decoder_only_format: bool, pretokenize: bool,
- packing_ratio: Optional[float]):
- # Use the datasets just built in the last test
- tokenizer_name = 'facebook/opt-125m'
- data_local = get_data_local(tokenizer_name, pretokenize)
- path = get_abs_data_path(data_local)
- max_seq_len = 256 if decoder_only_format else 128
-
- if (decoder_only_format is False) and (packing_ratio is not None):
- pytest.xfail('packing_ratio only supported for decoder-only format.')
-
- with tempfile.TemporaryDirectory() as tmpdir:
- cfg = {
- 'name': 'text_denoising',
- 'dataset': {
- 'local': tmpdir,
- 'remote': path,
- 'split': 'val_xsmall',
- 'shuffle': False,
- 'max_seq_len': max_seq_len,
- 'packing_ratio': packing_ratio,
- 'predownload': 1000,
- 'keep_zip': False,
- 'num_workers': None
- },
- 'mixture_of_denoisers': {
- 'decoder_only_format': decoder_only_format,
- 'span_mean_lengths_and_ratios': [[3, .15], [8, .5]],
- 'sequence_mask_ratios': 0.25,
- },
- 'drop_last': False,
- 'num_workers': 4,
- }
- cfg = om.create(cfg)
- device_batch_size = 2
-
- expected_keys = ['input_ids', 'attention_mask', 'labels']
- if decoder_only_format:
- expected_keys += ['bidirectional_mask']
- else:
- expected_keys += ['decoder_attention_mask', 'decoder_input_ids']
-
- if packing_ratio is not None:
- expected_keys += ['sequence_id']
-
- tokenizer = build_tokenizer(
- tokenizer_name=tokenizer_name,
- tokenizer_kwargs={'model_max_length': max_seq_len})
-
- loader = build_text_denoising_dataloader(cfg, tokenizer,
- device_batch_size).dataloader
- batch_ix = 0
- for batch in loader:
- for k in expected_keys:
- assert k in batch
- t = batch[k]
- assert t.shape[0] == device_batch_size
- assert t.shape[1] <= max_seq_len
- batch_ix += 1
- if batch_ix >= 5:
- break
-
-
@pytest.mark.parametrize('use_chat_formatting', [True, False])
@pytest.mark.parametrize('decoder_only_format', [True, False])
@pytest.mark.parametrize('allow_pad_trimming', [True, False])
@@ -373,9 +319,7 @@ def test_finetuning_dataloader(use_chat_formatting: bool,
device_batch_size = 2
expected_keys = ['input_ids', 'attention_mask', 'labels']
- if decoder_only_format:
- expected_keys += ['bidirectional_mask']
- else:
+ if not decoder_only_format:
expected_keys += ['decoder_attention_mask', 'decoder_input_ids']
loader = build_finetuning_dataloader(cfg, tokenizer,
@@ -443,19 +387,18 @@ def test_finetuning_dataloader_safe_load(hf_name: str,
@pytest.mark.parametrize('dataset_size', [4, 8])
@pytest.mark.parametrize('device_batch_size', [2, 4])
@pytest.mark.parametrize('drop_last', [True, False])
-@pytest.mark.parametrize('invalid_dataset', [True, False])
def test_finetuning_dataloader_small_data(dataset_size: int,
device_batch_size: int,
- drop_last: bool,
- invalid_dataset: bool):
+ drop_last: bool):
tokenizer_name = 'gpt2'
max_seq_len = 2048
tiny_dataset_folder_path = os.path.join(os.getcwd(), 'test-ift-data-small')
tiny_dataset_path = os.path.join(tiny_dataset_folder_path, 'train.jsonl')
if dist.get_global_rank() == 0:
- make_tiny_ft_dataset(path=tiny_dataset_path,
- size=dataset_size,
- add_bad_data_error=invalid_dataset)
+ make_tiny_ft_dataset(
+ path=tiny_dataset_path,
+ size=dataset_size,
+ )
cfg = {
'name': 'finetuning',
@@ -483,15 +426,10 @@ def test_finetuning_dataloader_small_data(dataset_size: int,
tokenizer_kwargs={'model_max_length': max_seq_len},
)
- expected_keys = ['input_ids', 'attention_mask', 'labels']
- expected_keys += ['bidirectional_mask']
-
error_context = contextlib.nullcontext()
if (dist.get_world_size() * device_batch_size > dataset_size) and drop_last:
- error_context = pytest.raises(ValueError, match='Your dataset')
- if invalid_dataset:
- error_context = pytest.raises(TypeError,
- match='Unable to tokenize example')
+ error_context = pytest.raises(NotEnoughDatasetSamplesError,
+ match='Your dataset')
with error_context:
_ = build_finetuning_dataloader(cfg, tokenizer, device_batch_size)
@@ -730,11 +668,22 @@ def test_finetuning_dataloader_is_valid_ift_example(
empty_labels_example)
-@pytest.mark.parametrize('add_bad_data_dropped', [True, False])
-@pytest.mark.parametrize('add_bad_data_error', [True, False])
+invalid_prompt_response_params = [
+ 'add_bad_data_dropped', 'add_invalid_prompt_type',
+ 'add_invalid_response_type', 'add_unknown_example_type',
+ 'add_too_many_example_keys'
+]
+
+
+@pytest.mark.parametrize(
+ ','.join(invalid_prompt_response_params),
+ generate_exclusive_test_params(invalid_prompt_response_params))
def test_malformed_data(
add_bad_data_dropped: bool,
- add_bad_data_error: bool,
+ add_invalid_prompt_type: bool,
+ add_invalid_response_type: bool,
+ add_too_many_example_keys: bool,
+ add_unknown_example_type: bool,
tmp_path: pathlib.Path,
):
tokenizer_name = 'mosaicml/mpt-7b'
@@ -759,7 +708,10 @@ def test_malformed_data(
path=tiny_dataset_path,
size=dataset_size,
add_bad_data_dropped=add_bad_data_dropped,
- add_bad_data_error=add_bad_data_error,
+ add_invalid_prompt_type=add_invalid_prompt_type,
+ add_invalid_response_type=add_invalid_response_type,
+ add_unknown_example_type=add_unknown_example_type,
+ add_too_many_example_keys=add_too_many_example_keys,
add_just_bos_eos_pad=True,
pad_token=tokenizer.pad_token,
start_token=tokenizer.bos_token,
@@ -787,19 +739,25 @@ def test_malformed_data(
cfg = om.create(cfg)
- expected_keys = ['input_ids', 'attention_mask', 'labels']
- expected_keys += ['bidirectional_mask']
-
error_context = contextlib.nullcontext()
- if add_bad_data_error:
- error_context = pytest.raises(TypeError,
- match='Unable to tokenize example')
+ if add_invalid_prompt_type:
+ error_context = pytest.raises(InvalidPromptTypeError,
+ match='Expected prompt to be')
+ if add_invalid_response_type:
+ error_context = pytest.raises(InvalidResponseTypeError,
+ match='Expected response to be')
+ if add_unknown_example_type:
+ error_context = pytest.raises(UnknownExampleTypeError,
+ match='Unknown example type')
+ if add_too_many_example_keys:
+ error_context = pytest.raises(TooManyKeysInExampleError,
+ match='Please specify exactly one.')
with error_context:
dl = build_finetuning_dataloader(cfg, tokenizer,
device_batch_size).dataloader
- if not add_bad_data_error:
+ if not any(invalid_prompt_response_params):
# +5 because we added samples with just bos/eos in each of prompt/response
expected_num_batches = (dataset_size + 5) // device_batch_size
@@ -810,6 +768,96 @@ def test_malformed_data(
assert actual_num_batches == expected_num_batches
+invalid_conversation_params = [
+ 'add_invalid_last_chat_message', 'add_invalid_message_key_quantity',
+ 'add_invalid_content_type', 'add_invalid_role', 'add_not_alternating_roles'
+]
+
+
+@pytest.mark.parametrize(
+ ','.join(invalid_conversation_params),
+ generate_exclusive_test_params(invalid_conversation_params))
+def test_malformed_conversation_data(tmp_path: pathlib.Path,
+ add_invalid_last_chat_message: bool,
+ add_invalid_message_key_quantity: bool,
+ add_invalid_content_type: bool,
+ add_invalid_role: bool,
+ add_not_alternating_roles: bool):
+ tokenizer_name = 'mosaicml/mpt-7b'
+ max_seq_len = 2048
+ dataset_size = 5
+ device_batch_size = 5
+ tiny_dataset_folder_path = tmp_path
+ tiny_dataset_path = str(tiny_dataset_folder_path / 'train.jsonl')
+
+ tokenizer = build_tokenizer(
+ tokenizer_name=tokenizer_name,
+ tokenizer_kwargs={'model_max_length': max_seq_len},
+ )
+ tokenizer.add_special_tokens({
+ 'pad_token': '',
+ 'bos_token': '',
+ 'eos_token': '',
+ })
+
+ if dist.get_global_rank() == 0:
+ make_tiny_conversation_ft_dataset(
+ path=tiny_dataset_path,
+ size=dataset_size,
+ add_invalid_last_chat_message=add_invalid_last_chat_message,
+ add_invalid_message_key_quantity=add_invalid_message_key_quantity,
+ add_invalid_content_type=add_invalid_content_type,
+ add_invalid_role=add_invalid_role,
+ add_not_alternating_roles=add_not_alternating_roles,
+ )
+
+ cfg = {
+ 'name': 'finetuning',
+ 'dataset': {
+ 'hf_name': str(tiny_dataset_folder_path),
+ 'split': 'train',
+ 'max_seq_len': max_seq_len,
+ 'decoder_only_format': True,
+ 'allow_pad_trimming': False,
+ 'packing_ratio': None,
+ 'shuffle': True,
+ },
+ 'drop_last': False,
+ 'num_workers': 0,
+ 'prefetch_factor': None,
+ 'pin_memory': False,
+ 'persistent_workers': False,
+ 'timeout': 0
+ }
+
+ cfg = om.create(cfg)
+
+ expected_keys = ['input_ids', 'attention_mask', 'labels']
+ expected_keys += ['bidirectional_mask']
+
+ error_context = contextlib.nullcontext()
+ if add_invalid_last_chat_message:
+ error_context = pytest.raises(InvalidLastChatMessageRoleError,
+ match='Invalid last message role:')
+ if add_invalid_message_key_quantity:
+ error_context = pytest.raises(IncorrectMessageKeyQuantityError,
+ match='Expected 2 keys in message')
+ if add_invalid_content_type:
+ error_context = pytest.raises(InvalidContentTypeError,
+ match='Expected content to be')
+ if add_invalid_role:
+ error_context = pytest.raises(InvalidRoleError,
+ match='Expected role to be one of')
+
+ if add_not_alternating_roles:
+ error_context = pytest.raises(ConsecutiveRepeatedChatRolesError,
+ match='Conversation roles must alternate')
+
+ with error_context:
+ build_finetuning_dataloader(cfg, tokenizer,
+ device_batch_size).dataloader
+
+
def test_finetune_dataloader_pure_pad_responses():
"""Test that dataloader can handle pure-pad responses."""
@@ -920,8 +968,8 @@ def test_token_counting_func(pad_token_id: int, batch_size: int,
@pytest.mark.parametrize('dataloader_type,tensor_input',
[('finetuning-hf', False),
- ('finetuning-streaming', False), ('denoising', False),
- ('text', True), ('text', False)])
+ ('finetuning-streaming', False), ('text', True),
+ ('text', False)])
@pytest.mark.parametrize('pad_token_id', [100, None])
@pytest.mark.parametrize('batch_size', [1, 8])
@pytest.mark.parametrize('model_max_length', [1024])
@@ -956,10 +1004,6 @@ def test_token_counting_func_dataloader_setting(
torch.tensor(b['input_ids']) for b in batch_tokenized
]
- if dataloader_type == 'denoising':
- expected_token_count += 2 * batch_size # for the two eos tokens
- expected_token_count += 5 * batch_size # for the corruption prefix tokens
-
if dataloader_type in {'finetuning-hf', 'finetuning-streaming'}:
for b in batch_tokenized:
b['labels'] = b['input_ids'].copy() # type: ignore
@@ -1031,30 +1075,6 @@ def test_token_counting_func_dataloader_setting(
monkeypatch.setattr('llmfoundry.data.text_data.StreamingTextDataset',
lambda *args, **kwargs: ds_mock)
dl = build_text_dataloader(cfg, gptt, batch_size)
- elif dataloader_type == 'denoising':
- cfg = DictConfig({
- 'name': 'text_denoising',
- 'dataset': {
- 'local': 'dummy-path',
- 'remote': 'dummy-path',
- 'split': 'val_xsmall',
- 'shuffle': False,
- 'max_seq_len': model_max_length,
- 'packing_ratio': None,
- 'predownload': 1000,
- 'keep_zip': False,
- 'num_workers': None
- },
- 'mixture_of_denoisers': {
- 'decoder_only_format': False,
- 'span_mean_lengths_and_ratios': None,
- 'sequence_mask_ratios': 0.25,
- },
- **common_args
- })
- monkeypatch.setattr('llmfoundry.data.denoising.StreamingTextDataset',
- lambda *args, **kwargs: MagicMock())
- dl = build_text_denoising_dataloader(cfg, gptt, batch_size)
else:
raise NotImplementedError()
diff --git a/tests/data/test_template_tokenization.py b/tests/data/test_template_tokenization.py
index e3eb1b54c1..a45c4d8f0d 100644
--- a/tests/data/test_template_tokenization.py
+++ b/tests/data/test_template_tokenization.py
@@ -159,7 +159,7 @@ def test_tokenize_instruct_example_malformed():
]
for example in malformed_prompt_response_examples:
- with pytest.raises(KeyError):
+ with pytest.raises(Exception):
tokenize_formatted_example(example, MagicMock())
diff --git a/tests/data_utils.py b/tests/data_utils.py
index 83cd1dacdc..3c077b5e71 100644
--- a/tests/data_utils.py
+++ b/tests/data_utils.py
@@ -20,8 +20,11 @@ def make_tiny_ft_dataset(
path: str,
size: int = 4,
add_bad_data_dropped: bool = False,
- add_bad_data_error: bool = False,
+ add_invalid_prompt_type: bool = False,
+ add_invalid_response_type: bool = False,
+ add_unknown_example_type: bool = False,
add_just_bos_eos_pad: bool = False,
+ add_too_many_example_keys: bool = False,
pad_token: Optional[str] = None,
start_token: Optional[str] = None,
end_token: Optional[str] = None,
@@ -40,18 +43,28 @@ def make_tiny_ft_dataset(
# empty response
samples.append({'prompt': 'hello', 'response': ''})
- if add_bad_data_error:
+ if add_invalid_prompt_type:
# prompt just None
samples.append({
'prompt': None,
'response': 'goodbye'
}) # type: ignore (intentional test)
+
+ if add_invalid_response_type:
# response just None
samples.append({
'prompt': 'hello',
'response': None
}) # type: ignore (intentional test)
+ if add_too_many_example_keys:
+ # too many keys
+ samples.append({
+ 'prompt': 'hello',
+ 'response': 'goodbye',
+ 'completion': 'bar'
+ })
+
if add_just_bos_eos_pad:
if pad_token is None or start_token is None or end_token is None:
raise ValueError(
@@ -67,6 +80,123 @@ def make_tiny_ft_dataset(
samples.append({'prompt': 'hello', 'response': end_token})
# prompt just pad
samples.append({'prompt': pad_token, 'response': 'goodbye'})
+ if add_unknown_example_type:
+ # unknown example type
+ samples = [{'foo': 'yee', 'bar': 'haw'}]
+
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ with open(path, 'w') as _f:
+ for sample in samples:
+ _f.write(json.dumps(sample))
+ _f.write('\n')
+
+
+def make_tiny_conversation_ft_dataset(
+ size: int,
+ path: str,
+ add_invalid_last_chat_message: bool = False,
+ add_invalid_message_key_quantity: bool = False,
+ add_invalid_content_type: bool = False,
+ add_invalid_role: bool = False,
+ add_not_alternating_roles: bool = False,
+):
+ if Path(path).suffix != '.jsonl':
+ raise ValueError(f'Path {path} must be a jsonl file.')
+ good_sample = {
+ 'messages': [{
+ 'role': 'system',
+ 'content': 'A conversation between a user and a helpful assistant.'
+ }, {
+ 'role': 'user',
+ 'content': "Hi there. What's the capital of the moon?"
+ }, {
+ 'role': 'assistant',
+ 'content': "This question doesn't make sense."
+ }]
+ }
+
+ samples = [good_sample] * size
+
+ if add_invalid_last_chat_message:
+ # invalid last chat message
+ samples.append({
+ 'messages': [{
+ 'role':
+ 'system',
+ 'content':
+ 'A conversation between a user and a helpful assistant.'
+ }, {
+ 'role': 'user',
+ 'content': "Hi there. What's the capital of the moon?"
+ }, {
+ 'role': 'system',
+ 'content': "This question doesn't make sense."
+ }]
+ })
+
+ if add_invalid_message_key_quantity:
+ # invalid message key quantity
+ samples.append({
+ 'messages': [{
+ 'role':
+ 'system',
+ 'content':
+ 'A conversation between a user and a helpful assistant.',
+ 'extra_key':
+ 'extra value'
+ }]
+ })
+
+ if add_invalid_role:
+ # invalid role
+ samples.append({
+ 'messages': [{
+ 'role':
+ 'system',
+ 'content':
+ 'A conversation between a user and a helpful assistant.'
+ }, {
+ 'role': 'foo',
+ 'content': "Hi there. What's the capital of the moon?"
+ }, {
+ 'role': 'assistant',
+ 'content': "This question doesn't make sense."
+ }]
+ })
+
+ if add_invalid_content_type:
+ # invalid conversation type
+ samples.append({
+ 'messages': [{
+ 'role':
+ 'system',
+ 'content':
+ 'A conversation between a user and a helpful assistant.'
+ }, {
+ 'role': 'user',
+ 'content': "Hi there. What's the capital of the moon?"
+ }, {
+ 'role': 'assistant',
+ 'content': None
+ }]
+ }) # type: ignore (intentional test)
+
+ if add_not_alternating_roles:
+ # not alternating roles
+ samples.append({
+ 'messages': [{
+ 'role':
+ 'system',
+ 'content':
+ 'A conversation between a user and a helpful assistant.'
+ }, {
+ 'role': 'assistant',
+ 'content': "Hi there. What's the capital of the moon?"
+ }, {
+ 'role': 'assistant',
+ 'content': "This question doesn't make sense."
+ }]
+ })
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'w') as _f:
diff --git a/tests/models/hf/test_hf_config.py b/tests/models/hf/test_hf_config.py
index d541f0a30c..d5de596199 100644
--- a/tests/models/hf/test_hf_config.py
+++ b/tests/models/hf/test_hf_config.py
@@ -6,6 +6,7 @@
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Mapping
+from unittest.mock import Mock, patch
import pytest
import torch
@@ -83,7 +84,7 @@ def test_tie_weights(tie_word_embeddings: bool):
},
{
'attn_config': {
- 'attn_impl': 'triton'
+ 'attn_impl': 'flash',
}
},
{
@@ -94,7 +95,7 @@ def test_tie_weights(tie_word_embeddings: bool):
{
'max_seq_len': 1024,
'attn_config': {
- 'attn_impl': 'triton'
+ 'attn_impl': 'flash',
},
'init_config': {
'emb_init_std': 5
@@ -104,11 +105,13 @@ def test_tie_weights(tie_word_embeddings: bool):
marks=pytest.mark.xfail(reason='"msl" is a ValueError',
strict=True)),
pytest.param({'attn_config': {
- 'attn_iml': 'triton'
+ 'attn_iml': 'flash'
}},
marks=pytest.mark.xfail(reason='"attn_impl" mispelled',
strict=True)),
])
+@patch('llmfoundry.models.layers.attention.is_flash_v2_installed',
+ new=Mock(return_value=True))
def test_hf_config_override(
model_cfg_overrides: Dict[str, Any],
conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml',
diff --git a/tests/models/hf/test_hf_mpt_gen.py b/tests/models/hf/test_hf_mpt_gen.py
index 1df553f126..917e970852 100644
--- a/tests/models/hf/test_hf_mpt_gen.py
+++ b/tests/models/hf/test_hf_mpt_gen.py
@@ -13,14 +13,14 @@
@pytest.mark.gpu
@pytest.mark.parametrize('device', ['cpu', 'gpu'])
-@pytest.mark.parametrize('attn_impl', ['triton', 'torch'])
+@pytest.mark.parametrize('attn_impl', ['flash', 'torch'])
def test_init_hfhub_mpt(
device: str,
attn_impl: str,
build_tiny_hf_mpt: Callable[..., ComposerHFCausalLM],
mpt_tokenizer: PreTrainedTokenizerBase,
):
- if device == 'cpu' and attn_impl == 'triton':
+ if device == 'cpu' and attn_impl == 'flash':
pytest.skip(f'{attn_impl=} not implemented for {device=}.')
composer_device = get_device(device)
diff --git a/tests/models/hf/test_hf_t5.py b/tests/models/hf/test_hf_t5.py
index 12ee0935e3..fb8689e665 100644
--- a/tests/models/hf/test_hf_t5.py
+++ b/tests/models/hf/test_hf_t5.py
@@ -18,8 +18,6 @@ def test_experimental_hf_t5():
},
'pretrained': False,
'init_device': 'cpu',
- 'z_loss': 0.0,
- 'adapt_vocab_for_denoising': False
})
tokenizer = transformers.T5Tokenizer.from_pretrained('t5-base')
diff --git a/tests/models/hf/test_hf_v_mpt.py b/tests/models/hf/test_hf_v_mpt.py
index b44c8d14c2..82b64ce80c 100644
--- a/tests/models/hf/test_hf_v_mpt.py
+++ b/tests/models/hf/test_hf_v_mpt.py
@@ -16,14 +16,9 @@
('flash', 0.0, False, 1, False),
('flash', 0.1, False, 1, False),
('torch', 0.0, False, 1, False),
- ('triton', 0.0, False, 1, False),
- ('triton', 0.1, False, 1, False),
('torch', 0.0, False, 0, False),
- ('triton', 0.0, False, 0, False),
- ('triton', 0.1, False, 0, False),
('flash', 0.0, False, None, True),
('torch', 0.0, False, None, True),
- ('triton', 0.0, False, None, True),
])
def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool,
mask_val: Optional[int], no_attn_mask: bool):
@@ -93,7 +88,7 @@ def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool,
# extract model cfg
model_cfg = cfg.model
- # use triton attn implementation
+ # use given attn implementation
model_cfg.attn_impl = attn_impl
model_cfg.alibi = alibi
# modify cfg for HF GPT2 compatibility
diff --git a/tests/models/layers/test_flash_attn.py b/tests/models/layers/test_flash_attn.py
index 622e9574ea..a3b17c36df 100644
--- a/tests/models/layers/test_flash_attn.py
+++ b/tests/models/layers/test_flash_attn.py
@@ -6,12 +6,9 @@
import pytest
import torch
-from llmfoundry.models.layers.attention import (attn_bias_shape,
- build_attn_bias,
- check_alibi_support,
- flash_attn_fn, gen_slopes,
- is_flash_v2_installed,
- triton_flash_attn_fn)
+from llmfoundry.models.layers.attention import (
+ attn_bias_shape, build_attn_bias, check_alibi_support, flash_attn_fn,
+ gen_slopes, is_flash_v2_installed, scaled_multihead_dot_product_attention)
from llmfoundry.models.mpt.modeling_mpt import gen_flash_attn_padding_info
@@ -239,7 +236,7 @@ def test_sliding_window(sliding_window_size: int):
torch.ones(seqlen_1, seqlen_1), diagonal=-(sliding_window_size + 1)).to(
dtype=dtype, device=device) * torch.finfo(attn_bias_2.dtype).min
attn_bias_2 = attn_bias_2 + window_mask_2
- output_2, _, _ = triton_flash_attn_fn(
+ output_2, _, _ = scaled_multihead_dot_product_attention(
query=query_2,
key=key_2,
value=value_2,
@@ -257,13 +254,15 @@ def test_sliding_window(sliding_window_size: int):
output_2.sum().backward()
- assert torch.allclose(output_1, output_2)
- assert torch.norm(query_2.grad - query_1.grad # type: ignore
- ) <= 1e-2 + 1e-2 * torch.norm(query_2.grad)
- assert torch.norm(key_2.grad - key_1.grad # type: ignore
- ) <= 1e-2 + 1e-2 * torch.norm(key_2.grad)
- assert torch.norm(value_2.grad - value_1.grad # type: ignore
- ) <= 1e-2 + 1e-2 * torch.norm(value_2.grad)
+ print(torch.max(output_1 - output_2))
+
+ _assert_approx_equal(output_1, output_2)
+ assert (query_2.grad is not None) and (query_1.grad is not None)
+ _assert_approx_equal(query_1.grad, query_2.grad)
+ assert (key_2.grad is not None) and (key_1.grad is not None)
+ _assert_approx_equal(key_1.grad, key_2.grad)
+ assert (value_2.grad is not None) and (value_1.grad is not None)
+ _assert_approx_equal(value_1.grad, value_2.grad)
@pytest.mark.gpu
@@ -322,17 +321,16 @@ def test_alibi_bias(n_heads: int):
def gen_bias():
causal = True
- bs = attn_bias_shape('triton',
+ bs = attn_bias_shape('torch',
n_heads,
seqlen_1,
True,
- prefix_lm=False,
use_sequence_id=False,
causal=causal)
attn_bias = torch.zeros(*bs, device=device)
attn_bias = build_attn_bias(
- 'triton',
+ 'torch',
attn_bias,
n_heads,
seqlen_1,
@@ -344,7 +342,7 @@ def gen_bias():
attn_bias_2 = gen_bias()
- output_2, _, _ = triton_flash_attn_fn(
+ output_2, _, _ = scaled_multihead_dot_product_attention(
query=query_2,
key=key_2,
value=value_2,
@@ -362,13 +360,14 @@ def gen_bias():
output_2.sum().backward()
- assert torch.allclose(output_1, output_2)
+ _assert_approx_equal(output_1, output_2)
assert (query_2.grad is not None) and (query_1.grad is not None)
- assert torch.norm(query_2.grad -
- query_1.grad) <= 1e-2 + 1e-2 * torch.norm(query_2.grad)
+ _assert_approx_equal(query_1.grad, query_2.grad)
assert (key_2.grad is not None) and (key_1.grad is not None)
- assert torch.norm(key_2.grad -
- key_1.grad) <= 1e-2 + 1e-2 * torch.norm(key_2.grad)
+ _assert_approx_equal(key_1.grad, key_2.grad)
assert (value_2.grad is not None) and (value_1.grad is not None)
- assert torch.norm(value_2.grad -
- value_1.grad) <= 1e-2 + 1e-2 * torch.norm(value_2.grad)
+ _assert_approx_equal(value_1.grad, value_2.grad)
+
+
+def _assert_approx_equal(value1: torch.Tensor, value2: torch.Tensor):
+ assert torch.norm(value2 - value1) <= 1e-2 + 1e-2 * torch.norm(value2)
diff --git a/tests/models/layers/test_flash_triton_torch.py b/tests/models/layers/test_flash_torch.py
similarity index 98%
rename from tests/models/layers/test_flash_triton_torch.py
rename to tests/models/layers/test_flash_torch.py
index 4e1efa3f34..c0e9f4b3b5 100644
--- a/tests/models/layers/test_flash_triton_torch.py
+++ b/tests/models/layers/test_flash_torch.py
@@ -23,9 +23,7 @@ def allclose_helper(t0: torch.Tensor,
@pytest.mark.gpu
@pytest.mark.parametrize('attn_impl_0, attn_impl_1', [
- ('flash', 'triton'),
('flash', 'torch'),
- ('triton', 'torch'),
])
@pytest.mark.parametrize('clip_qkv', [True, False])
@pytest.mark.parametrize('qk_ln, qk_gn', [
@@ -146,7 +144,6 @@ def gen_bias(attn_impl: str):
cfg.n_heads,
s,
alibi,
- prefix_lm=False,
use_sequence_id=attn_uses_sequence_id,
causal=causal)
if bs is not None:
@@ -291,7 +288,7 @@ def gen_bias(attn_impl: str):
@pytest.mark.gpu
-@pytest.mark.parametrize('attn_impl', ['flash', 'triton', 'torch'])
+@pytest.mark.parametrize('attn_impl', ['flash', 'torch'])
def test_vs_mha(attn_impl: str, device: str = 'cuda'):
"""Compare diff attn_impl to torch.nn.MultiheadAttention."""
from llmfoundry.models.layers import attention
@@ -388,7 +385,7 @@ def gen_tca_mask():
@pytest.mark.gpu
-@pytest.mark.parametrize('attn_impl', ['flash', 'triton', 'torch'])
+@pytest.mark.parametrize('attn_impl', ['flash', 'torch'])
@pytest.mark.parametrize('n_heads', [16, 8])
@pytest.mark.parametrize('kv_n_heads', [4, 2, 1])
def test_grouped_attention_heads(attn_impl: str,
diff --git a/tests/models/layers/test_huggingface_flash.py b/tests/models/layers/test_huggingface_flash.py
index dfd3b17f96..1e8ec2383d 100644
--- a/tests/models/layers/test_huggingface_flash.py
+++ b/tests/models/layers/test_huggingface_flash.py
@@ -3,105 +3,16 @@
import contextlib
import os
-from unittest.mock import patch
import pytest
-import torch
-import transformers
from composer.core.precision import get_precision_context
-from composer.utils import reproducibility
from omegaconf import OmegaConf as om
-from transformers.models.llama.modeling_llama import LlamaAttention
from llmfoundry.models.hf.hf_fsdp import rgetattr
from llmfoundry.models.layers.attention import is_flash_v2_installed
-from llmfoundry.models.layers.llama_attention_monkeypatch import (
- llama_attention_patch_torch, llama_attention_patch_triton)
from llmfoundry.utils.builders import build_composer_model, build_tokenizer
-@pytest.mark.parametrize('patch_fn_name', ['torch', 'triton'])
-@pytest.mark.parametrize('explicit_mask', [True, False])
-@pytest.mark.parametrize(
- 'model_name', ['meta-llama/Llama-2-7b-hf', 'meta-llama/Llama-2-70b-hf'])
-@pytest.mark.gpu
-def test_patch_equivalence(patch_fn_name: str, explicit_mask: bool,
- model_name: str):
- if 'HUGGING_FACE_HUB_TOKEN' not in os.environ:
- pytest.skip(
- 'The CI cluster does not have access to the Llama models, so skip this test.'
- )
-
- device = 'cuda:0'
- sequence_length = 64
- model_dim = 128 if '7b' in model_name else 256
- batch_size = 2
- if patch_fn_name == 'torch':
- patch_fn = llama_attention_patch_torch
- dtype = torch.float32
- atol = 0.0
- rtol = 0.0
- elif patch_fn_name == 'triton':
- # the huggingface implementation of llama performs the softmax in fp32
- # this can result in fairly large differences for the triton implementation
- # but the torch implementation produces the exact same output so we can confirm
- # the implementation is correct
- patch_fn = llama_attention_patch_triton
- dtype = torch.bfloat16
- atol = 1e-2
- rtol = 1e-2
- else:
- raise ValueError(f'Unknown patch_fn_name: {patch_fn_name}')
-
- llama_config = transformers.AutoConfig.from_pretrained(
- model_name, use_auth_token=True, hidden_size=model_dim)
-
- reproducibility.seed_all(42)
- attention = LlamaAttention(config=llama_config,)
- attention.to(dtype=dtype, device=device)
-
- rng = torch.Generator(device=device).manual_seed(42)
- hidden_states = torch.randn(batch_size,
- sequence_length,
- model_dim,
- generator=rng,
- dtype=dtype,
- device=device)
- causal_mask = torch.full((sequence_length, sequence_length),
- torch.finfo(torch.float32).min,
- device=device)
- causal_mask = causal_mask.triu(diagonal=1)
- causal_mask = causal_mask[None,
- None, :, :].expand(batch_size, 1, sequence_length,
- sequence_length)
- position_ids = torch.arange(sequence_length,
- dtype=torch.long,
- device=device)
- position_ids = position_ids[None, :].expand(batch_size, sequence_length)
-
- attn_output, _, _ = attention(
- hidden_states=hidden_states,
- attention_mask=causal_mask if explicit_mask else None,
- position_ids=position_ids,
- past_key_value=None,
- use_cache=False,
- )
-
- reproducibility.seed_all(42)
- with patch.object(LlamaAttention, 'forward', new=patch_fn):
- attention = LlamaAttention(config=llama_config,)
- attention.to(dtype=dtype, device=device)
- new_output, _, _ = attention(
- hidden_states=hidden_states,
- attention_mask=causal_mask if explicit_mask else None,
- position_ids=position_ids,
- past_key_value=None,
- use_cache=False,
- )
-
- assert torch.allclose(attn_output, new_output, atol=atol, rtol=rtol)
-
-
@pytest.mark.gpu
@pytest.mark.world_size(2)
@pytest.mark.parametrize('model_name', ['llama2', 'mistral'])
diff --git a/tests/models/test_model.py b/tests/models/test_model.py
index 79a5e4f98f..7244ddc8c2 100644
--- a/tests/models/test_model.py
+++ b/tests/models/test_model.py
@@ -26,7 +26,7 @@
from transformers.models.bloom.modeling_bloom import build_alibi_tensor
from llmfoundry import ComposerHFCausalLM
-from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
+from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP
from llmfoundry.models.layers import NORM_CLASS_REGISTRY, build_alibi_bias
from llmfoundry.models.layers.attention import (check_alibi_support,
is_flash_v2_installed)
@@ -108,7 +108,7 @@ def gen_random_batch(batch_size: int,
# default to only input ids
if inputs == None:
inputs = ['input_ids']
- # generate input batch of random data, suitable for a Causal or Prefix LM
+ # generate input batch of random data, suitable for a Causal LM
batch = {}
for inp in inputs:
if inp == 'input_ids':
@@ -128,8 +128,6 @@ def gen_random_batch(batch_size: int,
batch['attention_mask'] = torch.ones(size=(batch_size,
test_cfg.max_seq_len),
dtype=torch.int64).to(test_cfg.device)
- batch['bidirectional_mask'] = batch['attention_mask'].clone()
- batch['bidirectional_mask'][:, (test_cfg.max_seq_len // 2):] = 0
return batch
@@ -257,9 +255,7 @@ def test_attention_mechanism(batch_size: int = 2):
x = x + block.resid_ffn_dropout(n)
-@pytest.mark.parametrize('prefixlm', [False, True])
-def test_full_forward_and_backward_gpt2_small(prefixlm: bool,
- batch_size: int = 2):
+def test_full_forward_and_backward_gpt2_small(batch_size: int = 2):
warnings.filterwarnings(
action='ignore',
message='Torchmetrics v0.9 introduced a new argument class property')
@@ -270,11 +266,7 @@ def test_full_forward_and_backward_gpt2_small(prefixlm: bool,
device = 'cpu'
neo_cfg.device = device
neo_cfg.max_seq_len = 256
-
- if prefixlm:
- neo_cfg.model.name = 'hf_prefix_lm'
- else:
- neo_cfg.model.name = 'hf_causal_lm'
+ neo_cfg.model.name = 'hf_causal_lm'
tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(neo_cfg.tokenizer)
tokenizer = build_tokenizer(neo_cfg.tokenizer.name,
@@ -687,7 +679,7 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool,
@pytest.mark.gpu
-@pytest.mark.parametrize('attention_impl', ['flash', 'triton', 'torch'])
+@pytest.mark.parametrize('attention_impl', ['flash', 'torch'])
@pytest.mark.parametrize('pos_emb_config', [{
'alibi': True,
'rope': False
@@ -799,7 +791,6 @@ def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict):
@pytest.mark.parametrize('attention_impl', [
'torch',
pytest.param('flash', marks=pytest.mark.gpu),
- pytest.param('triton', marks=pytest.mark.gpu),
pytest.param('torch', marks=pytest.mark.gpu)
])
@pytest.mark.parametrize('pos_emb_config', [{
@@ -1002,65 +993,9 @@ def test_forward_with_padding(attention_impl: str, pos_emb_config: dict,
atol=pad_vs_unpad_atol)
-@pytest.mark.parametrize('attention_impl', ['torch', 'triton'])
-def test_advanced_mask_building(attention_impl: str):
- # Test that the correct attention mask is created when both
- # prefix_mask and sequence_id are used
- hf_config = MPTConfig(
- init_device='cpu',
- d_model=16,
- n_heads=1,
- n_layers=1,
- expansion_ratio=1,
- max_seq_len=256,
- emb_pdrop=0.0,
- resid_pdrop=0.0,
- attn_config={
- 'attn_impl': attention_impl,
- 'prefix_lm': True,
- 'attn_uses_sequence_id': True,
- 'alibi': False,
- },
- )
- mpt = MPTForCausalLM(hf_config)
- mpt.eval()
-
- prefix_mask = torch.ByteTensor([[1, 1, 0, 0, 1, 1, 1, 0]])
- sequence_id = torch.LongTensor([[0, 0, 0, 0, 1, 1, 1, 1]])
-
- attn_bias, _ = mpt.transformer._attn_bias(device=mpt.device,
- dtype=torch.float32,
- attention_mask=None,
- prefix_mask=prefix_mask,
- sequence_id=sequence_id)
-
- assert isinstance(attn_bias, torch.Tensor)
- assert attn_bias.shape == torch.Size([1, 1, 8, 8])
-
- # We'll construct the expected value of attn_bias and then compare.
- can_attend = torch.tensor([
- [1, 1, 0, 0, 0, 0, 0, 0],
- [1, 1, 0, 0, 0, 0, 0, 0],
- [1, 1, 1, 0, 0, 0, 0, 0],
- [1, 1, 1, 1, 0, 0, 0, 0],
- [0, 0, 0, 0, 1, 1, 1, 0],
- [0, 0, 0, 0, 1, 1, 1, 0],
- [0, 0, 0, 0, 1, 1, 1, 0],
- [0, 0, 0, 0, 1, 1, 1, 1],
- ])
- can_attend = can_attend.bool().view(1, 1, 8, 8)
- expected_attn_bias = torch.zeros_like(attn_bias)
- expected_attn_bias = expected_attn_bias.masked_fill(
- torch.logical_not(can_attend),
- torch.finfo(attn_bias.dtype).min)
-
- assert torch.equal(attn_bias, expected_attn_bias)
-
-
@pytest.mark.parametrize('attention_impl,precision', [
('torch', 'fp32'),
pytest.param('flash', 'amp_bf16', marks=pytest.mark.gpu),
- pytest.param('triton', 'amp_bf16', marks=pytest.mark.gpu),
pytest.param('torch', 'amp_bf16', marks=pytest.mark.gpu),
pytest.param('torch', 'fp32', marks=pytest.mark.gpu),
])
@@ -1313,7 +1248,6 @@ def test_save_from_pretrained(tmp_path: pathlib.Path):
@pytest.mark.parametrize('attn_impl', [
'torch',
pytest.param('flash', marks=pytest.mark.gpu),
- pytest.param('triton', marks=pytest.mark.gpu),
])
@pytest.mark.parametrize('pos_emb_config', [{
'alibi': False,
@@ -1447,7 +1381,6 @@ def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict):
@pytest.mark.parametrize('attn_impl', [
'torch',
pytest.param('flash', marks=pytest.mark.gpu),
- pytest.param('triton', marks=pytest.mark.gpu),
])
@pytest.mark.parametrize('pos_emb_config', [{
'alibi': False,
@@ -1586,7 +1519,6 @@ def test_forward_with_cache(attn_impl: str, pos_emb_config: dict,
@pytest.mark.parametrize('attn_impl', [
'torch',
pytest.param('flash', marks=pytest.mark.gpu),
- pytest.param('triton', marks=pytest.mark.gpu),
])
@pytest.mark.parametrize('pos_emb_config', [{
'alibi': False,
@@ -1685,7 +1617,6 @@ def test_generate_with_past_kv(attn_impl: str, pos_emb_config: dict,
@pytest.mark.parametrize('attn_impl', [
'torch',
pytest.param('flash', marks=pytest.mark.gpu),
- pytest.param('triton', marks=pytest.mark.gpu),
])
@pytest.mark.parametrize('generation_kwargs', [{
'max_new_tokens': 2,
@@ -1882,7 +1813,6 @@ def test_alibi_vs_hf():
@pytest.mark.parametrize('attn_impl', [
'torch',
pytest.param('flash', marks=pytest.mark.gpu),
- pytest.param('triton', marks=pytest.mark.gpu),
pytest.param('torch', marks=pytest.mark.gpu),
])
@pytest.mark.parametrize('pos_emb_config', [{
@@ -1915,7 +1845,7 @@ def test_forward_with_output_attentions_and_output_hidden_states(
attn_impl: str, pos_emb_config: dict):
if pos_emb_config['alibi'] and not check_alibi_support(attn_impl):
pytest.skip(f'flash attention below v2.4.2 does not support alibi.')
- if attn_impl in ['flash', 'triton']:
+ if attn_impl == 'flash':
pytest.skip(f'output_attentions only implemented with torch attention.')
if pos_emb_config['rope'] and pos_emb_config[
'rope_impl'] == 'dail' and not is_flash_v2_installed():
@@ -2039,7 +1969,7 @@ def test_hf_init(tmp_path: pathlib.Path,
prepare_fsdp_module(model, optimizer, fsdp_config, precision, device, False)
- model = HuggingFaceModelWithZLoss(model, tokenizer)
+ model = HuggingFaceModelWithFSDP(model, tokenizer)
batch = gen_random_batch(batch_size, test_cfg)
@@ -2058,7 +1988,7 @@ def test_hf_init(tmp_path: pathlib.Path,
@pytest.mark.gpu
-def test_head_dim_8_triton_mqa_attn(batch_size: int = 2):
+def test_head_dim_8_flash_mqa_attn(batch_size: int = 2):
test_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml')
test_cfg.device = torch.cuda.current_device()
@@ -2074,7 +2004,7 @@ def test_head_dim_8_triton_mqa_attn(batch_size: int = 2):
emb_pdrop=0.1,
resid_pdrop=0.2,
attn_config={
- 'attn_impl': 'triton',
+ 'attn_impl': 'flash',
'attn_type': 'multiquery_attention'
},
)
@@ -2086,7 +2016,7 @@ def test_head_dim_8_triton_mqa_attn(batch_size: int = 2):
mpt = MPTForCausalLM(hf_config)
- model = HuggingFaceModelWithZLoss(mpt, tokenizer, shift_labels=True)
+ model = HuggingFaceModelWithFSDP(mpt, tokenizer, shift_labels=True)
model = model.to(test_cfg.device)
batch = gen_random_batch(batch_size, test_cfg)
diff --git a/tests/models/test_mpt_gen.py b/tests/models/test_mpt_gen.py
index 9f022ef487..35f130cd46 100644
--- a/tests/models/test_mpt_gen.py
+++ b/tests/models/test_mpt_gen.py
@@ -28,7 +28,6 @@ def forward(
input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
- prefix_mask: Optional[torch.ByteTensor] = None,
sequence_id: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
@@ -38,7 +37,7 @@ def forward(
inputs_embeds: Optional[torch.FloatTensor] = None,
):
result = super().forward(input_ids, past_key_values, attention_mask,
- prefix_mask, sequence_id, labels, return_dict,
+ sequence_id, labels, return_dict,
output_attentions, output_hidden_states,
use_cache, inputs_embeds)
# Modify the logits to select the next token.
@@ -53,7 +52,7 @@ def forward(
@pytest.mark.world_size(2)
@pytest.mark.gpu
-@pytest.mark.parametrize('attn_impl', ['triton', 'torch'])
+@pytest.mark.parametrize('attn_impl', ['flash', 'torch'])
@pytest.mark.parametrize('use_alibi', [True, False])
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
@patch('llmfoundry.models.mpt.modeling_mpt.MPTForCausalLM',
@@ -93,7 +92,7 @@ def test_mpt_generate_multi_gpu(attn_impl: str, use_alibi: bool,
@pytest.mark.gpu
-@pytest.mark.parametrize('attn_impl', ['triton', 'torch'])
+@pytest.mark.parametrize('attn_impl', ['flash', 'torch'])
@pytest.mark.parametrize('use_alibi', [True, False])
def test_mpt_generate_callback(attn_impl: str, use_alibi: bool,
build_tiny_mpt: Callable[...,
@@ -144,7 +143,7 @@ def test_mpt_generate_callback(attn_impl: str, use_alibi: bool,
@pytest.mark.gpu
-@pytest.mark.parametrize('attn_impl', ['triton', 'torch'])
+@pytest.mark.parametrize('attn_impl', ['flash', 'torch'])
@pytest.mark.parametrize('use_alibi', [True, False])
def test_mpt_generate_callback_not_tied(
use_alibi: bool, attn_impl: str,
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100644
index 0000000000..89517e64ff
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,27 @@
+# Copyright 2024 MosaicML LLM Foundry authors
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import List
+
+import pytest
+
+
+def generate_exclusive_test_params(param_names: List[str]):
+ """Generates pytest.param objects with one true parameter for testing.
+
+ Creates pytest.param objects for each parameter name given. For each
+ param object, one parameter is set to True (indicating a test case for
+ malformed data) while the rest are set to False.
+
+ Args:
+ param_names (List[str]): The names of parameters to create test cases for.
+
+ Yields:
+ pytest.param: Each with one parameter set to True, indicating the specific case being tested.
+ """
+ for _, name in enumerate(param_names):
+ params = {param_name: False for param_name in param_names}
+ params[name] = True
+ param_values = list(params.values())
+ param_id = f'{name}=True'
+ yield pytest.param(*param_values, id=param_id)