diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml
index cf6b85b193..144e3f1ad3 100644
--- a/.github/workflows/release.yaml
+++ b/.github/workflows/release.yaml
@@ -32,14 +32,6 @@ jobs:
PYPI_PACKAGE_NAME="llm-foundry-test-$(date +%Y%m%d%H%M%S)"
fi
- # Remove the peft, xentropy-cuda-lib and triton-pre-mlir dependencies as PyPI does not
- # support direct installs. The error message for importing PEFT, FusedCrossEntropy,
- # and flash_attn_triton gives instructions on how to install if a user tries to use it
- # without this dependency.
- sed '/xentropy-cuda-lib@git+https:\/\/github.com\/HazyResearch\/flash-attention.git@.*/d' -i setup.py
- sed '/triton-pre-mlir@git+https:\/\/github.com\/vchiley\/triton.git@.*/d' -i setup.py
- sed '/peft@git+https:\/\/github.com\/huggingface\/peft.git.*/d' -i setup.py
-
python -m pip install --upgrade build twine
python -m build
twine check --strict dist/*
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 62bc853fb5..3551482244 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,5 @@
default_language_version:
python: python3
-exclude: llmfoundry/models/layers/flash_attn_triton.py
repos:
- repo: https://github.com/google/yapf
rev: v0.32.0
diff --git a/README.md b/README.md
index aaf4b70c35..e8ea026be8 100644
--- a/README.md
+++ b/README.md
@@ -184,7 +184,6 @@ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.o
**Lastly**, install the ROCm enabled flash attention (instructions [here](https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm2#amd-gpurocm-support)).
Notes:
-1. `attn_impl: triton` does not work.
1. We don't yet have a Docker image where everything works perfectly. You might need to up/downgrade some packages (in our case, we needed to downgrade to `numpy==1.23.5`) before everything works without issue.
### Intel Gaudi
diff --git a/TUTORIAL.md b/TUTORIAL.md
index 8fa0e41a92..3be4910c4f 100644
--- a/TUTORIAL.md
+++ b/TUTORIAL.md
@@ -32,10 +32,8 @@ This tutorial will provide a brief intro to the repo’s structure and underlyin
- [What hardware can I run eval on?](#what-hardware-can-i-run-eval-on)
- [What hardware can I run inference on?](#what-hardware-can-i-run-inference-on)
- [What is FSDP?](#what-is-fsdp)
- - [What are the different attention options `torch` / `flash` / `triton` for MPT and which one should I use?](#what-are-the-different-attention-options-torch--flash--triton--for-mpt-and-which-one-should-i-use)
+ - [What are the different attention options `torch` / `flash` for MPT and which one should I use?](#what-are-the-different-attention-options-torch--flash--for-mpt-and-which-one-should-i-use)
- [Limitations](#limitations)
- - [What is `triton-pre-mlir`?](#what-is-triton-pre-mlir)
- - [Known issue with sm86+ GPUs](#known-issue-with-sm86-gpus)
- [Support for FlashAttention-2](#support-for-flashattention-2)
- [What kinds of positional embeddings does LLM Foundry support?](#what-kinds-of-positional-embeddings-does-llm-foundry-support)
- [Can I finetune using PEFT / LoRA?](#can-i-finetune-using-peft--lora)
@@ -144,7 +142,7 @@ name = 'mosaicml/mpt-7b'
# Download config
config = AutoConfig.from_pretrained(name, trust_remote_code=True)
-# (Optional) Use `flash` (preferred) or `triton` backend for fast attention. Defaults to `torch`.
+# (Optional) Use `flash` (preferred) backend for fast attention. Defaults to `torch`.
# config.attn_config['attn_impl'] = 'flash'
# (Optional) Change the `max_seq_len` allowed for inference
# config.max_seq_len = 4096
@@ -291,7 +289,7 @@ The purpose of this section is probably pretty self-evident. You’ve got questi
- If OOMs persist with `device_train_microbatch_size: 1` and `device_eval_batch_size: 1`, you may need to use activation checkpointing `fsdp_config.activation_checkpointing: true` (if you are not already) and, as a last resort, activation CPU offloading `fsdp_config.activation_cpu_offload: true`.
### What hardware can I train on?
-- In general, this repo should work on any system with NVIDIA GPUs. Checkout the `scripts/train/README.md` for more [details on GPU memory requirements]([https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#how-many-gpus-do-i-need-to-train-a-llm](https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#how-many-gpus-do-i-need-to-train-a-llm)). We recommend using `Flash` attention instead of `Triton` attention, unless you're training Prefix Language Models (in which case use `Triton`). Keep in mind you may run into issues with `Flash` or `Triton` support on some GPU types. In that situation, you can fall back to `attn_impl: torch`, or raise an issue in the [Flash Attention github repo](https://github.com/Dao-AILab/flash-attention).
+- In general, this repo should work on any system with NVIDIA GPUs. Checkout the `scripts/train/README.md` for more [details on GPU memory requirements]([https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#how-many-gpus-do-i-need-to-train-a-llm](https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#how-many-gpus-do-i-need-to-train-a-llm)). We recommend using `Flash` attention. Keep in mind you may run into issues with `Flash` support on some GPU types. In that situation, you can fall back to `attn_impl: torch`, or raise an issue in the [Flash Attention github repo](https://github.com/Dao-AILab/flash-attention).
### What hardware can I run eval on?
- Similar to above…
@@ -302,8 +300,8 @@ The purpose of this section is probably pretty self-evident. You’ve got questi
### What is FSDP?
- [Fully Sharded Data Parallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) is a PyTorch implementation of the [Zero Redundancy Optimizer (ZeRO)](https://arxiv.org/abs/1910.02054). FSDP shards networks parameters and the optimizer state across all GPUs. This enables users to train models with large parameter counts which do not fit into a single GPUs memory.
-### What are the different attention options `torch` / `flash` / `triton` for MPT and which one should I use?
-- **Short answer:** `torch` is the native pytorch attention implementation, and `flash` and `triton` are different implementations of the much more optimized [Flash Attention](https://arxiv.org/abs/2205.14135) method. `triton` and `flash` will be faster (and use less GPU memory) than `torch`, but they might not work with all hardware and environment setups.
+### What are the different attention options `torch` / `flash` for MPT and which one should I use?
+- **Short answer:** `torch` is the native pytorch attention implementation, and `flash` is an implementation of the much more optimized [Flash Attention](https://arxiv.org/abs/2205.14135) method. `flash` will be faster (and use less GPU memory) than `torch`, but they might not work with all hardware and environment setups.
Our training setups typically use `flash`.
@@ -313,7 +311,6 @@ Furthermore, integrating a recomputation schema decreases the sequence length me
- Setting `attn_config.attn_impl=torch` enables a naive Softmax Attention written using base torch operations.
- Setting `attn_config.attn_impl=flash` enables Flash Attention [implemented by Dao et al in the Dao-AILab repo using CUDA](https://github.com/Dao-AILab/flash-attention). This will have linear memory complexity (enabling larger batch sizes) and will run much faster.
- - Setting `attn_config.attn_impl=triton` enables a Flash Attention [implemented using Triton](https://github.com/mosaicml/llm-foundry/blob/main/llmfoundry/models/layers/flash_attn_triton.py). We recommend using `flash` attention instead of `triton` attention, unless you're training Prefix Language Models (in which case use `Triton`).
+The majority of our training setups use `flash`. -->
#### Limitations
- For training, `torch` uses a lot of memory and is slow.
-- `flash` and `triton` cannot return attention weights and therefore cannot be used with methods that require it.
+- `flash` cannot return attention weights and therefore cannot be used with methods that require it.
- `flash` cannot accept an attention bias. However, it still allows the use of ALiBi positional bias.
-#### What is `triton-pre-mlir`?
-- Torch2 installs and requires a specific version of [Triton](https://openai.com/research/triton).
- `attn_config.attn_impl=triton` requires a different version of triton.
- As a result, you can either use torch2 or `attn_impl=triton`.
- To enable both, we fork triton and make it pip installable as `triton-pre-mlir`.
- `attn_impl=triton` can then use `triton-pre-mlir` leaving the version of triton required for torch2 intact.
-
-#### Known issue with sm86+ GPUs
-- Under the hood, part of `triton-pre-mlir` compile path uses LLVM11.
- H100 GPUs (sm90 GPUs) are not formally supported until LLVM15 (technically it doesn't support anything sm86+).
- Updating the LLVM version used by `triton-pre-mlir` to LLVM13 seems to be relatively easy.
- Updating to LLVM14 (or LLVM15) cannot be done because there are breaking changes.
- What is the result of this? Although sm89+ is not **formally** supported until LLVM15, our testing on H100 GPUs shows that `attn_impl=triton` still works well and still runs fast. The only issue is that when the network is starting to run, LLVM might throw a warning like: `'sm_90' is not a recognized processor for this target (ignoring processor)`. This warning does not seem to affect performance.
-
#### Support for FlashAttention-2
- [FlashAttention-2](https://arxiv.org/pdf/2307.08691.pdf) improves upon FlashAttention to get even faster attention computation. LLM Foundry supports FlashAttention-2. Please follow the instructions [here](https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#flashattention).
@@ -352,8 +334,8 @@ Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706.
| Name | YAML Config | Training MFU on MPT-7B trained on 8 A100 80GB GPUs | Notes |
|:-----------------------------------|:------------------------------------------------------------------|:---------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Learned Positional Embeddings |
model:
learned_pos_emb: True
| 65.7 | |
-| ALiBi | model:
attn_config:
alibi: True
| 64.5 | Requires Flash (v2.4.2 or higher) or Triton or Torch attention. |
-| RoPE (Dao-AILab Implementation) | model:
attn_config:
rope: True
rope_impl: dail
| 64.5 | Requires a CUDA GPU and the [flash-attn library](https://github.com/Dao-AILab/flash-attention) v2.0.1 or higher to be installed. Please see the instructions in the [paragraph above](#support-for-flashattention-2) on how to install flash-attn v2. Note that the attention implementation can still be `torch`, `triton`, or `flash`. |
+| ALiBi | model:
attn_config:
alibi: True
| 64.5 | Requires Flash (v2.4.2 or higher) or Torch attention. |
+| RoPE (Dao-AILab Implementation) | model:
attn_config:
rope: True
rope_impl: dail
| 64.5 | Requires a CUDA GPU and the [flash-attn library](https://github.com/Dao-AILab/flash-attention) v2.0.1 or higher to be installed. Please see the instructions in the [paragraph above](#support-for-flashattention-2) on how to install flash-attn v2. Note that the attention implementation can still be `torch` or `flash`. |
| RoPE (Hugging
Face Implementation) | model:
attn_config:
rope: True
rope_impl: hf
| 62.3 | |
### Can I finetune using PEFT / LoRA?
diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py
index ea9eba74ce..922f738e9a 100644
--- a/llmfoundry/__init__.py
+++ b/llmfoundry/__init__.py
@@ -20,15 +20,13 @@
hf_dynamic_modules_logger.addFilter(new_files_warning_filter)
from llmfoundry import algorithms, callbacks, loggers, optim, registry, utils
-from llmfoundry.data import (ConcatTokensDataset, MixtureOfDenoisersCollator,
- NoConcatDataset, Seq2SeqFinetuningCollator,
- build_finetuning_dataloader,
- build_text_denoising_dataloader)
-from llmfoundry.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM,
- ComposerHFT5)
+from llmfoundry.data import (ConcatTokensDataset, NoConcatDataset,
+ Seq2SeqFinetuningCollator,
+ build_finetuning_dataloader)
+from llmfoundry.models.hf import ComposerHFCausalLM, ComposerHFT5
from llmfoundry.models.layers.attention import (
MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
- flash_attn_fn, scaled_multihead_dot_product_attention, triton_flash_attn_fn)
+ flash_attn_fn, scaled_multihead_dot_product_attention)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig,
@@ -36,9 +34,7 @@
from llmfoundry.tokenizers import TiktokenTokenizerWrapper
__all__ = [
- 'build_text_denoising_dataloader',
'build_finetuning_dataloader',
- 'MixtureOfDenoisersCollator',
'Seq2SeqFinetuningCollator',
'MPTBlock',
'FFN_CLASS_REGISTRY',
@@ -50,11 +46,9 @@
'MPTForCausalLM',
'ComposerMPTCausalLM',
'ComposerHFCausalLM',
- 'ComposerHFPrefixLM',
'ComposerHFT5',
'scaled_multihead_dot_product_attention',
'flash_attn_fn',
- 'triton_flash_attn_fn',
'MultiheadAttention',
'NoConcatDataset',
'ConcatTokensDataset',
@@ -70,4 +64,4 @@
'registry',
]
-__version__ = '0.6.0'
+__version__ = '0.7.0'
diff --git a/llmfoundry/data/__init__.py b/llmfoundry/data/__init__.py
index b64f54cfa3..45d1f8237f 100644
--- a/llmfoundry/data/__init__.py
+++ b/llmfoundry/data/__init__.py
@@ -3,8 +3,6 @@
from llmfoundry.data.data import ConcatTokensDataset, NoConcatDataset
from llmfoundry.data.dataloader import build_dataloader
-from llmfoundry.data.denoising import (MixtureOfDenoisersCollator,
- build_text_denoising_dataloader)
from llmfoundry.data.finetuning import (Seq2SeqFinetuningCollator,
build_finetuning_dataloader)
from llmfoundry.data.text_data import (StreamingTextDataset,
@@ -12,12 +10,9 @@
from llmfoundry.registry import dataloaders
dataloaders.register('text', func=build_text_dataloader)
-dataloaders.register('text_denoising', func=build_text_denoising_dataloader)
dataloaders.register('finetuning', func=build_finetuning_dataloader)
__all__ = [
- 'MixtureOfDenoisersCollator',
- 'build_text_denoising_dataloader',
'Seq2SeqFinetuningCollator',
'build_finetuning_dataloader',
'StreamingTextDataset',
diff --git a/llmfoundry/data/denoising.py b/llmfoundry/data/denoising.py
deleted file mode 100644
index 303c9298bb..0000000000
--- a/llmfoundry/data/denoising.py
+++ /dev/null
@@ -1,957 +0,0 @@
-# Copyright 2022 MosaicML LLM Foundry authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""Streaming dataloader for (mixture of) denoising task(s)."""
-
-import logging
-import random
-import sys
-import warnings
-from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union
-
-import numpy as np
-import torch
-from composer.core.data_spec import DataSpec
-from omegaconf import DictConfig
-from omegaconf import OmegaConf as om
-from torch.utils.data import DataLoader
-from transformers import PreTrainedTokenizerBase
-
-from llmfoundry.data.packing import BinPackCollator
-from llmfoundry.data.text_data import (StreamingTextDataset,
- get_tokens_per_batch_func)
-from llmfoundry.models import utils
-from llmfoundry.utils.warnings import VersionedDeprecationWarning
-
-__all__ = ['MixtureOfDenoisersCollator', 'build_text_denoising_dataloader']
-
-log = logging.getLogger(__name__)
-
-# HuggingFace hardcodes the ignore index to -100
-_HF_IGNORE_INDEX = -100
-
-# Required signature of any `prefix_function` (see below)
-PREFIX_FUNCTION = Callable[[float, Optional[float], PreTrainedTokenizerBase],
- Sequence[int]]
-
-
-def ul2_prefix_function(
- mask_ratio: float,
- mean_length: Optional[float],
- tokenizer: PreTrainedTokenizerBase,
-) -> Sequence[int]:
- """Generates prefixes based on UL2 paper.
-
- See: http://arxiv.org/abs/2205.05131
- """
- if mean_length is None:
- # This is the case for "sequence to sequence"
- prefix = '[S2S]' if mask_ratio < 1.0 else '[CLM]'
- elif mean_length >= 12 or mask_ratio >= 0.3:
- # UL2 tags this corruption rate "extreme"
- prefix = '[NLG]'
- else:
- # UL2 tags this corruption rate as "regular"
- prefix = '[NLU]'
- return tokenizer(prefix, add_special_tokens=False).input_ids
-
-
-class MixtureOfDenoisersCollator:
- """Data collator for mixture of span-corruption denoisers, as in UL2.
-
- This collator supports a variety of tasks used to pre-train an
- encoder-decoder model or a (prefix LM) decoder-only model. This is meant
- to be used with a dataset that yields tokenized text sequences. It is not
- required that the token sequences are already padded or truncate, as this
- collator will internally truncate and pad as needed.
-
- For the denoising mixture recommended in the original UL2 paper,
- http://arxiv.org/abs/2205.05131, use:
- .. python:
- MixtureOfDenoisersCollator(
- ...,
- span_mean_lengths_and_ratios=[
- [3, .15],
- [8, .15],
- [3, .50],
- [8, .50],
- [64, .15],
- [64, .50],
- ],
- sequence_mask_ratios=0.25
- )
-
- Args:
- tokenizer (transformers.PreTrainedTokenizer): The tokenizer used to
- prepare the data from raw text. Any missing sentinel tokens will
- be added by the collator.
- max_seq_length (int): The maximum length of sequences produced by this
- collator. Incoming sequences may be truncated to accommodate this
- limit.
- Note that when formatting for decoder-only models, the context
- tokens and target tokens are concatenated, and max_seq_length
- applies to their combined length. For encoder-decoder models, both
- the encoder and decoder will see up to max_seq_length tokens.
- decoder_only_format (bool, optional): Whether to format the batches
- for a decoder-only model (i.e. a prefix LM) or, if ``False``, an
- encoder-decoder model. Default: ``False``.
- span_mean_lengths_and_rations (optional): A length-2 list of a
- ``[mean_length, mask_ratio]`` pair, or a list of such pairs. Each
- pair adds a span corruption denoising task to the task mixture. For
- example, ``[3, 0.15]`` adds the original span corruption task used
- for pre-training a T5 model as in http://arxiv.org/abs/1910.10683,
- which trained with a single span corruption task that used a mean
- span length of 3 and a mask ratio of 15%.
- Default: ``None`` does not add any span corruption tasks.
- sequence_mask_ratios (optional): A float or list of floats, one for each
- sequence corruption denoising task to add to the task mixture. Each
- sequence mask ratio must be greater than 0.0 and at most 1.0.
- This type of task is a special instance of span corruption, with
- exactly one masked span take from the end of the sequence. The
- length of the span is sampled uniformly such that the average
- portion of masked tokens equals sequence_mask_ratio.
- Note: A value of 1.0 essentially yields causal LM examples.
- Default: ``None` does not add any sequence corruption tasks.
- allow_pad_trimming (bool, optional): Whether to allow the collator to
- trim away sequence regions that are entirely padding (i.e. padding
- for each example in the batch). If ``True``, shorter sequences may
- improve throughput but at a potentially higher memory cost
- owing to variable sequence lengths from batch to batch.
- Default: ``False`` yields batches that are always padded to
- max_seq_length.
- prefix_function (callable, optional): A function that maps denoising
- task parameters (e.g. mean_length=3, mask_ratio=0.15) to a prefix
- that will be added to sequences when the associated "noiser" is
- applied.
- To disable these prefixes, use a value of ``None``.
- Default: :func:`ul2_prefix_function` applies the prefix scheme
- suggested in the UL2 paper: http://arxiv.org/abs/2205.05131.
- context_eos (bool, optional): Whether to attach an EOS token to the end
- of the context sequence, marking the transition from context to
- target sequence. Only applicable if decoder_only_format is True.
- Context EOS tokens are always added for encoder-decoder format.
- Default: ``False`` does not attach context EOS.
- """
-
- def __init__(
- self,
- tokenizer: PreTrainedTokenizerBase,
- max_seq_length: int,
- decoder_only_format: bool = False,
- span_mean_lengths_and_ratios: Optional[List] = None,
- sequence_mask_ratios: Optional[Union[List[float], float]] = None,
- allow_pad_trimming: bool = False,
- prefix_function: Optional[PREFIX_FUNCTION] = ul2_prefix_function,
- context_eos: Optional[bool] = None,
- ):
- # Prepare the tokenizer for denoising tasks
- utils.adapt_tokenizer_for_denoising(tokenizer)
-
- self.tokenizer = tokenizer
- self.max_seq_length = max_seq_length
- self.decoder_only_format = decoder_only_format
- self._sentinel_token_ids = np.array(self.tokenizer.sentinel_token_ids)
-
- # Trimming will always be skipped on at least the first __call__
- self._allow_pad_trimming = allow_pad_trimming
- self._seen_first_batch = False
-
- self.context_eos = bool(context_eos) if decoder_only_format else True
-
- # Process the span_mean_lengths_and_ratios argument
- if span_mean_lengths_and_ratios is None:
- # In this case, there are no span corruption tasks
- self.span_mean_lengths_and_ratios = []
- elif isinstance(span_mean_lengths_and_ratios[0], (int, float)):
- # In this case, there is one span corruption task
- if not len(span_mean_lengths_and_ratios) == 2:
- raise ValueError('`span_mean_lengths_and_ratios` must be a ' + \
- 'pair of [mean_length, mask_ratio], a list ' + \
- f'of such pairs, or None. Got {span_mean_lengths_and_ratios}.')
- self.span_mean_lengths_and_ratios = [span_mean_lengths_and_ratios]
- else:
- # In this case, there are one or more span corruption tasks
- span_mean_lengths_and_ratios = list(span_mean_lengths_and_ratios)
- for spec_pair in span_mean_lengths_and_ratios:
- if len(spec_pair) != 2:
- raise ValueError('`span_mean_lengths_and_ratios` must be a ' + \
- 'pair of [mean_length, mask_ratio], a list ' + \
- f'of such pairs, or None. Got {span_mean_lengths_and_ratios}.')
- self.span_mean_lengths_and_ratios = span_mean_lengths_and_ratios
-
- # Process the sequence_mask_ratios argument
- if sequence_mask_ratios is None:
- # In this case, there are no sequence corruption tasks
- self.sequence_mask_ratios = []
- elif isinstance(sequence_mask_ratios, float):
- # In this case, there is one sequence corruption task
- self.sequence_mask_ratios = [sequence_mask_ratios]
- else:
- # In this case, there is one or more sequence corruption tasks
- for ratio in sequence_mask_ratios:
- if not (0 < ratio <= 1.0):
- raise ValueError('`sequence_mask_ratios` must be a float (or list '+\
- 'of floats) that are each >0.0 and <=1.0, or None. '+\
- f'Got {sequence_mask_ratios}.')
- self.sequence_mask_ratios = sequence_mask_ratios
-
- # Populate the noisers so we can learn to denoise them!
- self._noisers = []
- self._smallest_max_raw_length = self.max_seq_length * 100
- self._largest_max_raw_length = 0
- self._uses_span_corruption = False
-
- # Add "noisers" for any span corruption denoising tasks
- # Each mean_length / mask_ratio combo becomes one of the span
- # corruption denoising tasks
- for span_mean_length, span_mask_ratio in self.span_mean_lengths_and_ratios:
- self._uses_span_corruption = True
- if span_mean_length < 0:
- raise ValueError('All span mean lengths must be positive.')
- if not 0 < span_mask_ratio < 1.0:
- raise ValueError(
- 'All span masking ratios must be between 0.0 and 1.0.')
-
- if prefix_function is not None:
- prefix_tokens = prefix_function(span_mask_ratio,
- span_mean_length,
- self.tokenizer)
- else:
- prefix_tokens = None
-
- max_raw_length = _get_max_starting_length(
- max_length=self.max_seq_length,
- mask_ratio=span_mask_ratio,
- mean_span_length=span_mean_length,
- n_prefix_tokens=len(prefix_tokens or []),
- decoder_only_format=self.decoder_only_format,
- context_eos=self.context_eos)
- if max_raw_length < self._smallest_max_raw_length:
- self._smallest_max_raw_length = max_raw_length
- if max_raw_length > self._largest_max_raw_length:
- self._largest_max_raw_length = max_raw_length
-
- kwargs = {
- 'mean_span_length': span_mean_length,
- 'mask_ratio': span_mask_ratio,
- 'prefix_tokens': prefix_tokens,
- 'max_raw_length': max_raw_length,
- }
- self._noisers.append(kwargs)
-
- # Add "noisers" for any sequential denoising tasks
- for sequence_mask_ratio in self.sequence_mask_ratios:
- if prefix_function is not None:
- prefix_tokens = prefix_function(sequence_mask_ratio, None,
- self.tokenizer)
- else:
- prefix_tokens = None
-
- max_raw_length = self.max_seq_length - len(prefix_tokens or []) - 1
- if decoder_only_format and self.context_eos:
- max_raw_length = max_raw_length - 1
-
- if not self._uses_span_corruption and (
- max_raw_length < self._smallest_max_raw_length):
- # We choose not to count sequence denoising in the smallest
- # unless there is only sequence denoising.
- self._smallest_max_raw_length = max_raw_length
- if max_raw_length > self._largest_max_raw_length:
- self._largest_max_raw_length = max_raw_length
-
- kwargs = {
- 'mean_span_length': None,
- 'mask_ratio': sequence_mask_ratio,
- 'prefix_tokens': prefix_tokens,
- 'max_raw_length': max_raw_length,
- }
- self._noisers.append(kwargs)
-
- if not self._noisers:
- raise ValueError(
- 'No denoising tasks were included. Make sure to set ' + \
- '`span_mean_lengths_and_ratios` and/or `sequence_mask_ratios`.')
-
- @property
- def smallest_max_raw_length(self) -> int:
- return int(self._smallest_max_raw_length)
-
- @property
- def largest_max_raw_length(self) -> int:
- return int(self._largest_max_raw_length)
-
- def __call__(self, examples: List[Dict[str,
- Any]]) -> Dict[str, torch.Tensor]:
- """Batch examples processed by the span corrupter."""
- processed_examples = []
- for example in examples:
- # Randomly pick a "noiser" to apply to this example
- noiser = random.choice(self._noisers)
- # Apply it
- processed_examples.append(
- noise_token_sequence(
- example,
- mask_ratio=noiser['mask_ratio'],
- mean_span_length=noiser['mean_span_length'],
- prefix_tokens=noiser['prefix_tokens'],
- max_raw_length=noiser['max_raw_length'],
- max_seq_length=self.max_seq_length,
- tokenizer=self.tokenizer,
- sentinel_token_ids=self._sentinel_token_ids,
- decoder_only_format=self.decoder_only_format,
- context_eos=self.context_eos))
- batch = self.tokenizer.pad(processed_examples)
-
- # This logic prevents trimming on at least the first batch
- if not (self._allow_pad_trimming and self._seen_first_batch):
- self._seen_first_batch = True
- return batch
- self._seen_first_batch = True
-
- # Truncate portions of the inputs that are purely padding
- # (up to a multiple of 8)
- multiple_of = 8
- n_examples_per_length = batch['attention_mask'].sum(0)
- keep_tokens = torch.sum(n_examples_per_length > 0)
- keep_tokens = int(multiple_of * torch.ceil(keep_tokens / multiple_of))
-
- # Note: EncDec formatting will always produce a right-padded batch
- if self.tokenizer.padding_side == 'left' and self.decoder_only_format:
- batch['input_ids'] = batch['input_ids'][:, -keep_tokens:]
- batch['attention_mask'] = batch['attention_mask'][:, -keep_tokens:]
- else:
- batch['input_ids'] = batch['input_ids'][:, :keep_tokens]
- batch['attention_mask'] = batch['attention_mask'][:, :keep_tokens]
-
- if self.decoder_only_format:
- if self.tokenizer.padding_side == 'left':
- batch['labels'] = batch['labels'][:, -keep_tokens:]
- batch['bidirectional_mask'] = batch[
- 'bidirectional_mask'][:, -keep_tokens:]
- else:
- batch['labels'] = batch['labels'][:, :keep_tokens]
- batch['bidirectional_mask'] = batch[
- 'bidirectional_mask'][:, :keep_tokens]
-
- else:
- # Truncate portions of the decoder inputs that are purely padding
- n_examples_per_length = batch['decoder_attention_mask'].sum(0)
- keep_tokens = torch.sum(n_examples_per_length > 0)
- keep_tokens = int(multiple_of *
- torch.ceil(keep_tokens / multiple_of))
-
- batch['labels'] = batch['labels'][:, :keep_tokens]
- batch['decoder_attention_mask'] = batch[
- 'decoder_attention_mask'][:, :keep_tokens]
- batch['decoder_input_ids'] = batch[
- 'decoder_input_ids'][:, :keep_tokens]
-
- # This slicing can produce non-contiguous tensors, so use .contiguous
- # to prevent related problems
- batch = {k: v.contiguous() for k, v in batch.items()}
-
- return batch
-
-
-def build_text_denoising_dataloader(
- cfg: DictConfig,
- tokenizer: PreTrainedTokenizerBase,
- device_batch_size: int,
-) -> DataSpec:
- """Constructor function for a Mixture of Denoisers dataloader.
-
- This function constructs a dataloader that can be used to train an
- encoder-decoder model or a (prefix LM) decoder-only model on a text
- denoising task mixture (e.g. span corruption, or UL2).
-
- The underlying dataset is a :class:`StreamingTextDataset`, allowing you to
- stream raw text data or pre-tokenized text data.
-
- The dataloader uses a :class:`MixtureOfDenoisersCollator` to prepare the
- tokenized examples into training batches.
-
- Args:
- cfg (DictConfig): An omegaconf dictionary used to configure the loader:
- cfg.name (str): The type of dataloader to build. Must = "text_denoising".
- ---
- cfg.dataset.max_seq_len (int): The maximum length of sequences
- in the batch. See :class:`MixtureOfDenoisersCollator` docstring
- for details.
- cfg.dataset.packing_ratio (Optional[float, Literal['auto']]): If provided, this invokes
- a collator wrapper that packs device_batch_size*packing_ratio
- raw examples into device_batch_size packed examples. This helps
- minimize padding while preserving sequence integrity.
- This adds `sequence_id` to the batch, which indicates which unique
- sequence each token belongs to.
-
- If set to 'auto', packing_ratio is profiled and the highest observed packing ratio with
- zero waste is selected.
- In practice, this may result in > 0 waste because profiling is done on only a portion
- of the dataset.
-
- Note: Using this feature will not change device_batch_size but it
- will determine the number of raw examples consumed by the dataloader
- per batch. Some examples may be discarded if they do not fit when
- packing.
- Select packing_ratio **carefully** based on the dataset
- statistics, max_seq_len, and tolerance for discarding samples!
- The script `scripts/misc/profile_packing.py` can help
- you choose the best packing_ratio.
- See :class:`StreamingTextDataset` for info on other standard config
- options within `cfg.dataset`.
- ---
- cfg.mixture_of_denoisers.decoder_only_format (bool): Whether the
- batches should use the format required for training a decoder-only
- model (if ``True``) or an encoder-decoder model (if ``False``).
- cfg.mixture_of_denoisers.span_mean_lengths_and_ratios (optional): The
- parameters for any span corruption denoising tasks to include in
- the task mixture.
- See :class:`MixtureOfDenoisersCollator` docstring for details.
- cfg.mixture_of_denoisers.sequence_mask_ratios (optional): The
- parameters for any sequence denoising tasks to include in the
- task mixture.
- See :class:`MixtureOfDenoisersCollator` docstring for details.
- cfg.mixture_of_denoisers.allow_pad_trimming (optional): Whether to
- allow the collator to trim padding when possible (if ``True``).
- Defaults to ``False``.
- cfg.mixture_of_denoisers.prefix_function (optional): Set to ``None``
- to disable the UL2-style prefixes that will be automatically
- added by default.
- ---
- See :class:`DataLoader` for standard argument options to the pytorch
- dataloader, such as `cfg.drop_last`, `cfg.num_workers`, etc.
- tokenizer (transformers.PreTrainedTokenizer): The tokenizer used to
- prepare the data from raw text. Any missing sentinel tokens will
- be added by the collator.
- device_batch_size (int): The size of the batches (number of examples)
- that the dataloader will produce.
-
- Note:
- You can use the script `scripts/misc/profile_packing.py` to quickly test the
- padding/waste rates for different `cfg.dataset.packing_ratio` choices,
- given a starting workload YAML.
- """
- warnings.warn(
- VersionedDeprecationWarning('Text denoising is deprecated.',
- remove_version='0.7.0'))
- assert cfg.name == 'text_denoising', f'Tried to build_denoising text dataloader with cfg.name={cfg.name}'
-
- collate_fn = MixtureOfDenoisersCollator(
- tokenizer=tokenizer,
- max_seq_length=cfg.dataset.max_seq_len,
- decoder_only_format=cfg.mixture_of_denoisers.decoder_only_format,
- span_mean_lengths_and_ratios=cfg.mixture_of_denoisers.get(
- 'span_mean_lengths_and_ratios'),
- sequence_mask_ratios=cfg.mixture_of_denoisers.get(
- 'sequence_mask_ratios'),
- allow_pad_trimming=cfg.mixture_of_denoisers.get('allow_pad_trimming',
- False),
- prefix_function=cfg.mixture_of_denoisers.get('prefix_function',
- ul2_prefix_function),
- context_eos=cfg.mixture_of_denoisers.get('context_eos'))
-
- truncate_to = cfg.mixture_of_denoisers.get('truncate_raw_tokens_to')
- if truncate_to is None:
- # By default, truncate to the largest max raw length of the denoisers
- truncate_to = collate_fn.largest_max_raw_length
- elif isinstance(truncate_to, str):
- if truncate_to.lower() == 'min':
- # Truncate to the smallest max raw length of the denoisers
- truncate_to = collate_fn.smallest_max_raw_length
- elif truncate_to.lower() == 'max':
- # Truncate to the largest max raw length of the denoisers
- truncate_to = collate_fn.largest_max_raw_length
- else:
- raise ValueError(
- f'truncate_raw_tokens_to(="{truncate_to.lower()}") must be "min", "max", a positive int, or None.'
- )
- else:
- if not isinstance(truncate_to, int):
- ValueError(
- f'truncate_raw_tokens_to(={truncate_to}) must be "min", "max", a positive int, or None.'
- )
- if truncate_to < 0:
- ValueError(
- f'truncate_raw_tokens_to(={truncate_to}) must be "min", "max", a positive int, or None.'
- )
-
- dataset = StreamingTextDataset(
- local=cfg.dataset.local,
- tokenizer=tokenizer,
- max_seq_len=truncate_to,
- remote=cfg.dataset.get('remote'),
- split=cfg.dataset.get('split'),
- shuffle=cfg.dataset.get('shuffle', False),
- predownload=cfg.dataset.get('predownload', None),
- keep_zip=cfg.dataset.get('keep_zip', False),
- download_retry=cfg.dataset.get('download_retry', 2),
- download_timeout=cfg.dataset.get('download_timeout', 60),
- validate_hash=cfg.dataset.get('validate_hash', None),
- shuffle_seed=cfg.dataset.get('shuffle_seed', 9176),
- num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', None),
- batch_size=device_batch_size,
- )
-
- if dataset.tokenizer.pad_token is None:
- dataset.tokenizer.pad_token = dataset.tokenizer.eos_token
-
- if cfg.dataset.get('packing_ratio'):
- n_examples_to_pack = int(device_batch_size * cfg.dataset.packing_ratio)
- if n_examples_to_pack < device_batch_size:
- raise ValueError('packing_ratio must be >= 1, if supplied')
- if not cfg.mixture_of_denoisers.decoder_only_format:
- raise NotImplementedError(
- 'On-the-fly packing is currently only supported for decoder-only formats.'
- )
- collate_fn = BinPackCollator(
- collator=collate_fn,
- target_batch_size=device_batch_size,
- max_seq_len=cfg.dataset.max_seq_len,
- pad_token_id=dataset.tokenizer.pad_token_id,
- padding_side=dataset.tokenizer.padding_side,
- max_leftover_bins_to_keep=cfg.dataset.get(
- 'max_leftover_bins_to_keep'),
- )
- device_batch_size = n_examples_to_pack
- elif cfg.dataset.get('max_leftover_bins_to_keep') is not None:
- raise ValueError(
- 'cfg.dataset.max_leftover_bins_to_keep has been defined, ' +\
- 'but cfg.dataset.packing_ratio has not been set. Please set ' +\
- 'the latter to turn on packing or remove the former from the config.')
-
- dl = DataLoader(
- dataset,
- collate_fn=collate_fn,
- batch_size=device_batch_size,
- drop_last=cfg.drop_last,
- num_workers=cfg.num_workers,
- pin_memory=cfg.get('pin_memory', True),
- prefetch_factor=cfg.get('prefetch_factor', 2),
- persistent_workers=cfg.get('persistent_workers', False),
- timeout=cfg.get('timeout', 0),
- )
-
- token_counting_func = get_tokens_per_batch_func(
- decoder_only=cfg.mixture_of_denoisers.decoder_only_format)
-
- return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)
-
-
-def noise_token_sequence(
- example: Union[torch.Tensor, Mapping[str, Any]],
- mask_ratio: float,
- mean_span_length: Optional[float],
- prefix_tokens: Optional[Sequence[int]],
- max_raw_length: int,
- max_seq_length: int,
- tokenizer: PreTrainedTokenizerBase,
- sentinel_token_ids: np.ndarray,
- decoder_only_format: bool,
- context_eos: bool,
-) -> Dict[str, torch.Tensor]:
- """Span corruption applicable to all UL2 denoising tasks."""
- # Extract the raw text tokens (trim if we need to)
- if isinstance(example, torch.Tensor):
- # If the example is a tensor, assume is the raw tokens with no padding
- tokens = example
- length = len(tokens)
- else:
- tokens = example['input_ids']
- length = sum(example['attention_mask'])
- if length > max_raw_length:
- length = max_raw_length
- if tokenizer.padding_side == 'left':
- tokens = tokens[-length:]
- else:
- tokens = tokens[:length]
-
- prefix_tokens = prefix_tokens or []
-
- if length < 1:
- raise ValueError('Example cannot be empty but token length <1.')
-
- # mean_span_length==None is a special case for "sequential" denoising
- # (where a single span at the end of the sequence is masked)
- if mean_span_length is None:
- # This ensures that exactly 1 span will be produced and that
- # trimming to max_seq_length won't cut off any token.
- # In the decoder-only case, this won't insert new tokens.
- if mask_ratio <= 0.5:
- u = np.random.uniform(low=0.0, high=mask_ratio * 2)
- else:
- u = np.random.uniform(low=(mask_ratio * 2) - 1, high=1.0)
- mean_span_length = float(np.round(1 + u * (length - 1)))
- mask_ratio = mean_span_length / length
- use_sentinels = False
- else:
- use_sentinels = True
-
- # Generate the mask
- # Note: this function can be used for all the UL2 noising functions
- mask = _sample_mask_array(length, mask_ratio, mean_span_length)
- # The sequence should always be unmasked at the beginning
- assert mask[0] == 0
-
- # Generate the input/label sequences given the raw tokens and the mask
- tokens_inputs = _apply_mask(tokens,
- mask,
- use_sentinels,
- tokenizer.eos_token_id,
- sentinel_token_ids,
- ensure_eos=context_eos)
- tokens_labels = _apply_mask(tokens,
- 1 - mask,
- use_sentinels,
- tokenizer.eos_token_id,
- sentinel_token_ids,
- ensure_eos=True)
-
- # Tag the inputs with any prefix
- if prefix_tokens:
- tokens_inputs = np.concatenate([prefix_tokens, tokens_inputs])
-
- # Trim if necessary
- if len(tokens_inputs) > max_seq_length:
- raise ValueError('This should not exceed the max length')
- if len(tokens_labels) > max_seq_length:
- raise ValueError('This should not exceed the max length')
-
- tokens_inputs = torch.LongTensor(tokens_inputs)
- tokens_labels = torch.LongTensor(tokens_labels)
-
- if decoder_only_format:
- return _format_tokens_for_decoder_only(tokens_inputs, tokens_labels,
- max_seq_length,
- tokenizer.pad_token_id,
- tokenizer.padding_side)
- return _format_tokens_for_encoder_decoder(tokens_inputs, tokens_labels,
- max_seq_length,
- tokenizer.pad_token_id)
-
-
-def _get_max_starting_length(max_length: int, mask_ratio: float,
- mean_span_length: float, n_prefix_tokens: int,
- decoder_only_format: bool,
- context_eos: bool) -> int:
- """Get max num raw tokens that will fit max_length."""
-
- def sequence_stats(length: int):
- length = np.maximum(length, 2)
- num_noise_tokens = int(np.round(mask_ratio * float(length)))
- num_noise_tokens = np.minimum(np.maximum(num_noise_tokens, 1),
- length - 1)
- num_spans = int(np.round(float(num_noise_tokens) / mean_span_length))
- num_noise_spans = np.maximum(num_spans, 1)
- num_nonnoise_tokens = length - num_noise_tokens
- # Prefix, sentinel, and EOS added to input for Enc-Dec
- extra_inp_tokens = n_prefix_tokens + num_noise_spans + int(context_eos)
- # Sentinel and EOS added to target
- extra_targ_tokens = num_noise_spans + 1
- # Sequence totals after corruption
- total_inp_tokens = num_nonnoise_tokens + extra_inp_tokens
- total_targ_tokens = num_noise_tokens + extra_targ_tokens
- return total_inp_tokens, total_targ_tokens
-
- def length_fits(length: int) -> bool:
- total_inp_tokens, total_targ_tokens = sequence_stats(length)
- if decoder_only_format:
- return (total_inp_tokens + total_targ_tokens) <= max_length
- return (total_inp_tokens <= max_length) and (total_targ_tokens <=
- max_length)
-
- # Start with a definitely too-long sequence and reduce until it fits
- num_raw_tokens = max_length * 2
- while num_raw_tokens > 0:
- if length_fits(num_raw_tokens):
- return num_raw_tokens
- num_raw_tokens -= 1
- raise ValueError(
- 'Unable to find a starting sequence length that can fit given the corruption and max_length parameters.'
- )
-
-
-def _sample_mask_array(length: int, mask_ratio: float,
- mean_span_length: float) -> np.ndarray:
- """Samples a span corruption mask."""
- if mask_ratio == 0.0:
- return np.zeros(length)
- # This first block computes the number of noise/non-noise spans and the
- # total tokens in each. Extra steps are taken to handle edge cases that
- # cause degeneracy.
- starting_length = length
- length = np.maximum(length, 2)
- num_noise_tokens = int(np.round(mask_ratio * float(length)))
- num_noise_tokens = np.minimum(np.maximum(num_noise_tokens, 1), length - 1)
- num_spans = int(np.round(float(num_noise_tokens) / mean_span_length))
- num_noise_spans = np.maximum(num_spans, 1)
- num_nonnoise_tokens = length - num_noise_tokens
-
- # Sample the noise/non-noise span lengths and interleave them to
- # generate the mask array.
- # Note: We always start with a non-noise span.
- def _sample_span_lengths(total_tokens: int, num_spans: int) -> np.ndarray:
- """Samples lengths of num_spans segments.
-
- Note: the combined length of segments equals total_tokens.
- """
- span_markers = np.less(np.arange(total_tokens - 1), num_spans -
- 1)[np.random.permutation(total_tokens - 1)]
- span_start_indicator = np.concatenate([np.array([0]), span_markers])
- span_id = np.cumsum(span_start_indicator).reshape(-1, 1)
- spans = np.arange(num_spans).reshape(1, -1)
- span_lengths = np.sum(span_id == spans, axis=0)
- return span_lengths
-
- noise_span_lengths = _sample_span_lengths(num_noise_tokens, num_noise_spans)
- nonnoise_span_lengths = _sample_span_lengths(num_nonnoise_tokens,
- num_noise_spans)
- interleaved_span_lengths = np.reshape(
- np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1),
- [num_noise_spans * 2])
-
- span_starts = np.cumsum(interleaved_span_lengths)[:-1]
- span_start_indicator = np.zeros(length)
- span_start_indicator[span_starts] = 1
- span_id = np.cumsum(span_start_indicator)
- is_noise = np.equal(np.mod(span_id, 2), 1)
-
- mask = is_noise[:starting_length]
-
- return mask
-
-
-def _apply_mask(tokens: Union[torch.Tensor, Sequence[int], np.ndarray],
- mask: np.ndarray,
- use_sentinels: bool,
- eos_token_id: int,
- sentinel_token_ids: np.ndarray,
- ensure_eos: bool = True) -> np.ndarray:
- """Remove or replace masked portions from token sequence."""
- if not use_sentinels:
- # The logic is simple if we don't use sentinel tokens
- noised_tokens = np.array(tokens)[np.logical_not(mask)]
-
- # Ensure there's an end-of-sentence token at the end
- if ensure_eos and (noised_tokens[-1] != eos_token_id):
- noised_tokens = np.concatenate(
- [noised_tokens, np.array([eos_token_id])])
-
- return noised_tokens
-
- # Masking at previous token
- prev_token_mask = np.concatenate([np.array([0]), mask[:-1]])
-
- # Decompose mask into start-of-span mask and non-start-of-span mask
- start_of_noise_span_token = np.logical_and(mask,
- np.logical_not(prev_token_mask))
- nonstart_noise_span_token = np.logical_and(mask, prev_token_mask)
-
- # Replace tokens at the start of each noise span with its corresponding
- # sentinel token
- sentinel_idx = np.minimum(len(sentinel_token_ids),
- np.cumsum(start_of_noise_span_token)) - 1
- tokens = np.where(start_of_noise_span_token,
- sentinel_token_ids[sentinel_idx], tokens)
-
- # Remove masked tokens (but preserving the sentinel tokens)
- noised_tokens = tokens[np.logical_not(nonstart_noise_span_token)]
-
- # Ensure there's an end-of-sentence token at the end
- if ensure_eos and (noised_tokens[-1] != eos_token_id):
- noised_tokens = np.concatenate(
- [noised_tokens, np.array([eos_token_id])])
- return noised_tokens
-
-
-def _format_tokens_for_encoder_decoder(
- tokens_inputs: torch.LongTensor,
- tokens_labels: torch.LongTensor,
- max_seq_length: int,
- pad_token_id: int,
-) -> Dict[str, torch.Tensor]:
- """Package the input/label sequence for an EncDec model."""
- example = {}
- # Re-populate with an empty, padded example
- example['input_ids'] = torch.full((max_seq_length,),
- pad_token_id,
- dtype=torch.int32)
- example['labels'] = torch.full((max_seq_length,),
- _HF_IGNORE_INDEX,
- dtype=torch.int32)
- example['attention_mask'] = torch.zeros_like(example['input_ids'])
- example['decoder_attention_mask'] = torch.zeros_like(example['labels'])
-
- # Fill in with processed results (Note: EncDec format is right-padded)
- example['input_ids'][:len(tokens_inputs)] = tokens_inputs
- example['labels'][:len(tokens_labels)] = tokens_labels
- example['attention_mask'][:len(tokens_inputs)] = 1
- example['decoder_attention_mask'][:len(tokens_labels)] = 1
-
- # Best practice is to include decoder_input_ids (= right-shifted labels)
- example['decoder_input_ids'] = torch.full_like(example['labels'],
- pad_token_id)
- example['decoder_input_ids'][1:len(tokens_labels)] = tokens_labels[:-1]
- return example
-
-
-def _format_tokens_for_decoder_only(
- tokens_inputs: torch.LongTensor,
- tokens_labels: torch.LongTensor,
- max_seq_length: int,
- pad_token_id: int,
- padding_side: str,
-) -> Dict[str, torch.Tensor]:
- """Package the input/label sequence for an decoder-only model."""
- example = {}
- # Re-populate with an empty, padded example
- example['input_ids'] = torch.full((max_seq_length,),
- pad_token_id,
- dtype=torch.int32)
- example['labels'] = torch.full((max_seq_length,),
- _HF_IGNORE_INDEX,
- dtype=torch.int32)
- example['attention_mask'] = torch.full((max_seq_length,),
- 0,
- dtype=torch.bool)
- example['bidirectional_mask'] = torch.full((max_seq_length,),
- 0,
- dtype=torch.bool)
-
- n_input = len(tokens_inputs)
- n_label = len(tokens_labels)
- n_concat = n_input + n_label
- assert n_concat <= max_seq_length, f'{n_concat=}, {n_input=}, {n_label=}'
-
- tokens_concat = torch.concat([tokens_inputs, tokens_labels], dim=0)
-
- # Fill in with the processed results
- if padding_side == 'left':
- example['input_ids'][-n_concat:] = tokens_concat
- # `labels` copies `input_ids` but with -100 at
- # non-loss-generating tokens. `labels` will be shifted in the
- # model code when computing loss.
- example['labels'][-n_concat:] = tokens_concat
- example['labels'][-n_concat:-n_label] = _HF_IGNORE_INDEX
- example['attention_mask'][-n_concat:] = 1
- example['bidirectional_mask'][-n_concat:-n_label] = 1
- else:
- example['input_ids'][:n_concat] = tokens_concat
- # See above comment regarding `labels`
- example['labels'][:n_concat] = tokens_concat
- example['labels'][:n_input] = _HF_IGNORE_INDEX
- example['attention_mask'][:n_concat] = 1
- example['bidirectional_mask'][:n_input] = 1
- return example
-
-
-# Helpful to test if your dataloader is working locally
-# Run `python denoising.py [local] [remote, optional]` and verify that batches
-# are printed out
-if __name__ == '__main__':
- from llmfoundry.utils.builders import build_tokenizer
-
- local = sys.argv[1]
- if len(sys.argv) > 2:
- remote = sys.argv[2]
- else:
- remote = local
- print(f'Reading val split from {remote} -> {local}')
-
- decoder_only = True
-
- cfg = {
- 'name': 'text_denoising',
- 'dataset': {
- 'local': local,
- 'remote': remote,
- 'split': 'val', # 'val_small',
- 'shuffle': False,
- 'max_seq_len': 2048 if decoder_only else 1024,
- 'packing_ratio': 4.5,
- 'predownload': 1000,
- 'keep_zip': True, # in case we need compressed files after testing
- },
- 'mixture_of_denoisers': {
- 'decoder_only_format': decoder_only,
- 'span_mean_lengths_and_ratios': [[3, .15], [8, .5]],
- 'sequence_mask_ratios': 0.25,
- },
- 'drop_last': False,
- 'num_workers': 0,
- }
- cfg = om.create(cfg)
- device_batch_size = 2
-
- tokenizer_name = 'EleutherAI/gpt-neox-20b' if decoder_only else 't5-base'
- tokenizer_kwargs = {'model_max_length': cfg.dataset.max_seq_len}
- tokenizer = build_tokenizer(tokenizer_name=tokenizer_name,
- tokenizer_kwargs=tokenizer_kwargs)
-
- loader = build_text_denoising_dataloader(cfg, tokenizer,
- device_batch_size).dataloader
- assert isinstance(loader, DataLoader)
- assert isinstance(loader.dataset, StreamingTextDataset)
-
- print(f'\n\nTRUNCATING TO: {loader.dataset.max_seq_len}\n\n')
-
- packing = cfg.dataset.get('packing_ratio') is not None
- if packing:
- tokenizer = loader.collate_fn.base_collator.tokenizer
- else:
- tokenizer = loader.collate_fn.tokenizer
- batch_ix = 0
- for batch in loader:
- if batch_ix >= 50:
- batch_ix += 1
- break
- if batch_ix >= 5:
- if not packing:
- break
- batch_ix += 1
- continue
- print('\n')
- print('#' * 20, f'Batch {batch_ix}', '#' * 20)
- for k, v in batch.items():
- print(k, v.shape, v.dtype)
- for sample_ix, token_sample in enumerate(batch['input_ids']):
- if cfg.mixture_of_denoisers.decoder_only_format:
- labels = batch['labels'][sample_ix]
- attn_inputs = batch['bidirectional_mask'][sample_ix].to(
- torch.bool)
- attn_full = batch['attention_mask'][sample_ix].to(torch.bool)
- attn_labels = torch.logical_xor(attn_inputs, attn_full)
- print('-' * 20, f' Sample {sample_ix} ', '-' * 20)
- if packing:
- for subseq in range(
- int(batch['sequence_id'][sample_ix].max()) + 1):
- is_subseq = batch['sequence_id'][sample_ix] == subseq
- print(
- '\033[93m{}\033[00m\n'.format('Input: '),
- tokenizer.decode(token_sample[torch.logical_and(
- is_subseq, attn_inputs)]))
- print(
- '\033[92m{}\033[00m\n'.format('Target: '),
- tokenizer.decode(labels[torch.logical_and(
- is_subseq, attn_labels)]))
- else:
- print('\033[91m{}\033[00m\n'.format('Full: '),
- tokenizer.decode(token_sample[attn_full]))
- print('\033[93m{}\033[00m\n'.format('Input: '),
- tokenizer.decode(token_sample[attn_inputs]))
- print('\033[92m{}\033[00m\n'.format('Target: '),
- tokenizer.decode(labels[attn_labels]))
- else:
- labels = batch['labels'][sample_ix]
- attn_inputs = batch['attention_mask'][sample_ix].to(torch.bool)
- attn_labels = batch['decoder_attention_mask'][sample_ix].to(
- torch.bool)
- print('-' * 20, f' Sample {sample_ix} ', '-' * 20)
- print('\033[93m{}\033[00m\n'.format('Input: '),
- tokenizer.decode(token_sample[attn_inputs]))
- print('\033[92m{}\033[00m\n'.format('Target: '),
- tokenizer.decode(labels[attn_labels]))
- batch_ix += 1
-
- if packing:
- print(f'Padding = {100*(1-loader.collate_fn.efficiency):5.2f}%')
- print(f'Waste = {100*loader.collate_fn.waste:5.2f}%')
diff --git a/llmfoundry/data/finetuning/collator.py b/llmfoundry/data/finetuning/collator.py
index d82fad5ec2..6e3babd657 100644
--- a/llmfoundry/data/finetuning/collator.py
+++ b/llmfoundry/data/finetuning/collator.py
@@ -331,31 +331,21 @@ def _process_and_batch_decoder_only(
self._warned_truncated = True
attention_mask = [1] * len(input_ids)
- # bidirectional_mask is used by our prefix lm model variants
- # Note: this will be malformed if any loss-generating tokens are followed by non-loss-generating tokens
- # (such as in the case of multi-turn chat examples)
- bidirectional_mask = [
- 1 if label == _HF_IGNORE_INDEX else 0 for label in labels
- ]
# Annoyingly, we need to pad everything but input_ids
# and attention_mask ourselves
n_total = len(input_ids)
i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - n_total)
- z_pad = [0] * (self.max_seq_len - n_total)
if self.tokenizer.padding_side == 'left':
labels = i_pad + labels
- bidirectional_mask = z_pad + bidirectional_mask
else:
labels = labels + i_pad
- bidirectional_mask = bidirectional_mask + z_pad
# Update the example
processed_example = {
'input_ids': input_ids,
'labels': labels,
'attention_mask': attention_mask,
- 'bidirectional_mask': bidirectional_mask,
}
processed_examples.append(processed_example)
diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py
index 88b388d6ce..a0003d0571 100644
--- a/llmfoundry/data/finetuning/dataloader.py
+++ b/llmfoundry/data/finetuning/dataloader.py
@@ -552,15 +552,6 @@ def _build_collate_fn(
1)],
skip_special_tokens=False,
clean_up_tokenization_spaces=True))
- print(
- '\033[92m{}\033[00m\n'.format('CONTEXT: '),
- tokenizer.decode(batch['input_ids'][
- j,
- torch.logical_and(
- is_subseq, batch['bidirectional_mask'][j] ==
- 1)],
- skip_special_tokens=False,
- clean_up_tokenization_spaces=True))
print(
'\033[91m{}\033[00m\n'.format('TARGET: '),
tokenizer.decode(batch['input_ids'][
@@ -578,12 +569,6 @@ def _build_collate_fn(
batch['attention_mask'][j] == 1],
skip_special_tokens=False,
clean_up_tokenization_spaces=True))
- print(
- '\033[92m{}\033[00m\n'.format('CONTEXT: '),
- tokenizer.decode(batch['input_ids'][
- j, batch['bidirectional_mask'][j] == 1],
- skip_special_tokens=False,
- clean_up_tokenization_spaces=True))
print(
'\033[91m{}\033[00m\n'.format('TARGET: '),
tokenizer.decode(batch['input_ids'][
diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py
index fba3ab2d3e..9f4af3099e 100644
--- a/llmfoundry/data/packing.py
+++ b/llmfoundry/data/packing.py
@@ -71,7 +71,6 @@ def pack(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
'input_ids',
'labels',
'attention_mask',
- 'bidirectional_mask',
'sequence_id',
]
# Cut everything down to size
@@ -278,7 +277,6 @@ def pad_tensor(tensor: torch.Tensor, pad_value: int):
'input_ids': pad_token_id,
'labels': -100,
'attention_mask': 0,
- 'bidirectional_mask': 0,
'sequence_id': -1,
}
keys = packed_examples[0].keys()
diff --git a/llmfoundry/metrics/__init__.py b/llmfoundry/metrics/__init__.py
index 116b6dd08c..6c71a3ea08 100644
--- a/llmfoundry/metrics/__init__.py
+++ b/llmfoundry/metrics/__init__.py
@@ -43,11 +43,6 @@
'code_eval_accuracy',
]
-DEFAULT_PREFIX_LM_METRICS = [
- 'language_cross_entropy',
- 'masked_accuracy',
-]
-
DEFAULT_ENC_DEC_METRICS = [
'language_cross_entropy',
'masked_accuracy',
@@ -66,6 +61,5 @@
'MaskedAccuracy',
'DEFAULT_CAUSAL_LM_TRAIN_METRICS',
'DEFAULT_CAUSAL_LM_EVAL_METRICS',
- 'DEFAULT_PREFIX_LM_METRICS',
'DEFAULT_ENC_DEC_METRICS',
]
diff --git a/llmfoundry/models/__init__.py b/llmfoundry/models/__init__.py
index 36234d3c14..ea144225c0 100644
--- a/llmfoundry/models/__init__.py
+++ b/llmfoundry/models/__init__.py
@@ -1,8 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
-from llmfoundry.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM,
- ComposerHFT5)
+from llmfoundry.models.hf import ComposerHFCausalLM, ComposerHFT5
from llmfoundry.models.inference_api_wrapper import (FMAPICasualLMEvalWrapper,
FMAPIChatAPIEvalWrapper,
OpenAICausalLMEvalWrapper,
@@ -13,7 +12,6 @@
models.register('mpt_causal_lm', func=ComposerMPTCausalLM)
models.register('hf_causal_lm', func=ComposerHFCausalLM)
-models.register('hf_prefix_lm', func=ComposerHFPrefixLM)
models.register('hf_t5', func=ComposerHFT5)
models.register('openai_causal_lm', func=OpenAICausalLMEvalWrapper)
models.register('fmapi_causal_lm', func=FMAPICasualLMEvalWrapper)
@@ -22,7 +20,6 @@
__all__ = [
'ComposerHFCausalLM',
- 'ComposerHFPrefixLM',
'ComposerHFT5',
'MPTConfig',
'MPTPreTrainedModel',
diff --git a/llmfoundry/models/hf/__init__.py b/llmfoundry/models/hf/__init__.py
index d0ab65dfaf..3c35080d6e 100644
--- a/llmfoundry/models/hf/__init__.py
+++ b/llmfoundry/models/hf/__init__.py
@@ -5,12 +5,10 @@
from llmfoundry.models.hf.hf_fsdp import (prepare_hf_causal_lm_model_for_fsdp,
prepare_hf_enc_dec_model_for_fsdp,
prepare_hf_model_for_fsdp)
-from llmfoundry.models.hf.hf_prefix_lm import ComposerHFPrefixLM
from llmfoundry.models.hf.hf_t5 import ComposerHFT5
__all__ = [
'ComposerHFCausalLM',
- 'ComposerHFPrefixLM',
'ComposerHFT5',
'prepare_hf_causal_lm_model_for_fsdp',
'prepare_hf_enc_dec_model_for_fsdp',
diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py
index 7640a7b8a8..38ed7a7e70 100644
--- a/llmfoundry/models/hf/hf_causal_lm.py
+++ b/llmfoundry/models/hf/hf_causal_lm.py
@@ -17,11 +17,10 @@
from llmfoundry.metrics import (DEFAULT_CAUSAL_LM_EVAL_METRICS,
DEFAULT_CAUSAL_LM_TRAIN_METRICS)
from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
-from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
+from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP
from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.utils import init_empty_weights
from llmfoundry.utils.config_utils import pop_config
-from llmfoundry.utils.warnings import VersionedDeprecationWarning
if TYPE_CHECKING:
from peft import PeftConfig
@@ -31,7 +30,7 @@
log = logging.getLogger(__name__)
-class ComposerHFCausalLM(HuggingFaceModelWithZLoss):
+class ComposerHFCausalLM(HuggingFaceModelWithFSDP):
"""Configures a :class:`.HuggingFaceModel` around a Causal LM.
Args:
@@ -54,11 +53,8 @@ class ComposerHFCausalLM(HuggingFaceModelWithZLoss):
cfg.use_auth_token (bool, optional): Whether to use the Hugging Face authentication token when
loading from Hugging Face Hub. Default: ``False``.
cfg.use_train_metrics (bool, optional): Whether to use training metrics. Default: ``True``.
- cfg.z_loss (float, optional): The z-loss coefficient. Default: ``0.0``.
cfg.load_in_8bit (bool, optional): Whether to load the model in 8-bit mode. Default: ``False``.
cfg.init_device (str, optional): Which device to initialize the model on. Default: ``'cpu'``.
- cfg.attention_patch_type (str, optional): Which attention patch to use for llama models. Default: ``None``.
- Deprecated. Will automatically use flash attention 2.
cfg.use_flash_attention_2 (bool, optional): Whether to use flash-attention 2. Default: ``False``.
tokenizer (PreTrainedTokenizer): The tokenizer that the model will use.
"""
@@ -88,26 +84,15 @@ def __init__(self, om_model_config: DictConfig,
load_in_8bit = om_model_config.get('load_in_8bit', False)
# Set up config args for the model construction and base classes
- z_loss = om_model_config.get('z_loss', 0.0)
init_device = om_model_config.get('init_device', 'cpu')
# Resolve "mixed" init device to either "cpu" or "meta"
resolved_init_device = hf_get_init_device(init_device)
- attention_patch_type = om_model_config.get('attention_patch_type', None)
- if attention_patch_type is not None:
- warnings.warn(
- VersionedDeprecationWarning(
- 'attention_patch_type is deprecated and will automatically use flash attention 2. '
- +
- 'We recommend `use_flash_attention_2: true` for llama models.',
- remove_version='0.7.0'))
- use_flash_attention_2 = True
-
requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager'
if use_flash_attention_2 and not is_flash_v2_installed():
raise ValueError(
'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. '
- + 'Please `pip install llm-foundry[gpu-flash2]`.')
+ + 'Please `pip install llm-foundry[gpu]`.')
peft_config_dict = pop_config(om_model_config,
'peft_config',
@@ -269,7 +254,6 @@ def _autoset_attn_implementation_monkeypatch(
tokenizer=tokenizer,
metrics=train_metrics,
eval_metrics=eval_metrics,
- z_loss=z_loss,
init_device=init_device,
peft_config=peft_config,
)
diff --git a/llmfoundry/models/hf/hf_prefix_lm.py b/llmfoundry/models/hf/hf_prefix_lm.py
deleted file mode 100644
index 67060a02b8..0000000000
--- a/llmfoundry/models/hf/hf_prefix_lm.py
+++ /dev/null
@@ -1,149 +0,0 @@
-# Copyright 2022 MosaicML LLM Foundry authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""Implements a Hugging Prefix LM wrapped inside a :class:`.ComposerModel`."""
-
-from __future__ import annotations
-
-from typing import Mapping, MutableMapping
-
-from composer.utils import dist
-from omegaconf import DictConfig
-from transformers import (AutoConfig, AutoModelForCausalLM,
- PreTrainedTokenizerBase)
-
-from llmfoundry.metrics import DEFAULT_PREFIX_LM_METRICS
-from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
-from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
-from llmfoundry.models.utils import (adapt_tokenizer_for_denoising,
- add_bidirectional_mask_if_missing,
- convert_hf_causal_lm_to_prefix_lm,
- init_empty_weights)
-
-__all__ = ['ComposerHFPrefixLM']
-
-
-class ComposerHFPrefixLM(HuggingFaceModelWithZLoss):
- """Configures a :class:`.HuggingFaceModel` around a Prefix LM.
-
- Note: HuggingFace does not natively support Prefix LM-style models. This function uses
- `transformers.AutoModelForCausalLM` to instantiate a Causal LM, then uses a conversion utility
- to turn the model into a Prefix LM. Currently, that conversion utility only supports the
- following HuggingFace Causal LM types:
- - `GPT2LMHeadModel`
- - `GPTNeoForCausalLM`
- - `GPTNeoXForCausalLM`
- - `GPTJForCausalLM`
- - `BloomForCausalLM`
- - `OPTForCausalLM`
-
- Args:
- cfg (DictConfig): An omegaconf dictionary used to configure the model:
- cfg.pretrained_model_name_or_path (str): The name of or local path to
- the HF model (e.g., `gpt2` to instantiate a GPT2LMHeadModel). The model
- will be converted to a Prefix LM during initialization.
- cfg.config_overrides (dict, optional): An optional dictionary of keyword
- arguments that override the default configuration associated with
- cfg.pretrained_model_name_or_path. Default: ``{}``.
- cfg.pretrained (bool): Whether to instantiate the model with pre-trained
- weights coming from cfg.pretrained_model_name_or_path. If ``True``,
- cfg.config_overrides must be compatible with the pre-trained weights.
- cfg.init_device ('cpu' | 'meta'): Which device, 'cpu' or 'meta', to
- initialize the model on. Currently, `meta` is only supported when
- cfg.pretrained is ``False``. Default: ``'cpu'``.
- cfg.z_loss (float, optional): The coefficient of the z-loss. If >0.0, this
- the z-loss will be multiplied by this value before being added to the
- standard loss term. Default: ``0.0``.
- cfg.adapt_vocab_for_denoising (bool, optional): Whether to adapt the vocab
- of the model/tokenizer to include sentinel tokens that are used in denoising
- tasks like Span Corruption. If you intend to load from an existing Composer
- checkpoint that was trained on such a task, set this to ``True`` to ensure
- that the model vocab size matches your checkpoint's vocab size when loading
- the weights. Default: ``False``.
- tokenizer (PreTrainedTokenizer): The tokenizer that the model will use.
- """
-
- def __init__(self, om_model_config: DictConfig,
- tokenizer: PreTrainedTokenizerBase):
- from llmfoundry.utils.builders import build_metric
-
- config = AutoConfig.from_pretrained(
- om_model_config.pretrained_model_name_or_path,
- trust_remote_code=om_model_config.get('trust_remote_code', True),
- use_auth_token=om_model_config.get('use_auth_token', False),
- )
-
- # set config overrides
- for k, v in om_model_config.get('config_overrides', {}).items():
- if not hasattr(config, k):
- raise ValueError(
- f'config does not have attribute "{k}" to override ({k}: {v}).'
- )
-
- attr = getattr(config, k)
- if isinstance(attr, Mapping):
- extra_keys = [_k for _k in v.keys() if _k not in attr.keys()]
- if extra_keys:
- raise ValueError(
- f'Config dict override got unknown keys. ' +
- f'Extra keys: {extra_keys}. ' +
- f'Expected (a subset of) keys: {list(attr.keys())}.')
- getattr(config, k).update(v)
- else:
- setattr(config, k, v)
-
- # Set up the tokenizer (add tokens for denoising sentinels if needed)
- if om_model_config.get('adapt_vocab_for_denoising', False):
- adapt_tokenizer_for_denoising(tokenizer)
-
- init_device = om_model_config.get('init_device', 'cpu')
-
- # Get the device we want to initialize, and use the
- # resolved version to initialize the HF model
- resolved_init_device = hf_get_init_device(init_device)
-
- # We need to have all non-zero local ranks be not-pretrained
- # Rank 0 will still be pretrained, and distribute the weights appropriately
- if dist.get_local_rank() != 0 and init_device == 'mixed':
- om_model_config.pretrained = False
-
- if resolved_init_device == 'cpu':
- if om_model_config.pretrained:
- model = AutoModelForCausalLM.from_pretrained(
- om_model_config.pretrained_model_name_or_path,
- config=config)
- else:
- model = AutoModelForCausalLM.from_config(config)
- elif resolved_init_device == 'meta':
- if om_model_config.pretrained:
- raise ValueError(
- 'Setting cfg.pretrained=True is not supported when init_device="meta".'
- )
- with init_empty_weights(include_buffers=False):
- model = AutoModelForCausalLM.from_config(config)
- else:
- raise ValueError(
- f'init_device="{init_device}" must be either "cpu" or "meta".')
-
- # Convert the Causal LM into a Prefix LM via our custom wrapper
- model = convert_hf_causal_lm_to_prefix_lm(model)
-
- metrics = [
- build_metric(metric, {}) for metric in DEFAULT_PREFIX_LM_METRICS +
- om_model_config.get('additional_train_metrics', [])
- ]
-
- composer_model = super().__init__(model=model,
- shift_labels=True,
- tokenizer=tokenizer,
- metrics=metrics,
- z_loss=om_model_config.get(
- 'z_loss', 0.0),
- init_device=init_device)
-
- return composer_model
-
- def forward(self, batch: MutableMapping):
- # Add bidirectional_mask if it is missing and can be constructed
- add_bidirectional_mask_if_missing(batch)
- return super().forward(batch)
diff --git a/llmfoundry/models/hf/hf_t5.py b/llmfoundry/models/hf/hf_t5.py
index 5956d49cc8..b9c1df64cf 100644
--- a/llmfoundry/models/hf/hf_t5.py
+++ b/llmfoundry/models/hf/hf_t5.py
@@ -14,16 +14,15 @@
from llmfoundry.metrics import DEFAULT_ENC_DEC_METRICS
from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
-from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
-from llmfoundry.models.utils import (adapt_tokenizer_for_denoising,
- init_empty_weights)
+from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP
+from llmfoundry.models.utils import init_empty_weights
from llmfoundry.utils.warnings import experimental_class
__all__ = ['ComposerHFT5']
@experimental_class('ComposerHFT5')
-class ComposerHFT5(HuggingFaceModelWithZLoss):
+class ComposerHFT5(HuggingFaceModelWithFSDP):
"""Configures a :class:`.HuggingFaceModel` around a T5.
Note: This function uses `transformers.T5ForConditionalGeneration`. Future releases
@@ -42,15 +41,6 @@ class ComposerHFT5(HuggingFaceModelWithZLoss):
cfg.init_device ('cpu' | 'meta'): Which device, 'cpu' or 'meta', to
initialize the model on. Currently, `meta` is only supported when
cfg.pretrained is ``False``. Default: ``'cpu'``.
- cfg.z_loss (float, optional): The coefficient of the z-loss. If >0.0, this
- the z-loss will be multiplied by this value before being added to the
- standard loss term. Default: ``0.0``.
- cfg.adapt_vocab_for_denoising (bool, optional): Whether to adapt the vocab
- of the model/tokenizer to include sentinel tokens that are used in denoising
- tasks like Span Corruption. If you intend to load from an existing Composer
- checkpoint that was trained on such a task, set this to ``True`` to ensure
- that the model vocab size matches your checkpoint's vocab size when loading
- the weights. Default: ``False``.
tokenizer (PreTrainedTokenizer): The tokenizer that the model will use.
"""
@@ -87,10 +77,6 @@ def __init__(self, om_model_config: DictConfig,
raise ValueError(f'Model type "hf_t5" currently only supports T5 models ' +\
f'using configs where `is_encoder_decoder` is ``True``.')
- # Set up the tokenizer (add tokens for denoising sentinels if needed)
- if om_model_config.get('adapt_vocab_for_denoising', False):
- adapt_tokenizer_for_denoising(tokenizer)
-
init_device = om_model_config.get('init_device', 'cpu')
# Get the device we want to initialize, and use the
@@ -128,8 +114,6 @@ def __init__(self, om_model_config: DictConfig,
composer_model = super().__init__(model=model,
tokenizer=tokenizer,
metrics=metrics,
- z_loss=om_model_config.get(
- 'z_loss', 0.0),
init_device=init_device)
return composer_model
diff --git a/llmfoundry/models/hf/model_wrapper.py b/llmfoundry/models/hf/model_wrapper.py
index e9f4c89796..2ba88d390c 100644
--- a/llmfoundry/models/hf/model_wrapper.py
+++ b/llmfoundry/models/hf/model_wrapper.py
@@ -5,11 +5,9 @@
from __future__ import annotations
-import warnings
from collections import UserDict
from typing import TYPE_CHECKING, List, Mapping, Optional
-import torch
import transformers
from composer.models.huggingface import HuggingFaceModel
from torchmetrics import Metric
@@ -17,7 +15,6 @@
from transformers.utils.generic import ModelOutput
from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp
-from llmfoundry.utils.warnings import VersionedDeprecationWarning
if TYPE_CHECKING:
from peft import PeftConfig
@@ -26,19 +23,9 @@
_HF_IGNORE_INDEX = -100
-class HuggingFaceModelWithZLoss(HuggingFaceModel):
+class HuggingFaceModelWithFSDP(HuggingFaceModel):
"""Wrapper around HuggingFaceModel.
- This adds z-loss, which is used in some training contexts,
- and is a convenient way to patch features that are generically
- useful for HF models.
- See use of z_loss in PaLM: https://arxiv.org/abs/2204.02311v3, Section 5.
- Also, from https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666:
- Two uses of z_loss are:
- - To keep the logits from drifting too far from zero, which can cause
- unacceptable roundoff errors in bfloat16.
- - To encourage the logits to be normalized log-probabilities.
-
Handles preparation for FSDP wrapping.
"""
@@ -47,7 +34,6 @@ def __init__(self,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
metrics: Optional[List[Metric]] = None,
eval_metrics: Optional[List[Metric]] = None,
- z_loss: float = 0.0,
shift_labels: bool = False,
init_device: Optional[str] = None,
peft_config: Optional['PeftConfig'] = None):
@@ -61,9 +47,6 @@ def __init__(self,
peft_config=peft_config,
should_save_peft_only=True,
)
- self.z_loss = float(z_loss)
- if self.z_loss < 0.0:
- raise ValueError(f'z_loss(={z_loss}) cannot be negative.')
# Note: We need to add the FSDP related attributes to the model AFTER the super init,
# so that the (possible) embedding resizing doesn't destroy them
@@ -88,27 +71,6 @@ def forward(self, batch: Mapping):
def loss(self, outputs: ModelOutput, batch: Mapping):
if self.config.use_return_dict:
- loss, logits = outputs['loss'], outputs['logits']
- else:
- # loss is at index 0 in the output tuple, logits are at index 1
- loss, logits = outputs[:2]
- if self.z_loss == 0.0:
- return loss
-
- warnings.warn(
- VersionedDeprecationWarning('z-loss is deprecated.',
- remove_version='0.7.0'))
-
- # Add a z_loss to the standard loss
- logits_flat = logits.view(-1, logits.size(-1))
- labels_flat = batch['labels'].view(-1)
- log_z = torch.logsumexp(logits_flat[labels_flat != _HF_IGNORE_INDEX],
- dim=1)
- log_z2 = log_z**2
- z_loss = log_z2.mean() * self.z_loss
- if self.config.use_return_dict:
- outputs['loss'] += z_loss
return outputs['loss']
- else:
- outputs[0] += z_loss
- return outputs[0]
+ # loss is at index 0 in the output tuple, logits are at index 1
+ return outputs[:2]
diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py
index 05350b059b..df4216b81c 100644
--- a/llmfoundry/models/layers/__init__.py
+++ b/llmfoundry/models/layers/__init__.py
@@ -4,7 +4,7 @@
from llmfoundry.models.layers.attention import (
ATTN_CLASS_REGISTRY, GroupedQueryAttention, MultiheadAttention,
MultiQueryAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
- flash_attn_fn, scaled_multihead_dot_product_attention, triton_flash_attn_fn)
+ flash_attn_fn, scaled_multihead_dot_product_attention)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
@@ -14,7 +14,6 @@
__all__ = [
'scaled_multihead_dot_product_attention',
'flash_attn_fn',
- 'triton_flash_attn_fn',
'MultiheadAttention',
'MultiQueryAttention',
'GroupedQueryAttention',
diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py
index 79780bccee..1deca69eb2 100644
--- a/llmfoundry/models/layers/attention.py
+++ b/llmfoundry/models/layers/attention.py
@@ -50,7 +50,7 @@ def check_alibi_support(attention_impl: str) -> bool:
def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
original_is_causal: bool) -> bool:
# disable causal when it is not needed
- # necessary for flash & triton for generation with kv_cache
+ # necessary for flash for generation with kv_cache
if original_is_causal and num_query_tokens != num_key_tokens:
if num_query_tokens != 1:
raise NotImplementedError(
@@ -100,7 +100,7 @@ def scaled_multihead_dot_product_attention(
v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
if past_key_value is not None:
- # attn_impl: flash & triton use kernels which expect input shape [b, s, h, d_head].
+ # attn_impl: flash attn uses kernels which expect input shape [b, s, h, d_head].
# kv_cache is therefore stored using that shape.
# attn_impl: torch stores the kv_cache in the ordering which is most advantageous
# for its attn computation ie
@@ -341,123 +341,13 @@ def flash_attn_fn(
return output, None, past_key_value
-def triton_flash_attn_fn(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- n_heads: int,
- kv_n_heads: int,
- past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
- softmax_scale: Optional[float] = None,
- attn_bias: Optional[torch.Tensor] = None,
- key_padding_mask: Optional[torch.Tensor] = None,
- is_causal: bool = False,
- dropout_p: float = 0.0,
- training: bool = False,
- needs_weights: bool = False,
-) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
- torch.Tensor]]]:
- try:
- from llmfoundry.models.layers.flash_attn_triton import flash_attn_func
- except:
- _installed = False
- if version.parse(torch.__version__) < version.parse('2.0.0'):
- _installed = True
- # if torch1.13.1 revert to using triton flash attn from HazyResearch
- # with flash-attn==1.0.9 and triton==2.0.0.dev20221202
- try:
- from flash_attn.flash_attn_triton import flash_attn_func
- except:
- _installed = False
- if not _installed:
- # installing triton-pre-mlir works for both torch1.13.1 and torch2.0+
- # default recommendation is to install this variant
- raise RuntimeError(
- 'Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU '
- +
- 'and `pip install .[gpu]` if installing from llm-foundry source or '
- +
- '`pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` '
- +
- 'if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). '
- +
- 'Note: (1) requires you have CMake and PyTorch already installed.'
- )
-
- check_valid_inputs(query, key, value)
-
- if past_key_value is not None:
- if len(past_key_value) != 0:
- key = torch.cat([past_key_value[0], key], dim=1)
- value = torch.cat([past_key_value[1], value], dim=1)
-
- past_key_value = (key, value)
-
- if attn_bias is not None:
- # clamp to 0 necessary for torch 2.0 compile()
- _s_q = max(0, attn_bias.size(2) - query.size(1))
- _s_k = max(0, attn_bias.size(3) - key.size(1))
- attn_bias = attn_bias[:, :, _s_q:, _s_k:]
-
- if dropout_p:
- raise NotImplementedError(
- f'Dropout not implemented for attn_impl: triton.')
- dropout_p = dropout_p if training else 0.0
-
- if needs_weights:
- raise NotImplementedError(
- f'attn_impl: triton cannot return attn weights.')
-
- if key_padding_mask is not None:
- warnings.warn(
- 'Propagating key_padding_mask to the attention module ' +\
- 'and applying it within the attention module can cause ' +\
- 'unnecessary computation/memory usage. Consider integrating ' +\
- 'into attn_bias once and passing that to each attention ' +\
- 'module instead.'
- )
- b_size, s_k = key_padding_mask.shape[:2]
-
- if attn_bias is None:
- attn_bias = query.new_zeros(b_size, 1, 1, s_k)
-
- attn_bias = attn_bias.masked_fill(
- ~key_padding_mask.view((b_size, 1, 1, s_k)),
- torch.finfo(query.dtype).min)
-
- query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
- key = rearrange(key, 'b s (h d) -> b s h d', h=kv_n_heads)
- value = rearrange(value, 'b s (h d) -> b s h d', h=kv_n_heads)
-
- # multi-query case
- if kv_n_heads == 1:
- # necessary to repeat instead of expand tensor because
- # output contains NaN in edge cases such as with head dimension = 8
- key = key.repeat(1, 1, n_heads, 1)
- value = value.repeat(1, 1, n_heads, 1)
- # grouped query case
- elif kv_n_heads < n_heads:
- # Each query belong to a group of kv heads of group size n_heads // kv_n_heads
- # We repeat each kv head by the group size number to use the underlying MHA kernels
- key = repeat_kv_for_gqa(key, n_heads // kv_n_heads)
- value = repeat_kv_for_gqa(value, n_heads // kv_n_heads)
-
- reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
- attn_output = flash_attn_func( # type: ignore
- query, key, value, attn_bias, reset_is_causal, softmax_scale)
-
- output = attn_output.view(*attn_output.shape[:2], -1) # type: ignore
-
- return output, None, past_key_value
-
-
class GroupedQueryAttention(nn.Module):
"""Grouped Query Attention (GQA) is a generalization of Multi-head (MHA).
and Multi-query attention (MQA).
This allows the user to set a variable of number of kv_n_heads, rather than
- just n_heads or 1, as in MHA and MQA. Using torch or triton attention
+ just n_heads or 1, as in MHA and MQA. Using torch attention
implementation enables user to also use additive bias.
"""
@@ -466,7 +356,7 @@ def __init__(
d_model: int,
n_heads: int,
kv_n_heads: int,
- attn_impl: str = 'triton',
+ attn_impl: str = 'flash',
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
qk_gn: bool = False,
@@ -538,8 +428,6 @@ def __init__(
if self.attn_impl == 'flash':
self.attn_fn = flash_attn_fn
- elif self.attn_impl == 'triton':
- self.attn_fn = triton_flash_attn_fn
elif self.attn_impl == 'torch':
self.attn_fn = scaled_multihead_dot_product_attention
else:
@@ -679,15 +567,14 @@ def forward(
class MultiheadAttention(GroupedQueryAttention):
"""Multi-head self attention.
- Using torch or triton attention implementation enables user to also use
- additive bias.
+ Using torch attention implementation enables user to also use additive bias.
"""
def __init__(
self,
d_model: int,
n_heads: int,
- attn_impl: str = 'triton',
+ attn_impl: str = 'flash',
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
qk_gn: bool = False,
@@ -720,15 +607,14 @@ def __init__(
class MultiQueryAttention(GroupedQueryAttention):
"""Multi-Query self attention.
- Using torch or triton attention implementation enables user to also use
- additive bias.
+ Using torch attention implementation enables user to also use additive bias.
"""
def __init__(
self,
d_model: int,
n_heads: int,
- attn_impl: str = 'triton',
+ attn_impl: str = 'flash',
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
qk_gn: bool = False,
@@ -759,17 +645,16 @@ def __init__(
def attn_bias_shape(
- attn_impl: str, n_heads: int, seq_len: int, alibi: bool,
- prefix_lm: bool, causal: bool,
+ attn_impl: str, n_heads: int, seq_len: int, alibi: bool, causal: bool,
use_sequence_id: bool) -> Optional[tuple[int, int, int, int]]:
if attn_impl == 'flash':
return None
- elif attn_impl in ['torch', 'triton']:
+ elif attn_impl == 'torch':
if alibi:
- if (prefix_lm or not causal) or use_sequence_id:
+ if (not causal) or use_sequence_id:
return (1, n_heads, seq_len, seq_len)
return (1, n_heads, 1, seq_len)
- elif prefix_lm or use_sequence_id:
+ elif use_sequence_id:
return (1, 1, seq_len, seq_len)
return None
else:
@@ -787,7 +672,7 @@ def build_attn_bias(
) -> Optional[torch.Tensor]:
if attn_impl == 'flash':
return None
- elif attn_impl in ['torch', 'triton']:
+ elif attn_impl == 'torch':
if alibi:
# in place add alibi to attn bias
device, dtype = attn_bias.device, attn_bias.dtype
diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py
index 4ac43a8bac..855df7903f 100644
--- a/llmfoundry/models/layers/blocks.py
+++ b/llmfoundry/models/layers/blocks.py
@@ -20,12 +20,11 @@
attn_config_defaults: Dict = {
'attn_type': 'multihead_attention',
'attn_pdrop': 0.0,
- 'attn_impl': 'triton',
+ 'attn_impl': 'flash',
'qk_ln': False,
'qk_gn': False,
'clip_qkv': None,
'softmax_scale': None,
- 'prefix_lm': False,
'attn_uses_sequence_id': False,
'sliding_window_size': -1,
'alibi': False,
@@ -79,9 +78,9 @@ def __init__(
# necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
args_to_exclude_in_attn_class = {
- 'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id',
- 'alibi_bias_max', 'rope', 'rope_theta', 'rope_impl',
- 'rope_dail_config', 'rope_hf_config'
+ 'attn_type', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max',
+ 'rope', 'rope_theta', 'rope_impl', 'rope_dail_config',
+ 'rope_hf_config'
}
attn_config_subset_for_attn_class = {
k: v
diff --git a/llmfoundry/models/layers/flash_attn_triton.py b/llmfoundry/models/layers/flash_attn_triton.py
deleted file mode 100644
index 9276d0f917..0000000000
--- a/llmfoundry/models/layers/flash_attn_triton.py
+++ /dev/null
@@ -1,835 +0,0 @@
-"""
-Copied from https://github.com/HazyResearch/flash-attention/blob/eff9fe6b8076df59d64d7a3f464696738a3c7c24/flash_attn/flash_attn_triton.py
-update imports to use 'triton_pre_mlir'
-
-*Experimental* implementation of FlashAttention in Triton.
-Tested with triton==2.0.0.dev20221202.
-Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
-other than 64:
-https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
-We'll update this implementation with the new Triton backend once this is fixed.
-
-We use the FlashAttention implementation from Phil Tillet a starting point.
-https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
-
-Changes:
-- Implement both causal and non-causal attention.
-- Implement both self-attention and cross-attention.
-- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
-- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
-- Support attention bias.
-- Speed up the forward pass a bit, and only store the LSE instead of m and l.
-- Make the backward for d=128 much faster by reducing register spilling.
-- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
-small batch size * nheads.
-
-Caution:
-- This is an *experimental* implementation. The forward pass should be quite robust but
-I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
-- This implementation has only been tested on A100.
-- If you plan to use headdim other than 64 and 128, you should test for race conditions
-(due to the Triton compiler), as done in tests/test_flash_attn.py
-"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
-for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
-that there are none left for other head dimensions.
-
-Differences between this Triton version and the CUDA version:
-- Triton version doesn't support dropout.
-- Triton forward is generally faster than CUDA forward, while Triton backward is
-generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
-than CUDA forward + backward.
-- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
-- Triton version supports attention bias, while CUDA version doesn't.
-"""
-
-import math
-
-import torch
-
-import triton_pre_mlir as triton
-import triton_pre_mlir.language as tl
-
-
-# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128
-# @triton.autotune(
-# configs=[
-# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1),
-# # This config has a race condition when EVEN_M == False, disabling it for now.
-# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
-# ],
-# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']
-# )
-@triton.heuristics(
- {
- "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
- "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
- "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
- }
-)
-@triton.jit
-def _fwd_kernel(
- Q, K, V, Bias, Out,
- Lse, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
- softmax_scale,
- stride_qb, stride_qh, stride_qm,
- stride_kb, stride_kh, stride_kn,
- stride_vb, stride_vh, stride_vn,
- stride_bb, stride_bh, stride_bm,
- stride_ob, stride_oh, stride_om,
- nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
- CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
- BIAS_TYPE: tl.constexpr,
- IS_CAUSAL: tl.constexpr,
- BLOCK_HEADDIM: tl.constexpr,
- EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
- BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
-):
- start_m = tl.program_id(0)
- off_hb = tl.program_id(1)
- off_b = off_hb // nheads
- off_h = off_hb % nheads
- # off_b = tl.program_id(1)
- # off_h = tl.program_id(2)
- # off_hb = off_b * nheads + off_h
- # initialize offsets
- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_n = tl.arange(0, BLOCK_N)
- offs_d = tl.arange(0, BLOCK_HEADDIM)
- # Initialize pointers to Q, K, V
- # Adding parenthesis around indexing might use int32 math instead of int64 math?
- # https://github.com/openai/triton/issues/741
- # I'm seeing a tiny bit of difference (5-7us)
- q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
- k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
- v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
- if BIAS_TYPE == 'vector':
- b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
- elif BIAS_TYPE == 'matrix':
- b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
- # initialize pointer to m and l
- t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
- lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
- m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
- acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
- # load q: it will stay in SRAM throughout
- # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
- # tl.load(q_ptrs), we get the wrong output!
- if EVEN_M & EVEN_N:
- if EVEN_HEADDIM:
- q = tl.load(q_ptrs)
- else:
- q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
- else:
- if EVEN_HEADDIM:
- q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
- else:
- q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
- other=0.0)
- # loop over k, v and update accumulator
- end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
- for start_n in range(0, end_n, BLOCK_N):
- start_n = tl.multiple_of(start_n, BLOCK_N)
- # -- compute qk ----
- if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
- if EVEN_HEADDIM:
- k = tl.load(k_ptrs + start_n * stride_kn)
- else:
- k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
- else:
- if EVEN_HEADDIM:
- k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,
- other=0.0)
- else:
- k = tl.load(k_ptrs + start_n * stride_kn,
- mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
- other=0.0)
- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
- qk += tl.dot(q, k, trans_b=True)
- # Trying to combine the two masks seem to make the result wrong
- if not EVEN_N: # Need to mask out otherwise the softmax is wrong
- qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
- if IS_CAUSAL:
- qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
- if BIAS_TYPE != 'none':
- if BIAS_TYPE == 'vector':
- if EVEN_N:
- bias = tl.load(b_ptrs + start_n).to(tl.float32)
- else:
- bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32)
- bias = bias[None, :]
- elif BIAS_TYPE == 'matrix':
- if EVEN_M & EVEN_N:
- bias = tl.load(b_ptrs + start_n).to(tl.float32)
- else:
- bias = tl.load(b_ptrs + start_n,
- mask=(offs_m[:, None] < seqlen_q)
- & ((start_n + offs_n)[None, :] < seqlen_k),
- other=0.0).to(tl.float32)
- # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
- # can then fuse the mult and add into an fma instruction. But if we have bias we need to
- # to multiply with softmax_scale here.
- qk = qk * softmax_scale + bias
- m_ij = tl.maximum(tl.max(qk, 1), lse_i)
- p = tl.exp(qk - m_ij[:, None])
- else:
- m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
- p = tl.exp(qk * softmax_scale - m_ij[:, None])
- l_ij = tl.sum(p, 1)
-
- # scale acc_o
- acc_o_scale = tl.exp(m_i - m_ij)
-
- # # -- update output accumulator --
- # BUG: have to store and immediately load
- tl.store(t_ptrs, acc_o_scale)
- acc_o_scale = tl.load(t_ptrs)
- acc_o = acc_o * acc_o_scale[:, None]
- # update acc_o
- if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
- if EVEN_HEADDIM:
- v = tl.load(v_ptrs + start_n * stride_vn)
- else:
- v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
- else:
- if EVEN_HEADDIM:
- v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k,
- other=0.0)
- else:
- v = tl.load(v_ptrs + start_n * stride_vn,
- mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
- other=0.0)
- p = p.to(v.dtype)
- acc_o += tl.dot(p, v)
-
- # -- update statistics
- m_i = m_ij
- l_i_new = tl.exp(lse_i - m_ij) + l_ij
- lse_i = m_ij + tl.log(l_i_new)
-
- o_scale = tl.exp(m_i - lse_i)
- # BUG: have to store and immediately load
- tl.store(t_ptrs, o_scale)
- o_scale = tl.load(t_ptrs)
- acc_o = acc_o * o_scale[:, None]
- # rematerialize offsets to save registers
- start_m = tl.program_id(0)
- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
- # write back l and m
- lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
- tl.store(lse_ptrs, lse_i)
- # initialize pointers to output
- offs_d = tl.arange(0, BLOCK_HEADDIM)
- out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])
- if EVEN_M:
- if EVEN_HEADDIM:
- tl.store(out_ptrs, acc_o)
- else:
- tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
- else:
- if EVEN_HEADDIM:
- tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
- else:
- tl.store(out_ptrs, acc_o,
- mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
-
-
-@triton.jit
-def _bwd_preprocess_do_o_dot(
- Out, DO, Delta,
- stride_ob, stride_oh, stride_om,
- stride_dob, stride_doh, stride_dom,
- nheads, seqlen_q, seqlen_q_rounded, headdim,
- BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
-):
- start_m = tl.program_id(0)
- off_hb = tl.program_id(1)
- off_b = off_hb // nheads
- off_h = off_hb % nheads
- # initialize offsets
- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_d = tl.arange(0, BLOCK_HEADDIM)
- # load
- o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
- mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
- do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :],
- mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
- delta = tl.sum(o * do, axis=1)
- # write-back
- tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
-
-
-@triton.jit
-def _bwd_store_dk_dv(
- dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
- EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
-):
- # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
- # if we just call tl.store(dv_ptrs), there's a race condition
- if EVEN_N & EVEN_M:
- if EVEN_HEADDIM:
- tl.store(dv_ptrs, dv)
- tl.store(dk_ptrs, dk)
- else:
- tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
- tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
- else:
- if EVEN_HEADDIM:
- tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
- tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
- else:
- tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
- tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
-
-
-@triton.jit
-def _bwd_kernel_one_col_block(
- start_n,
- Q, K, V, Bias,
- DO, DQ, DK, DV,
- LSE, D,
- softmax_scale,
- stride_qm, stride_kn, stride_vn, stride_bm,
- stride_dom, stride_dqm, stride_dkn, stride_dvn,
- seqlen_q, seqlen_k, headdim,
- ATOMIC_ADD: tl.constexpr,
- BIAS_TYPE: tl.constexpr,
- IS_CAUSAL: tl.constexpr,
- BLOCK_HEADDIM: tl.constexpr,
- EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
- BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
-):
- # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
- begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
- # initialize row/col offsets
- offs_qm = begin_m + tl.arange(0, BLOCK_M)
- offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
- offs_m = tl.arange(0, BLOCK_M)
- offs_d = tl.arange(0, BLOCK_HEADDIM)
- # initialize pointers to value-like data
- q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
- k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
- v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
- do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
- dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
- if BIAS_TYPE == 'vector':
- b_ptrs = Bias + offs_n
- elif BIAS_TYPE == 'matrix':
- b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
- # initialize dv and dk
- dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
- dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
- # There seems to be some problem with Triton pipelining that makes results wrong for
- # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop
- # may have zero step, and pipelining with the bias matrix could screw it up.
- # So we just exit early.
- if begin_m >= seqlen_q:
- dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
- dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
- _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
- EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
- return
- # k and v stay in SRAM throughout
- # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
- # if we just call tl.load(k_ptrs), we get the wrong output!
- if EVEN_N & EVEN_M:
- if EVEN_HEADDIM:
- k = tl.load(k_ptrs)
- v = tl.load(v_ptrs)
- else:
- k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
- v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
- else:
- if EVEN_HEADDIM:
- k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
- v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
- else:
- k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
- other=0.0)
- v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
- other=0.0)
- # loop over rows
- num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
- for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
- start_m = tl.multiple_of(start_m, BLOCK_M)
- offs_m_curr = start_m + offs_m
- # load q, k, v, do on-chip
- # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117)
- if EVEN_M & EVEN_HEADDIM:
- q = tl.load(q_ptrs)
- else:
- if EVEN_HEADDIM:
- q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
- else:
- q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
- & (offs_d[None, :] < headdim), other=0.0)
- # recompute p = softmax(qk, dim=-1).T
- qk = tl.dot(q, k, trans_b=True)
- # Trying to combine the two masks seem to make the result wrong
- if not EVEN_N: # Need to mask out otherwise the softmax is wrong
- qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
- if IS_CAUSAL:
- qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
- if BIAS_TYPE != 'none':
- tl.debug_barrier() # Race condition otherwise
- if BIAS_TYPE == 'vector':
- if EVEN_N:
- bias = tl.load(b_ptrs).to(tl.float32)
- else:
- bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
- bias = bias[None, :]
- elif BIAS_TYPE == 'matrix':
- if EVEN_M & EVEN_N:
- bias = tl.load(b_ptrs).to(tl.float32)
- else:
- bias = tl.load(b_ptrs,
- mask=(offs_m_curr[:, None] < seqlen_q)
- & (offs_n[None, :] < seqlen_k),
- other=0.0).to(tl.float32)
- qk = qk * softmax_scale + bias
- # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
- # Also wrong for headdim=64.
- if not (EVEN_M & EVEN_HEADDIM):
- tl.debug_barrier()
- lse_i = tl.load(LSE + offs_m_curr)
- if BIAS_TYPE == 'none':
- p = tl.exp(qk * softmax_scale - lse_i[:, None])
- else:
- p = tl.exp(qk - lse_i[:, None])
- # compute dv
- # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
- # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
- # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512,
- # the output is correct.
- if EVEN_M & EVEN_HEADDIM:
- do = tl.load(do_ptrs)
- else:
- # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
- do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
- & (offs_d[None, :] < headdim), other=0.0)
- # if EVEN_M:
- # if EVEN_HEADDIM:
- # do = tl.load(do_ptrs)
- # else:
- # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
- # else:
- # if EVEN_HEADDIM:
- # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
- # else:
- # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
- # & (offs_d[None, :] < headdim), other=0.0)
- dv += tl.dot(p.to(do.dtype), do, trans_a=True)
- # compute dp = dot(v, do)
- # There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
- # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
- # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
- if not (EVEN_M & EVEN_HEADDIM):
- tl.debug_barrier()
- dp = tl.dot(do, v, trans_b=True)
- # There's a race condition for headdim=48
- if not EVEN_HEADDIM:
- tl.debug_barrier()
- # compute ds = p * (dp - delta[:, None])
- # Putting the subtraction after the dp matmul (instead of before) is slightly faster
- Di = tl.load(D + offs_m_curr)
- # Converting ds to q.dtype here reduces register pressure and makes it much faster
- # for BLOCK_HEADDIM=128
- ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
- # compute dk = dot(ds.T, q)
- dk += tl.dot(ds, q, trans_a=True)
- # compute dq
- if not (EVEN_M & EVEN_HEADDIM): # Otherewise there's a race condition when BIAS_TYPE='matrix'
- tl.debug_barrier()
- if not ATOMIC_ADD:
- if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
- dq = tl.load(dq_ptrs, eviction_policy="evict_last")
- dq += tl.dot(ds, k)
- tl.store(dq_ptrs, dq, eviction_policy="evict_last")
- else:
- if EVEN_HEADDIM:
- dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0,
- eviction_policy="evict_last")
- dq += tl.dot(ds, k)
- tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q,
- eviction_policy="evict_last")
- else:
- dq = tl.load(dq_ptrs,
- mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
- other=0.0, eviction_policy="evict_last")
- dq += tl.dot(ds, k)
- tl.store(dq_ptrs, dq,
- mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
- eviction_policy="evict_last")
- else: # If we're parallelizing across the seqlen_k dimension
- dq = tl.dot(ds, k)
- if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
- tl.atomic_add(dq_ptrs, dq)
- else:
- if EVEN_HEADDIM:
- tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
- else:
- tl.atomic_add(dq_ptrs, dq,
- mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
- # increment pointers
- dq_ptrs += BLOCK_M * stride_dqm
- q_ptrs += BLOCK_M * stride_qm
- do_ptrs += BLOCK_M * stride_dom
- if BIAS_TYPE == 'matrix':
- b_ptrs += BLOCK_M * stride_bm
- # write-back
- dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
- dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
- _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
- EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
-
-
-def init_to_zero(name):
- return lambda nargs: nargs[name].zero_()
-
-
-@triton.autotune(
- configs=[
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
- # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
- # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
- # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
- # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
- # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
- # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
- ],
- key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'],
-)
-@triton.heuristics(
- {
- "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
- "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
- "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
- }
-)
-@triton.jit
-def _bwd_kernel(
- Q, K, V, Bias,
- DO, DQ, DK, DV,
- LSE, D,
- softmax_scale,
- stride_qb, stride_qh, stride_qm,
- stride_kb, stride_kh, stride_kn,
- stride_vb, stride_vh, stride_vn,
- stride_bb, stride_bh, stride_bm,
- stride_dob, stride_doh, stride_dom,
- stride_dqb, stride_dqh, stride_dqm,
- stride_dkb, stride_dkh, stride_dkn,
- stride_dvb, stride_dvh, stride_dvn,
- nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
- CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
- BIAS_TYPE: tl.constexpr,
- IS_CAUSAL: tl.constexpr,
- BLOCK_HEADDIM: tl.constexpr,
- SEQUENCE_PARALLEL: tl.constexpr,
- EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
- BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
-):
- off_hb = tl.program_id(1)
- off_b = off_hb // nheads
- off_h = off_hb % nheads
- # offset pointers for batch/head
- Q += off_b * stride_qb + off_h * stride_qh
- K += off_b * stride_kb + off_h * stride_kh
- V += off_b * stride_vb + off_h * stride_vh
- DO += off_b * stride_dob + off_h * stride_doh
- DQ += off_b * stride_dqb + off_h * stride_dqh
- DK += off_b * stride_dkb + off_h * stride_dkh
- DV += off_b * stride_dvb + off_h * stride_dvh
- if BIAS_TYPE != 'none':
- Bias += off_b * stride_bb + off_h * stride_bh
- # pointer to row-wise quantities in value-like data
- D += off_hb * seqlen_q_rounded
- LSE += off_hb * seqlen_q_rounded
- if not SEQUENCE_PARALLEL:
- num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
- for start_n in range(0, num_block_n):
- _bwd_kernel_one_col_block(
- start_n,
- Q, K, V, Bias,
- DO, DQ, DK, DV,
- LSE, D,
- softmax_scale,
- stride_qm, stride_kn, stride_vn, stride_bm,
- stride_dom, stride_dqm, stride_dkn, stride_dvn,
- seqlen_q, seqlen_k, headdim,
- ATOMIC_ADD=False,
- BIAS_TYPE=BIAS_TYPE,
- IS_CAUSAL=IS_CAUSAL,
- BLOCK_HEADDIM=BLOCK_HEADDIM,
- EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
- )
- else:
- start_n = tl.program_id(0)
- _bwd_kernel_one_col_block(
- start_n,
- Q, K, V, Bias,
- DO, DQ, DK, DV,
- LSE, D,
- softmax_scale,
- stride_qm, stride_kn, stride_vn, stride_bm,
- stride_dom, stride_dqm, stride_dkn, stride_dvn,
- seqlen_q, seqlen_k, headdim,
- ATOMIC_ADD=True,
- BIAS_TYPE=BIAS_TYPE,
- IS_CAUSAL=IS_CAUSAL,
- BLOCK_HEADDIM=BLOCK_HEADDIM,
- EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
- )
-
-
-def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
- # shape constraints
- batch, seqlen_q, nheads, d = q.shape
- _, seqlen_k, _, _ = k.shape
- assert k.shape == (batch, seqlen_k, nheads, d)
- assert v.shape == (batch, seqlen_k, nheads, d)
- assert d <= 128, 'FlashAttention only support head dimensions up to 128'
- assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
- assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16'
- assert q.is_cuda and k.is_cuda and v.is_cuda
- softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
-
- has_bias = bias is not None
- bias_type = 'none'
- if has_bias:
- assert bias.dtype in [q.dtype, torch.float]
- assert bias.is_cuda
- assert bias.dim() == 4
- if bias.stride(-1) != 1:
- bias = bias.contiguous()
- if bias.shape[2:] == (1, seqlen_k):
- bias_type = 'vector'
- elif bias.shape[2:] == (seqlen_q, seqlen_k):
- bias_type = 'matrix'
- else:
- raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
- ' or (seqlen_q, seqlen_k)')
- bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
- bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
-
- seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
- lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
- tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
- o = torch.empty_like(q)
-
- BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
- BLOCK = 128
- num_warps = 4 if d <= 64 else 8
- grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
- _fwd_kernel[grid](
- q, k, v, bias, o,
- lse, tmp,
- softmax_scale,
- q.stride(0), q.stride(2), q.stride(1),
- k.stride(0), k.stride(2), k.stride(1),
- v.stride(0), v.stride(2), v.stride(1),
- *bias_strides,
- o.stride(0), o.stride(2), o.stride(1),
- nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
- seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
- # Can't use kwargs here because triton autotune expects key to be args, not kwargs
- # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
- bias_type, causal, BLOCK_HEADDIM,
- BLOCK_M=BLOCK, BLOCK_N=BLOCK,
- num_warps=num_warps,
- num_stages=1,
- )
- return o, lse, softmax_scale # softmax_scale could have been updated
-
-
-def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):
- # Make sure that the last dimension is contiguous
- if do.stride(-1) != 1:
- do = do.contiguous()
- batch, seqlen_q, nheads, d = q.shape
- _, seqlen_k, _, _ = k.shape
- # assert d in {16, 32, 64, 128}
- assert d <= 128
- seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
- assert lse.shape == (batch, nheads, seqlen_q_rounded)
- assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
- assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
- softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
- # dq_accum = torch.zeros_like(q, dtype=torch.float32)
- dq_accum = torch.empty_like(q, dtype=torch.float32)
- delta = torch.empty_like(lse)
- # delta = torch.zeros_like(lse)
-
- BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
- grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
- _bwd_preprocess_do_o_dot[grid](
- o, do, delta,
- o.stride(0), o.stride(2), o.stride(1),
- do.stride(0), do.stride(2), do.stride(1),
- nheads, seqlen_q, seqlen_q_rounded, d,
- BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM,
- )
-
- has_bias = bias is not None
- bias_type = 'none'
- if has_bias:
- assert bias.dtype in [q.dtype, torch.float]
- assert bias.is_cuda
- assert bias.dim() == 4
- assert bias.stride(-1) == 1
- if bias.shape[2:] == (1, seqlen_k):
- bias_type = 'vector'
- elif bias.shape[2:] == (seqlen_q, seqlen_k):
- bias_type = 'matrix'
- else:
- raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
- ' or (seqlen_q, seqlen_k)')
- bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
- bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
-
- # BLOCK_M = 128
- # BLOCK_N = 64
- # num_warps = 4
- grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1,
- batch * nheads)
- _bwd_kernel[grid](
- q, k, v, bias,
- do, dq_accum, dk, dv,
- lse, delta,
- softmax_scale,
- q.stride(0), q.stride(2), q.stride(1),
- k.stride(0), k.stride(2), k.stride(1),
- v.stride(0), v.stride(2), v.stride(1),
- *bias_strides,
- do.stride(0), do.stride(2), do.stride(1),
- dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1),
- dk.stride(0), dk.stride(2), dk.stride(1),
- dv.stride(0), dv.stride(2), dv.stride(1),
- nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
- seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
- # Can't use kwargs here because triton autotune expects key to be args, not kwargs
- # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
- bias_type, causal, BLOCK_HEADDIM,
- # SEQUENCE_PARALLEL=False,
- # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
- # num_warps=num_warps,
- # num_stages=1,
- )
- dq.copy_(dq_accum)
-
-
-class FlashAttnQKVPackedFunc(torch.autograd.Function):
-
- @staticmethod
- def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
- """
- qkv: (batch, seqlen, 3, nheads, headdim)
- bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
- For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
- ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
- """
- # Make sure that the last dimension is contiguous
- if qkv.stride(-1) != 1:
- qkv = qkv.contiguous()
- o, lse, ctx.softmax_scale = _flash_attn_forward(
- qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal,
- softmax_scale=softmax_scale
- )
- ctx.save_for_backward(qkv, o, lse, bias)
- ctx.causal = causal
- return o
-
- @staticmethod
- def backward(ctx, do):
- qkv, o, lse, bias = ctx.saved_tensors
- assert not ctx.needs_input_grad[1], 'FlashAttention does not support bias gradient yet'
- # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
- # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
- with torch.inference_mode():
- dqkv = torch.empty_like(qkv)
- _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse,
- dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
- bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
- return dqkv, None, None, None
-
-
-flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
-
-
-class FlashAttnKVPackedFunc(torch.autograd.Function):
-
- @staticmethod
- def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
- """
- q: (batch, seqlen_q, nheads, headdim)
- kv: (batch, seqlen_k, 2, nheads, headdim)
- bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
- For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
- ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
- """
- # Make sure that the last dimension is contiguous
- q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
- o, lse, ctx.softmax_scale = _flash_attn_forward(
- q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale
- )
- ctx.save_for_backward(q, kv, o, lse, bias)
- ctx.causal = causal
- return o
-
- @staticmethod
- def backward(ctx, do):
- q, kv, o, lse, bias = ctx.saved_tensors
- if len(ctx.needs_input_grad) >= 3:
- assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet'
- # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
- # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
- with torch.inference_mode():
- dq = torch.empty_like(q)
- dkv = torch.empty_like(kv)
- _flash_attn_backward(do, q, kv[:, :, 0], kv[:, :, 1], o, lse,
- dq, dkv[:, :, 0], dkv[:, :, 1],
- bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
- return dq, dkv, None, None, None
-
-
-flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
-
-
-class FlashAttnFunc(torch.autograd.Function):
-
- @staticmethod
- def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
- """
- q: (batch_size, seqlen_q, nheads, headdim)
- k, v: (batch_size, seqlen_k, nheads, headdim)
- bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
- For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
- ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
- """
- # Make sure that the last dimension is contiguous
- q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
- o, lse, ctx.softmax_scale = _flash_attn_forward(
- q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
- )
- ctx.save_for_backward(q, k, v, o, lse, bias)
- ctx.causal = causal
- return o
-
- @staticmethod
- def backward(ctx, do):
- q, k, v, o, lse, bias = ctx.saved_tensors
- assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet'
- # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
- # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
- with torch.inference_mode():
- dq = torch.empty_like(q)
- dk = torch.empty_like(k)
- dv = torch.empty_like(v)
- _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv,
- bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
- return dq, dk, dv, None, None, None
-
-
-flash_attn_func = FlashAttnFunc.apply
diff --git a/llmfoundry/models/layers/llama_attention_monkeypatch.py b/llmfoundry/models/layers/llama_attention_monkeypatch.py
deleted file mode 100644
index eb4a65cd62..0000000000
--- a/llmfoundry/models/layers/llama_attention_monkeypatch.py
+++ /dev/null
@@ -1,317 +0,0 @@
-# Copyright 2022 MosaicML LLM Foundry authors
-# SPDX-License-Identifier: Apache-2.0
-
-# This file is copied and modified from
-# https://github.com/huggingface/transformers/blob/fe3c8ab1af558b95f67f5fafc0c55f09fd2b09db/src/transformers/models/llama/modeling_llama.py
-# See the clearly denoted code blocks for the main modifications (there are a few others like type ignores, and error messages)
-
-import logging
-from typing import Callable, Optional, Tuple
-
-import torch
-import torch.nn.functional as F
-from transformers.models.llama.modeling_llama import LlamaAttention
-
-from llmfoundry.models.layers.attention import (
- scaled_multihead_dot_product_attention, triton_flash_attn_fn)
-
-log = logging.getLogger(__name__)
-
-
-def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """Equivalent of torch.repeat_interleave(x, dim=1,
-
- repeats=n_rep).
-
- The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
- (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :,
- None, :, :].expand(batch, num_key_value_heads,
- n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
- head_dim)
-
-
-def rotate_half(x: torch.Tensor) -> torch.Tensor:
- """Rotates half the hidden dims of the input."""
- x1 = x[..., :x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2:]
- return torch.cat((-x2, x1), dim=-1)
-
-
-def apply_rotary_pos_emb(
- q: torch.Tensor,
- k: torch.Tensor,
- cos: torch.Tensor,
- sin: torch.Tensor,
- position_ids: Optional[torch.Tensor] = None,
- unsqueeze_dim: int = 1,
-):
- """Applies Rotary Position Embedding to the query and key tensors.
-
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`, *optional*):
- Deprecated and unused.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
-
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
-
-
-def get_llama_attention_patch_fn(patch_fn_name: str = 'torch') -> Callable:
- if patch_fn_name == 'torch':
- return llama_attention_patch_torch
- elif patch_fn_name == 'triton':
- return llama_attention_patch_triton
- else:
- raise ValueError(
- f'Unrecognized llama attention patch function: {patch_fn_name}')
-
-
-def llama_attention_patch_torch(
- self: LlamaAttention,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- if use_cache:
- raise NotImplementedError(
- 'use_cache is not yet supported when patching Llama attention.')
-
- bsz, q_len, _ = hidden_states.size()
-
- if self.config.pretraining_tp > 1:
- key_value_slicing = (self.num_key_value_heads *
- self.head_dim) // self.config.pretraining_tp
- query_slices = self.q_proj.weight.split(
- (self.num_heads * self.head_dim) // self.config.pretraining_tp,
- dim=0)
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
-
- query_states = [
- F.linear(hidden_states, query_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- query_states = torch.cat(query_states, dim=-1)
-
- key_states = [
- F.linear(hidden_states, key_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- key_states = torch.cat(key_states, dim=-1)
-
- value_states = [
- F.linear(hidden_states, value_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- value_states = torch.cat(value_states, dim=-1)
- else:
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
- cos, sin = self.rotary_emb(value_states, position_ids)
-
- query_states, key_states = apply_rotary_pos_emb(
- q=query_states,
- k=key_states,
- cos=cos,
- sin=sin,
- position_ids=None,
- )
-
- ### MAIN MODIFICATIONS START HERE ###
- query_states = query_states.transpose(1, 2).view(
- bsz, q_len, self.num_heads * self.head_dim)
- key_states = key_states.transpose(1, 2).view(
- bsz, q_len, self.num_key_value_heads * self.head_dim)
- value_states = value_states.transpose(1, 2).view(
- bsz, q_len, self.num_key_value_heads * self.head_dim)
-
- attn_output, attn_weights, _ = scaled_multihead_dot_product_attention(
- query=query_states,
- key=key_states,
- value=value_states,
- n_heads=self.num_heads,
- kv_n_heads=self.num_key_value_heads,
- past_key_value=None,
- softmax_scale=None,
- attn_bias=attention_mask,
- key_padding_mask=None,
- is_causal=False, # The causal mask is propagated from LLamaForCausalLM
- dropout_p=0,
- training=self.training,
- needs_weights=False,
- )
- ### MAIN MODIFICATIONS END HERE ###
-
- if self.config.pretraining_tp > 1:
- attn_output = attn_output.split(self.hidden_size //
- self.config.pretraining_tp,
- dim=2)
- o_proj_slices = self.o_proj.weight.split(self.hidden_size //
- self.config.pretraining_tp,
- dim=1)
- attn_output = sum([
- F.linear(attn_output[i], o_proj_slices[i])
- for i in range(self.config.pretraining_tp)
- ])
- else:
- attn_output = self.o_proj(attn_output)
-
- assert isinstance(attn_output, torch.Tensor)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, None
-
-
-def llama_attention_patch_triton(
- self: LlamaAttention,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- if use_cache:
- raise NotImplementedError(
- 'use_cache is not yet supported when patching Llama attention.')
- # output_attentions is not support for triton attention
- if output_attentions:
- raise NotImplementedError(
- 'output_attentions is not supported when patching Llama attention with triton attention.'
- )
-
- bsz, q_len, _ = hidden_states.size()
-
- if self.config.pretraining_tp > 1:
- key_value_slicing = (self.num_key_value_heads *
- self.head_dim) // self.config.pretraining_tp
- query_slices = self.q_proj.weight.split(
- (self.num_heads * self.head_dim) // self.config.pretraining_tp,
- dim=0)
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
-
- query_states = [
- F.linear(hidden_states, query_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- query_states = torch.cat(query_states, dim=-1)
-
- key_states = [
- F.linear(hidden_states, key_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- key_states = torch.cat(key_states, dim=-1)
-
- value_states = [
- F.linear(hidden_states, value_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- value_states = torch.cat(value_states, dim=-1)
- else:
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
- cos, sin = self.rotary_emb(value_states, position_ids)
- query_states, key_states = apply_rotary_pos_emb(
- q=query_states,
- k=key_states,
- cos=cos,
- sin=sin,
- position_ids=None,
- )
-
- ### MAIN MODIFICATIONS START HERE ###
- query_states = query_states.transpose(1, 2).view(
- bsz, q_len, self.num_heads * self.head_dim)
- key_states = key_states.transpose(1, 2).view(
- bsz, q_len, self.num_key_value_heads * self.head_dim)
- value_states = value_states.transpose(1, 2).view(
- bsz, q_len, self.num_key_value_heads * self.head_dim)
-
- attn_output, _, _ = triton_flash_attn_fn(
- query=query_states,
- key=key_states,
- value=value_states,
- n_heads=self.num_heads,
- kv_n_heads=self.num_key_value_heads,
- past_key_value=None,
- softmax_scale=None,
- attn_bias=attention_mask,
- key_padding_mask=None,
- is_causal=False, # The causal mask is propagated from LLamaForCausalLM
- dropout_p=0,
- training=self.training,
- needs_weights=False,
- )
- ### MAIN MODIFICATIONS END HERE ###
-
- if self.config.pretraining_tp > 1:
- attn_output = attn_output.split(self.hidden_size //
- self.config.pretraining_tp,
- dim=2)
- o_proj_slices = self.o_proj.weight.split(self.hidden_size //
- self.config.pretraining_tp,
- dim=1)
- attn_output = sum([
- F.linear(attn_output[i], o_proj_slices[i])
- for i in range(self.config.pretraining_tp)
- ])
- else:
- attn_output = self.o_proj(attn_output)
-
- assert isinstance(attn_output, torch.Tensor)
-
- return attn_output, None, None
diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py
index b5a099002e..20c3850a82 100644
--- a/llmfoundry/models/mpt/configuration_mpt.py
+++ b/llmfoundry/models/mpt/configuration_mpt.py
@@ -20,8 +20,6 @@
from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note)
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note)
-from llmfoundry.utils.warnings import VersionedDeprecationWarning
-
ffn_config_defaults: Dict = {
'ffn_type': 'mptmlp',
}
@@ -81,16 +79,13 @@ def __init__(
attn_config (Dict): A dictionary used to configure the model's attention module:
attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention, grouped_query_attention
attn_pdrop (float): The dropout probability for the attention layers.
- attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
+ attn_impl (str): The attention implementation to use. One of 'torch' or 'flash'.
qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
qk_gn (bool): Whether to apply group normalization to the queries and keys in the attention layer.
clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
this value.
softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
use the default scale of ``1/sqrt(d_keys)``.
- prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
- extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
- can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
which sub-sequence each token belongs to.
@@ -210,43 +205,20 @@ def _validate_config(self) -> None:
raise ValueError(
"self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1"
)
- if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
+ if self.attn_config['attn_impl'] not in ['torch', 'flash']:
raise ValueError(
f"Unknown attn_impl={self.attn_config['attn_impl']}")
- if self.attn_config['prefix_lm']:
- warnings.warn(
- VersionedDeprecationWarning(
- 'Support for Prefix Language Models is deprecated.',
- remove_version='0.7.0'))
- if self.attn_config['attn_impl'] == 'triton':
- warnings.warn(
- VersionedDeprecationWarning(
- 'Support for triton attention is deprecated. Please use torch or flash attention.',
- remove_version='0.7.0'))
-
- if self.attn_config['prefix_lm'] and self.attn_config[
- 'attn_impl'] not in ['torch', 'triton']:
- raise NotImplementedError(
- 'prefix_lm only implemented with torch and triton attention.')
-
- if self.attn_config[
- 'attn_impl'] == 'triton' and not self.attn_config['prefix_lm']:
- warnings.warn(
- UserWarning(
- 'If not using a Prefix Language Model, we recommend setting "attn_impl" to "flash" instead of "triton".'
- ))
-
if self.attn_config['alibi'] and not check_alibi_support(
self.attn_config['attn_impl']):
raise NotImplementedError(
- 'alibi only implemented with torch, triton, and flash (v2.4.2 or higher) attention.'
+ 'alibi only implemented with torch and flash (v2.4.2 or higher) attention.'
)
if self.attn_config['attn_uses_sequence_id'] and not (
- self.attn_config['attn_impl'] in ['torch', 'triton'] or
+ self.attn_config['attn_impl'] == 'torch' or
(self.attn_config['attn_impl'] == 'flash' and
is_flash_v2_installed(v2_version='v2.1.2'))):
raise NotImplementedError(
- 'attn_uses_sequence_id only implemented with torch, triton, and flash (v2.1.2 or higher) attention.'
+ 'attn_uses_sequence_id only implemented with torch and flash (v2.1.2 or higher) attention.'
)
if self.attn_config['rope'] and (self.attn_config['rope_impl']
not in ['dail', 'hf']):
diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py
index bf93bd71f0..e0a666f62c 100644
--- a/llmfoundry/models/mpt/modeling_mpt.py
+++ b/llmfoundry/models/mpt/modeling_mpt.py
@@ -53,14 +53,6 @@
# HuggingFace can detect all the needed files to copy into its modules folder.
# Otherwise, certain modules are missing.
# isort: off
-from llmfoundry.models.utils.adapt_tokenizer import (
- AutoTokenizerForMOD, # type: ignore (see note)
- adapt_tokenizer_for_denoising, # type: ignore (see note)
-)
-from llmfoundry.models.utils.hf_prefixlm_converter import (
- add_bidirectional_mask_if_missing, # type: ignore (see note)
- convert_hf_causal_lm_to_prefix_lm, # type: ignore (see note)
-)
from llmfoundry.models.utils.meta_init_context import \
init_empty_weights # type: ignore (see note)
from llmfoundry.models.utils.param_init_fns import (
@@ -72,12 +64,6 @@
build_act_ckpt_mod_to_blocks,
check_mapping_blocks_overlap)
-try:
- from llmfoundry.models.layers.flash_attn_triton import flash_attn_func as flash_attn_func
-except:
- pass
-# isort: on
-
import logging
log = logging.getLogger(__name__)
@@ -299,7 +285,6 @@ def __init__(self, config: MPTConfig):
super().__init__(config)
self.attn_impl = config.attn_config['attn_impl']
- self.prefix_lm = config.attn_config['prefix_lm']
self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
self.alibi = config.attn_config['alibi']
self.alibi_bias_max = config.attn_config['alibi_bias_max']
@@ -364,7 +349,7 @@ def __init__(self, config: MPTConfig):
)
self.apply(self.param_init_fn)
- self.is_causal = not self.prefix_lm
+ self.is_causal = True
# define attn mask
self._attn_bias_initialized = False
@@ -374,7 +359,6 @@ def __init__(self, config: MPTConfig):
config.n_heads,
config.max_seq_len,
self.alibi,
- prefix_lm=self.prefix_lm,
causal=self.is_causal,
use_sequence_id=self.attn_uses_sequence_id,
)
@@ -407,7 +391,6 @@ def _attn_bias(
device: torch.device,
dtype: torch.dtype,
attention_mask: Optional[torch.ByteTensor] = None,
- prefix_mask: Optional[torch.ByteTensor] = None,
sequence_id: Optional[torch.LongTensor] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.ByteTensor]]:
if not self._attn_bias_initialized:
@@ -426,8 +409,7 @@ def _attn_bias(
)
self._attn_bias_initialized = True
- # flash does not support prefix_lm and will incorporate any
- # attention_mask inside the attention module
+ # flash will incorporate any attention_mask inside the attention module
if self.attn_impl == 'flash':
return self.attn_bias, attention_mask
@@ -438,19 +420,13 @@ def _attn_bias(
attn_bias = self.attn_bias
- # If using torch or triton, we incorporate the prefix_mask (if appropriate)
- if self.prefix_lm:
- assert isinstance(attn_bias, torch.Tensor) # pyright
- assert isinstance(prefix_mask, torch.Tensor) # pyright
- attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
-
- # If using torch or triton, we incorporate sequence_id (if appropriate)
+ # If using torch, we incorporate sequence_id (if appropriate)
if self.attn_uses_sequence_id and sequence_id is not None:
assert isinstance(attn_bias, torch.Tensor) # pyright
attn_bias = apply_sequence_id(attn_bias, sequence_id,
self.config.max_seq_len)
- # If using torch or triton, we incorporate attention_mask. This will output
+ # If using torch, we incorporate attention_mask. This will output
# None in place of attention_mask since it will not be further needed in the
# attention modules.
if attention_mask is not None:
@@ -463,54 +439,17 @@ def _attn_bias(
# clamp to 0 necessary for torch 2.0 compile()
_s_k = max(0, attn_bias.size(-1) - s_k)
attn_bias = attn_bias[:, :, :, _s_k:]
- if prefix_mask is not None and (attention_mask.shape !=
- prefix_mask.shape):
- raise ValueError(
- f'attention_mask shape={attention_mask.shape} ' +
- f'and prefix_mask shape={prefix_mask.shape} are not equal.')
min_val = torch.finfo(attn_bias.dtype).min
attn_bias = attn_bias.masked_fill(
~attention_mask.view(-1, 1, 1, s_k), min_val)
return attn_bias, attention_mask
- def _apply_prefix_mask(self, attn_bias: torch.Tensor,
- prefix_mask: torch.Tensor) -> torch.Tensor:
- s_k, s_q = attn_bias.shape[-2:]
- if (s_k != self.config.max_seq_len) or (s_q != self.config.max_seq_len):
- raise ValueError(
- 'attn_bias does not match the expected shape. ' +
- f'The last two dimensions should both be {self.config.max_length} '
- + f'but are {s_k} and {s_q}.')
- seq_len = prefix_mask.shape[-1]
- if seq_len > self.config.max_seq_len:
- raise ValueError(
- f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}'
- )
-
- # select seq_len subset of attn mask
- attn_bias = attn_bias[..., :seq_len, :seq_len]
-
- # Mix the causal max and the bidirectional mask to get the full
- # allowable attention (i.e. full = not accounting for padding yet)
- causal = torch.tril(
- torch.ones((seq_len, seq_len),
- dtype=torch.bool,
- device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
- prefix = prefix_mask.view(-1, 1, 1, seq_len)
- cannot_attend = ~torch.logical_or(causal, prefix.bool())
-
- min_val = torch.finfo(attn_bias.dtype).min
- attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
-
- return attn_bias
-
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
- prefix_mask: Optional[torch.ByteTensor] = None,
sequence_id: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
output_attentions: Optional[bool] = None,
@@ -526,9 +465,6 @@ def forward(
if attention_mask is not None:
attention_mask = attention_mask.bool() # type: ignore
- if prefix_mask is not None:
- prefix_mask = prefix_mask.bool() # type: ignore
-
# These args are passed in by keyword in huggingface's generate function
# https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206
# but have not yet been fully implemented in MPTModel
@@ -538,7 +474,7 @@ def forward(
if output_attentions:
if self.attn_impl != 'torch':
raise NotImplementedError(
- 'output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.'
+ 'output_attentions is not implemented for MPT when using attn_impl `flash`.'
)
if (self.training and attention_mask is not None and
@@ -546,11 +482,6 @@ def forward(
raise NotImplementedError(
'MPT does not support training with left padding.')
- if self.prefix_lm and prefix_mask is None:
- raise ValueError(
- 'prefix_mask is a required argument when MPT is configured with prefix_lm=True.'
- )
-
if self.training:
if self.attn_uses_sequence_id and sequence_id is None:
raise ValueError(
@@ -594,8 +525,8 @@ def forward(
+
f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).'
)
- # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
- # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
+ # For attn_impl: flash, the past key tensor spec is (batch, seq, dim).
+ # For attn_impl: torch, the past key tensor spec is (batch, heads, head_dim, seq).
# Here we shift position embedding using the `seq` dim of the past key
past_position = past_key_values[0][0].size(1)
if self.attn_impl == 'torch':
@@ -654,7 +585,6 @@ def forward(
device=x.device,
dtype=torch.float32,
attention_mask=attention_mask,
- prefix_mask=prefix_mask,
sequence_id=sequence_id,
)
attention_mask_in_length = gen_attention_mask_in_length(
@@ -825,7 +755,6 @@ def forward(
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
- prefix_mask: Optional[torch.ByteTensor] = None,
sequence_id: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
@@ -843,7 +772,6 @@ def forward(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
- prefix_mask=prefix_mask,
sequence_id=sequence_id,
return_dict=return_dict,
output_attentions=output_attentions,
@@ -975,16 +903,6 @@ def prepare_inputs_for_generation(
if past_key_values is not None:
input_ids = input_ids[:, -1].unsqueeze(-1)
- if self.transformer.prefix_lm:
- # Leverage a convenience of sequential generation!
- prefix_mask = torch.ones_like(attention_mask)
- # This requires that we're using the cache
- if kwargs.get('use_cache') == False:
- raise NotImplementedError(
- 'MPT with prefix_lm=True does not support use_cache=False.')
- else:
- prefix_mask = None
-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
@@ -993,7 +911,6 @@ def prepare_inputs_for_generation(
model_inputs.update({
'attention_mask': attention_mask,
- 'prefix_mask': prefix_mask,
'sequence_id': sequence_id,
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache', True),
@@ -1089,13 +1006,9 @@ def get_targets(self, batch: Mapping) -> torch.Tensor:
return targets
def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast:
- if self.model.transformer.prefix_lm:
- add_bidirectional_mask_if_missing(batch)
- # Note: prefix_mask is only used if model.prefix_lm is True
return self.model(
input_ids=batch.get('input_ids', None),
attention_mask=batch.get('attention_mask', None),
- prefix_mask=batch.get('bidirectional_mask', None),
sequence_id=batch.get('sequence_id', None),
inputs_embeds=batch.get('inputs_embeds', None),
)
diff --git a/llmfoundry/models/utils/__init__.py b/llmfoundry/models/utils/__init__.py
index 35a15e530a..7c808ff449 100644
--- a/llmfoundry/models/utils/__init__.py
+++ b/llmfoundry/models/utils/__init__.py
@@ -1,22 +1,14 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
-from llmfoundry.models.utils.adapt_tokenizer import (
- AutoTokenizerForMOD, adapt_tokenizer_for_denoising)
-from llmfoundry.models.utils.hf_prefixlm_converter import (
- add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm)
from llmfoundry.models.utils.meta_init_context import (init_empty_weights,
init_on_device)
from llmfoundry.models.utils.param_init_fns import (MODEL_INIT_REGISTRY,
generic_param_init_fn_)
__all__ = [
- 'AutoTokenizerForMOD',
- 'adapt_tokenizer_for_denoising',
- 'convert_hf_causal_lm_to_prefix_lm',
'init_empty_weights',
'init_on_device',
- 'add_bidirectional_mask_if_missing',
'generic_param_init_fn_',
'MODEL_INIT_REGISTRY',
]
diff --git a/llmfoundry/models/utils/adapt_tokenizer.py b/llmfoundry/models/utils/adapt_tokenizer.py
deleted file mode 100644
index 8cb0c33697..0000000000
--- a/llmfoundry/models/utils/adapt_tokenizer.py
+++ /dev/null
@@ -1,57 +0,0 @@
-# Copyright 2022 MosaicML LLM Foundry authors
-# SPDX-License-Identifier: Apache-2.0
-
-from typing import Any
-
-from transformers import AutoTokenizer, PreTrainedTokenizerBase
-
-# For consistency with T5 Tokenizer, which is what this adaptation aims to mimic,
-# we hardcode there to be 100 sentinel tokens
-NUM_SENTINEL_TOKENS: int = 100
-
-
-def adapt_tokenizer_for_denoising(tokenizer: PreTrainedTokenizerBase) -> None:
- """Adds sentinel tokens and padding token (if missing).
-
- Expands the tokenizer vocabulary to include sentinel tokens
- used in mixture-of-denoiser tasks as well as a padding token.
-
- All added tokens are added as special tokens. No tokens are
- added if sentinel tokens and padding token already exist.
- """
- # Add sentinel tokens (e.g., , , and so on). Has no effect if these are already in the vocab.
- sentinels_to_add = [f'' for i in range(NUM_SENTINEL_TOKENS)]
- tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
-
- # If the padding token has not been set, add and use it
- if tokenizer.pad_token is None:
- tokenizer.add_tokens('', special_tokens=True)
- tokenizer.pad_token = ''
- assert tokenizer.pad_token_id is not None
-
- # Register a property that gets us the ids of the sentinel tokens
- sentinels = ''.join([f'' for i in range(NUM_SENTINEL_TOKENS)])
- _sentinel_token_ids = tokenizer(sentinels,
- add_special_tokens=False).input_ids
-
- tokenizer.sentinel_token_ids = _sentinel_token_ids
-
-
-class AutoTokenizerForMOD(AutoTokenizer):
- """AutoTokenizer + Adaptation for MOD.
-
- A simple wrapper around AutoTokenizer to make instantiating
- an MOD-adapted tokenizer a bit easier.
-
- MOD-adapted tokenizers have sentinel tokens (e.g., ),
- a padding token, and a property to get the token ids of the
- sentinel tokens.
- """
-
- @classmethod
- def from_pretrained(cls, *args: Any,
- **kwargs: Any) -> PreTrainedTokenizerBase:
- """See `AutoTokenizer.from_pretrained` docstring."""
- tokenizer = super().from_pretrained(*args, **kwargs)
- adapt_tokenizer_for_denoising(tokenizer)
- return tokenizer
diff --git a/llmfoundry/models/utils/hf_prefixlm_converter.py b/llmfoundry/models/utils/hf_prefixlm_converter.py
deleted file mode 100644
index 692fab94c2..0000000000
--- a/llmfoundry/models/utils/hf_prefixlm_converter.py
+++ /dev/null
@@ -1,301 +0,0 @@
-# Copyright 2022 MosaicML LLM Foundry authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""Converts Huggingface Causal LM to Prefix LM.
-
-Conversion does lightweight surgery on a HuggingFace
-Causal LM to convert it to a Prefix LM.
-
-Prefix LMs accepts a `bidirectional_mask` input in `forward`
-and treat the input prompt as the prefix in `generate`.
-"""
-
-from types import MethodType
-from typing import Any, List, MutableMapping, Optional, Tuple, Union
-
-import torch
-from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
-from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
-from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
-from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
-
-_SUPPORTED_GPT_MODELS = (
- GPT2LMHeadModel,
- GPTJForCausalLM,
- GPTNeoForCausalLM,
- GPTNeoXForCausalLM,
-)
-
-CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM,
- GPTNeoXForCausalLM,]
-
-
-def _convert_gpt_causal_lm_to_prefix_lm(
- model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
- """Converts a GPT-style Causal LM to a Prefix LM.
-
- Supported HuggingFace model classes:
- - `GPT2LMHeadModel`
- - `GPTNeoForCausalLM`
- - `GPTNeoXForCausalLM`
- - `GPTJForCausalLM`
-
- See `convert_hf_causal_lm_to_prefix_lm` for more details.
- """
- if hasattr(model, '_prefix_lm_converted'):
- return model
-
- assert isinstance(model, _SUPPORTED_GPT_MODELS)
- assert model.config.add_cross_attention == False, 'Only supports GPT-style decoder-only models'
-
- def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
- """Helper that gets a list of the model's attention modules.
-
- Each module has a `bias` buffer used for causal masking. The Prefix LM
- conversion adds logic to dynamically manipulate these biases to support
- Prefix LM attention masking.
- """
- attn_modules = []
-
- if isinstance(model, GPTNeoXForCausalLM):
- blocks = model.gpt_neox.layers
- else:
- blocks = model.transformer.h
-
- for block in blocks:
- if isinstance(model, GPTNeoForCausalLM):
- # Ignore "local" layers in this model type
- if block.attn.attention_type != 'global':
- continue
- attn_module = block.attn.attention
- elif isinstance(model, GPTNeoXForCausalLM):
- attn_module = block.attention
- else:
- attn_module = block.attn
-
- attn_modules.append(attn_module)
-
- return attn_modules
-
- # Rename methods to allow:
- # - new `forward` to wrap original `forward`
- # - new `generate` to wrap original `generate`
- setattr(model, '_original_forward', getattr(model, 'forward'))
- setattr(model, '_original_generate', getattr(model, 'generate'))
-
- def forward(
- self: CAUSAL_GPT_TYPES,
- input_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- bidirectional_mask: Optional[torch.Tensor] = None,
- token_type_ids: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ):
- """Wraps original forward to enable PrefixLM attention."""
-
- def call_og_forward():
- if isinstance(self, GPTNeoXForCausalLM):
- return self._original_forward(
- input_ids=input_ids,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- labels=labels,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- else:
- return self._original_forward(
- input_ids=input_ids,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- labels=labels,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
-
- if bidirectional_mask is None:
- # This wrapper is a no-op if bidirectional masks are not supplied
- return call_og_forward()
- assert isinstance(bidirectional_mask, torch.Tensor)
-
- attn_modules = _get_attn_modules(model)
-
- # Handle bidirectional_mask sizing
- # Note: all attn_modules.bias have the same size
- b, s = bidirectional_mask.shape
-
- max_length = attn_modules[0].bias.shape[-1] # type: ignore
-
- if s > max_length:
- raise ValueError(
- f'bidirectional_mask sequence length (={s}) exceeds the ' +\
- f'max length allowed by the model ({max_length}).'
- )
- assert s <= max_length
- if s < max_length:
- pad = torch.zeros((int(b), int(max_length - s)),
- dtype=bidirectional_mask.dtype,
- device=bidirectional_mask.device)
- bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
- bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
-
- # Incorporate the bidirectional mask into the original causal mask
- for attn_module in attn_modules:
- assert isinstance(attn_module.bias, torch.Tensor)
- attn_module.bias.data = torch.logical_or(attn_module.bias.data,
- bidirectional)
-
- # Collect outputs using the model's original forward method
- output = call_og_forward()
-
- # Reset the masks
- for attn_module in attn_modules:
- attn_module.bias.data = torch.tril(
- attn_module.bias.data[0, 0])[None, None] # type: ignore
-
- # Return the outputs
- return output
-
- def generate(self: CAUSAL_GPT_TYPES, *args: Any, **kwargs: Any):
- """Wraps original generate to enable PrefixLM attention."""
- attn_modules = _get_attn_modules(model)
-
- # A convenient answer to PrefixLM generation is to set the causal mask
- # to be bidirectional. All the tokens in the input prompt can attend to
- # one another and, since tokens are generated one-by-one, each new
- # token gets to see everything behind it. This depends on activations
- # being cached and not updated, which is how the HF implementation works.
- for attn_module in attn_modules:
- attn_module.bias.data[:] = 1 # type: ignore
-
- # Collect outputs using the model's original forward method
- output = self._original_generate(*args, **kwargs)
-
- # Reset the masks
- for attn_module in attn_modules:
- attn_module.bias.data = torch.tril(
- attn_module.bias.data[0, 0])[None, None] # type: ignore
-
- # Return the outputs
- return output
-
- # Replace `forward` and `generate` with the new wrappers
- setattr(model, 'forward', MethodType(forward, model))
- setattr(model, 'generate', MethodType(generate, model))
-
- # Finally, tag the model so that this conversion cannot happen again.
- setattr(model, '_prefix_lm_converted', True)
- return model
-
-
-_SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS
-
-CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM,
- GPTNeoXForCausalLM]
-
-
-def convert_hf_causal_lm_to_prefix_lm(
- model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
- """Converts a HuggingFace Causal LM to a Prefix LM.
-
- Supported HuggingFace model classes:
- - `GPT2LMHeadModel`
- - `GPTNeoForCausalLM`
- - `GPTNeoXForCausalLM`
- - `GPTJForCausalLM`
-
- Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
- `generate` method and/or select underlying methods depending on the model class.
-
- These changes preserve the model API, but add a new input to `forward`: "bidirectional_mask".
-
- Notes on training:
- To actually train the converted model as a Prefix LM, training batches will need to indicate
- the prefix/target structure by including `bidirectional_mask` as part of the batch inputs.
-
- **This is not a standard input and requires custom layers either within or after your dataloader.**
-
- In addition to adding `bidirectional_mask` to the batch, this custom code should modify `labels`
- such that `batch['labels'][batch['bidirectional_mask'] == 1] == -100`.
- That is, the prefix portion of the sequence should not generate any loss. Loss should only be
- generated by the target portion of the sequence.
-
- Notes on `GPTNeoForCausalLM`:
- To simplify the implementation, "global" and "local" attention layers are handled differently.
- For "global" layers, we handle conversion as described above. For "local" layers, which use a
- causal attention mask within a restricted local window, we do not alter the masking.
-
- Notes on `forward` method conversion:
- After conversion, the `forward` method will handle a new input, `bidirectional_mask`,
- which should be a [batch_size, seq_length] byte tensor, where 1 indicates token positions
- belonging to the prefix (prefix tokens can attend to one another bidirectionally), and
- 0 indicates token positions belonging to the target.
-
- The new `forward` method will incorporate `bidirectional_mask` (if supplied) into the existing
- causal mask, call the original `forward` method, and (if the causal mask is a buffer) reset
- the causal masks before returning the result.
-
- Notes on `generate` method conversion:
- After conversion, the `generate` method will have the same signature but will internally
- convert all causal masks to be purely bidirectional, call the original `generate` method, and
- (where appropriate) reset the causal masks before returning the result.
-
- This works thanks to the logic of the HuggingFace `generate` API, which first encodes the token
- "prompt" passed to `generate` (which is treated as the prefix) and then sequentially generates
- each new token. Encodings are cached as generation happens, so all prefix tokens can attend to one
- another (as expected in a Prefix LM) and generated tokens can only attend to prefix tokens and
- previously-generated tokens (also as expected in a Prefix LM).
-
- To preserve the API, the original methods are renamed to `_original_forward` and
- `_original_generate`, and replaced with new `forward` and `generate` methods that wrap
- them, respectively. Although implementation details vary by model class.
- """
- if isinstance(model, _SUPPORTED_GPT_MODELS):
- return _convert_gpt_causal_lm_to_prefix_lm(model)
- else:
- raise TypeError(
- f'Cannot convert model to Prefix LM. ' +\
- f'Model does not belong to set of supported HF models:' +\
- f'\n{_SUPPORTED_HF_MODELS}'
- )
-
-
-def add_bidirectional_mask_if_missing(batch: MutableMapping):
- """Attempts to add bidirectional_mask to batch if missing.
-
- Raises:
- KeyError if bidirectional_mask is missing and can't be inferred
- """
- if 'bidirectional_mask' not in batch:
- if batch.get('mode', None) == 'icl_task':
- batch['bidirectional_mask'] = batch['attention_mask'].clone()
- for i, continuation_indices in enumerate(
- batch['continuation_indices']):
- batch['bidirectional_mask'][i, continuation_indices] = 0
- elif ('labels' in batch) and ('attention_mask' in batch):
- batch['bidirectional_mask'] = torch.logical_and(
- torch.eq(batch['attention_mask'], 1),
- torch.eq(batch['labels'], -100),
- ).type_as(batch['attention_mask'])
- else:
- raise KeyError(
- 'No bidirectional_mask in batch and not sure how to construct one.'
- )
diff --git a/pyproject.toml b/pyproject.toml
index 503899c4fa..6bde062abb 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -22,8 +22,7 @@ include = [
# Pyright
[tool.pyright]
-exclude = ['env-**', 'venv*', '**/flash_attn_triton.py', '.venv']
-ignore = ['llmfoundry/models/layers/flash_attn_triton.py']
+exclude = ['env-**', 'venv*', '.venv']
stubPath = "" # suppress useless 'stubPath is not a valid directory' errors
reportUnnecessaryIsInstance = "none" # it is ok to do this for clarity or safety
diff --git a/scripts/misc/convert_examples_ckpt.py b/scripts/misc/convert_examples_ckpt.py
index a533aec72d..db1301674c 100644
--- a/scripts/misc/convert_examples_ckpt.py
+++ b/scripts/misc/convert_examples_ckpt.py
@@ -125,7 +125,7 @@ def convert_examples_ckpt(
'attn_clip_qkv', attn_config_defaults['clip_qkv'])
for k in [
- 'attn_pdrop', 'attn_impl', 'softmax_scale', 'prefix_lm',
+ 'attn_pdrop', 'attn_impl', 'softmax_scale',
'attn_uses_sequence_id', 'alibi', 'alibi_bias_max'
]:
if k in hf_config:
diff --git a/scripts/train/README.md b/scripts/train/README.md
index 0d9a335848..36974ec943 100644
--- a/scripts/train/README.md
+++ b/scripts/train/README.md
@@ -335,7 +335,7 @@ train_loader:
# Using Flash Attention
-Flash Attention is an optimized implementation of the attention mechanism, first introduced by [Dao et al.](https://github.com/Dao-AILab/flash-attention). There are three versions of Flash Attention that can be used with LLM Foundry: Flash Attention V1, Flash Attention V2, and a Triton implementation of Flash Attention. The support for Flash Attention V1 has been deprecated, and we recommend using Flash Attention V2. We also recommend using `Flash` attention instead of `Triton` attention, unless you're training Prefix Language Models (in which case we recommend using `Triton`). To start, we recommend using one of our [provided Docker images](../../README.md#mosaicml-docker-images) corresponding to the Flash Attention version you would like to use. The Triton implementation can be used with either Flash Attention V1 or V2. Next, how you specify to use Flash Attention depends on which model you are using.
+Flash Attention is an optimized implementation of the attention mechanism, first introduced by [Dao et al.](https://github.com/Dao-AILab/flash-attention). LLM Foundry supports Flash Attention V2. To start, we recommend using one of our [provided Docker images](../../README.md#mosaicml-docker-images) corresponding to the Flash Attention version you would like to use. Next, how you specify to use Flash Attention depends on which model you are using.
For MPT, you can specify Flash Attention in your YAML like so:
```yaml
@@ -343,8 +343,6 @@ model:
name: mpt_causal_lm
...
attn_config:
- # Will use either V1 or V2 depending on what is installed
- # "triton" will use the Triton implementation
attn_impl: flash
...
```
@@ -356,8 +354,6 @@ model:
pretrained_model_name_or_path: mosaicml/mpt-7b
...
config_overrides:
- # Will use either V1 or V2 depending on what is installed
- # "triton" will use the Triton implementation
attn_config:
attn_impl: flash
...
diff --git a/scripts/train/benchmarking/submit_benchmarks.py b/scripts/train/benchmarking/submit_benchmarks.py
index a020745581..aff570e3d4 100644
--- a/scripts/train/benchmarking/submit_benchmarks.py
+++ b/scripts/train/benchmarking/submit_benchmarks.py
@@ -141,7 +141,7 @@ def parse_args():
nargs='+',
help='model sizes to test')
- parser.add_argument('--attn_impl', type=str, default='triton')
+ parser.add_argument('--attn_impl', type=str, default='flash')
parser.add_argument('-c',
'--clusters',
diff --git a/scripts/train/benchmarking/sweep.py b/scripts/train/benchmarking/sweep.py
index 441c9825f8..57b3aa262c 100644
--- a/scripts/train/benchmarking/sweep.py
+++ b/scripts/train/benchmarking/sweep.py
@@ -64,14 +64,14 @@
'--fsdp_config_activation_checkpointing true',
'--fsdp_config_shard_strategy FULL_SHARD',
'--microbatch_size 16',
- '--attn_impl triton',
+ '--attn_impl flash',
],
[
'--model_yamls 30b.yaml',
'--fsdp_config_activation_checkpointing true',
'--fsdp_config_shard_strategy FULL_SHARD',
'--microbatch_size 8',
- '--attn_impl triton',
+ '--attn_impl flash',
],
[
'--model_yamls 70b.yaml',
diff --git a/scripts/train/train.py b/scripts/train/train.py
index 28364bf3c8..6ecbc55e38 100644
--- a/scripts/train/train.py
+++ b/scripts/train/train.py
@@ -55,43 +55,10 @@ def validate_config(cfg: DictConfig):
loaders.append(eval_loader)
for loader in loaders:
if loader.name == 'text':
- if cfg.model.name in ['hf_prefix_lm', 'hf_t5']:
+ if cfg.model.name in ['hf_t5']:
raise ValueError(
f'Model type "{cfg.model.name}" is not supported when using the "text " ' +\
- f'dataloader. Please use the "text_denoising" dataloader to pre-train that model type.')
- elif loader.name == 'text_denoising':
- if cfg.model.name == 'hf_causal_lm':
- raise ValueError(
- f'Model type "{cfg.model.name}" is not supported when using the "text_denoising" ' +\
- f'dataloader. Please use the "text" dataloader to pre-train that model type.')
- if loader.mixture_of_denoisers.decoder_only_format and cfg.model.name == 'hf_t5':
- warnings.warn(
- 'Model type "hf_t5" requires `decoder_only_format` to be ``False``. ' +\
- 'Overriding `decoder_only_format` from ``True`` to ``False``.')
- loader.mixture_of_denoisers.decoder_only_format = False
- if (not loader.mixture_of_denoisers.decoder_only_format
- ) and cfg.model.name == 'hf_prefix_lm':
- warnings.warn(
- 'Model type "hf_prefix_lm" requires `decoder_only_format` to be ``True``. ' +\
- 'Overriding `decoder_only_format` from ``False`` to ``True``.')
- loader.mixture_of_denoisers.decoder_only_format = True
- elif loader.name == 'finetuning':
- if cfg.model.name == 'hf_prefix_lm':
- is_prefix_lm = True
- elif cfg.model.name == 'mpt_causal_lm':
- is_prefix_lm = cfg.model.get('attn_config',
- {}).get('prefix_lm', False)
- else:
- # Note: This only covers the two prefix-lms introduced in this repo
- is_prefix_lm = False
- target_responses = loader.dataset.get('target_responses', 'last')
- target_prompts = loader.dataset.get('target_prompts', 'none')
- prefix_lm_safe = target_responses == 'last' and target_prompts == 'none'
- if is_prefix_lm and not prefix_lm_safe:
- raise ValueError(
- 'The model configuration is building a Prefix-LM, which requires that the finetuning ' +\
- 'dataloader uses `target_responses`="last" and `target_prompts`="none".'
- )
+ f'dataloader. Only finetuning is supported.')
if 'icl_tasks' in cfg:
if cfg.model.name == 'hf_t5':
diff --git a/setup.py b/setup.py
index 9a26307e5a..22b7cb17ca 100644
--- a/setup.py
+++ b/setup.py
@@ -66,9 +66,6 @@
'mosaicml-cli>=0.6.10,<1',
'onnx==1.14.0',
'onnxruntime==1.15.1',
- 'cmake>=3.25.0,<=3.26.3', # required for triton-pre-mlir below
- # PyPI does not support direct dependencies, so we remove this line before uploading from PyPI
- 'triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir_sm90#subdirectory=python',
'boto3>=1.21.45,<2',
'huggingface-hub>=0.17.0,<1.0',
'beautifulsoup4>=4.12.2,<5', # required for model download utils
diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py
index 0119f8edd2..3949c091aa 100644
--- a/tests/a_scripts/inference/test_convert_composer_to_hf.py
+++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py
@@ -24,7 +24,6 @@
from llmfoundry.callbacks import HuggingFaceCheckpointer
from llmfoundry.callbacks.hf_checkpointer import _maybe_get_license_filename
from llmfoundry.data.finetuning import build_finetuning_dataloader
-from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM
from llmfoundry.utils.builders import (build_composer_model, build_optimizer,
build_tokenizer)
from scripts.inference.convert_composer_to_hf import convert_composer_to_hf
@@ -789,50 +788,6 @@ def test_convert_and_generate(model: str, tie_word_embeddings: bool,
delete_transformers_cache()
-@pytest.mark.gpu
-@pytest.mark.parametrize('tie_word_embeddings', [True, False])
-def test_convert_and_generate_triton(tie_word_embeddings: str,
- tmp_path: pathlib.Path):
- delete_transformers_cache()
-
- cfg = get_config()
- cfg['model']['init_device'] = 'cpu'
- cfg['tie_word_embeddings'] = tie_word_embeddings
- tokenizer = transformers.AutoTokenizer.from_pretrained(
- 'EleutherAI/gpt-neox-20b')
- model = ComposerMPTCausalLM(cfg['model'], tokenizer)
- trainer = Trainer(model=model)
- trainer.save_checkpoint(os.path.join(tmp_path, 'checkpoint.pt'))
-
- args = Namespace(composer_path=os.path.join(tmp_path, 'checkpoint.pt'),
- hf_output_path=os.path.join(tmp_path, 'hf-output-folder'),
- output_precision='fp32',
- local_checkpoint_save_location=None,
- hf_repo_for_upload=None,
- trust_remote_code=False,
- test_uploaded_model=False)
- convert_composer_to_hf(args)
-
- config = transformers.AutoConfig.from_pretrained(os.path.join(
- tmp_path, 'hf-output-folder'),
- trust_remote_code=True)
- config.attn_config['attn_impl'] = 'triton'
- model = transformers.AutoModelForCausalLM.from_pretrained(
- os.path.join(tmp_path, 'hf-output-folder'),
- config=config,
- trust_remote_code=True)
- model.to(device='cuda', dtype=torch.bfloat16)
- tokenizer = transformers.AutoTokenizer.from_pretrained(
- os.path.join(tmp_path, 'hf-output-folder'), trust_remote_code=True)
-
- output = model.generate(tokenizer(
- 'hello', return_tensors='pt')['input_ids'].to(device='cuda'),
- max_new_tokens=1)
- assert output.shape == (1, 2)
-
- delete_transformers_cache()
-
-
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_convert_and_generate_meta(tie_word_embeddings: str,
tmp_path: pathlib.Path):
diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py
index 723b456e7a..f5c2631fa7 100644
--- a/tests/data/test_dataloader.py
+++ b/tests/data/test_dataloader.py
@@ -5,7 +5,6 @@
import pathlib
import random
import shutil
-import tempfile
from argparse import Namespace
from contextlib import nullcontext as does_not_raise
from pathlib import Path
@@ -22,8 +21,7 @@
from omegaconf import OmegaConf as om
from streaming import MDSWriter
-from llmfoundry import (build_finetuning_dataloader,
- build_text_denoising_dataloader)
+from llmfoundry import build_finetuning_dataloader
from llmfoundry.data import build_dataloader
from llmfoundry.data.finetuning.collator import (_HF_IGNORE_INDEX,
validate_target_settings)
@@ -270,72 +268,6 @@ def test_sequence_id_wrapper(eos_token_id: Optional[int],
raise NotImplementedError()
-@pytest.mark.parametrize('decoder_only_format', [True, False])
-@pytest.mark.parametrize('pretokenize', [True, False])
-@pytest.mark.parametrize('packing_ratio', [None, 5.5])
-def test_denoising_dataloader(decoder_only_format: bool, pretokenize: bool,
- packing_ratio: Optional[float]):
- # Use the datasets just built in the last test
- tokenizer_name = 'facebook/opt-125m'
- data_local = get_data_local(tokenizer_name, pretokenize)
- path = get_abs_data_path(data_local)
- max_seq_len = 256 if decoder_only_format else 128
-
- if (decoder_only_format is False) and (packing_ratio is not None):
- pytest.xfail('packing_ratio only supported for decoder-only format.')
-
- with tempfile.TemporaryDirectory() as tmpdir:
- cfg = {
- 'name': 'text_denoising',
- 'dataset': {
- 'local': tmpdir,
- 'remote': path,
- 'split': 'val_xsmall',
- 'shuffle': False,
- 'max_seq_len': max_seq_len,
- 'packing_ratio': packing_ratio,
- 'predownload': 1000,
- 'keep_zip': False,
- 'num_workers': None
- },
- 'mixture_of_denoisers': {
- 'decoder_only_format': decoder_only_format,
- 'span_mean_lengths_and_ratios': [[3, .15], [8, .5]],
- 'sequence_mask_ratios': 0.25,
- },
- 'drop_last': False,
- 'num_workers': 4,
- }
- cfg = om.create(cfg)
- device_batch_size = 2
-
- expected_keys = ['input_ids', 'attention_mask', 'labels']
- if decoder_only_format:
- expected_keys += ['bidirectional_mask']
- else:
- expected_keys += ['decoder_attention_mask', 'decoder_input_ids']
-
- if packing_ratio is not None:
- expected_keys += ['sequence_id']
-
- tokenizer = build_tokenizer(
- tokenizer_name=tokenizer_name,
- tokenizer_kwargs={'model_max_length': max_seq_len})
-
- loader = build_text_denoising_dataloader(cfg, tokenizer,
- device_batch_size).dataloader
- batch_ix = 0
- for batch in loader:
- for k in expected_keys:
- assert k in batch
- t = batch[k]
- assert t.shape[0] == device_batch_size
- assert t.shape[1] <= max_seq_len
- batch_ix += 1
- if batch_ix >= 5:
- break
-
-
@pytest.mark.parametrize('use_chat_formatting', [True, False])
@pytest.mark.parametrize('decoder_only_format', [True, False])
@pytest.mark.parametrize('allow_pad_trimming', [True, False])
@@ -387,9 +319,7 @@ def test_finetuning_dataloader(use_chat_formatting: bool,
device_batch_size = 2
expected_keys = ['input_ids', 'attention_mask', 'labels']
- if decoder_only_format:
- expected_keys += ['bidirectional_mask']
- else:
+ if not decoder_only_format:
expected_keys += ['decoder_attention_mask', 'decoder_input_ids']
loader = build_finetuning_dataloader(cfg, tokenizer,
@@ -496,9 +426,6 @@ def test_finetuning_dataloader_small_data(dataset_size: int,
tokenizer_kwargs={'model_max_length': max_seq_len},
)
- expected_keys = ['input_ids', 'attention_mask', 'labels']
- expected_keys += ['bidirectional_mask']
-
error_context = contextlib.nullcontext()
if (dist.get_world_size() * device_batch_size > dataset_size) and drop_last:
error_context = pytest.raises(NotEnoughDatasetSamplesError,
@@ -812,9 +739,6 @@ def test_malformed_data(
cfg = om.create(cfg)
- expected_keys = ['input_ids', 'attention_mask', 'labels']
- expected_keys += ['bidirectional_mask']
-
error_context = contextlib.nullcontext()
if add_invalid_prompt_type:
error_context = pytest.raises(InvalidPromptTypeError,
@@ -1044,8 +968,8 @@ def test_token_counting_func(pad_token_id: int, batch_size: int,
@pytest.mark.parametrize('dataloader_type,tensor_input',
[('finetuning-hf', False),
- ('finetuning-streaming', False), ('denoising', False),
- ('text', True), ('text', False)])
+ ('finetuning-streaming', False), ('text', True),
+ ('text', False)])
@pytest.mark.parametrize('pad_token_id', [100, None])
@pytest.mark.parametrize('batch_size', [1, 8])
@pytest.mark.parametrize('model_max_length', [1024])
@@ -1080,10 +1004,6 @@ def test_token_counting_func_dataloader_setting(
torch.tensor(b['input_ids']) for b in batch_tokenized
]
- if dataloader_type == 'denoising':
- expected_token_count += 2 * batch_size # for the two eos tokens
- expected_token_count += 5 * batch_size # for the corruption prefix tokens
-
if dataloader_type in {'finetuning-hf', 'finetuning-streaming'}:
for b in batch_tokenized:
b['labels'] = b['input_ids'].copy() # type: ignore
@@ -1155,30 +1075,6 @@ def test_token_counting_func_dataloader_setting(
monkeypatch.setattr('llmfoundry.data.text_data.StreamingTextDataset',
lambda *args, **kwargs: ds_mock)
dl = build_text_dataloader(cfg, gptt, batch_size)
- elif dataloader_type == 'denoising':
- cfg = DictConfig({
- 'name': 'text_denoising',
- 'dataset': {
- 'local': 'dummy-path',
- 'remote': 'dummy-path',
- 'split': 'val_xsmall',
- 'shuffle': False,
- 'max_seq_len': model_max_length,
- 'packing_ratio': None,
- 'predownload': 1000,
- 'keep_zip': False,
- 'num_workers': None
- },
- 'mixture_of_denoisers': {
- 'decoder_only_format': False,
- 'span_mean_lengths_and_ratios': None,
- 'sequence_mask_ratios': 0.25,
- },
- **common_args
- })
- monkeypatch.setattr('llmfoundry.data.denoising.StreamingTextDataset',
- lambda *args, **kwargs: MagicMock())
- dl = build_text_denoising_dataloader(cfg, gptt, batch_size)
else:
raise NotImplementedError()
diff --git a/tests/models/hf/test_hf_config.py b/tests/models/hf/test_hf_config.py
index d541f0a30c..d5de596199 100644
--- a/tests/models/hf/test_hf_config.py
+++ b/tests/models/hf/test_hf_config.py
@@ -6,6 +6,7 @@
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Mapping
+from unittest.mock import Mock, patch
import pytest
import torch
@@ -83,7 +84,7 @@ def test_tie_weights(tie_word_embeddings: bool):
},
{
'attn_config': {
- 'attn_impl': 'triton'
+ 'attn_impl': 'flash',
}
},
{
@@ -94,7 +95,7 @@ def test_tie_weights(tie_word_embeddings: bool):
{
'max_seq_len': 1024,
'attn_config': {
- 'attn_impl': 'triton'
+ 'attn_impl': 'flash',
},
'init_config': {
'emb_init_std': 5
@@ -104,11 +105,13 @@ def test_tie_weights(tie_word_embeddings: bool):
marks=pytest.mark.xfail(reason='"msl" is a ValueError',
strict=True)),
pytest.param({'attn_config': {
- 'attn_iml': 'triton'
+ 'attn_iml': 'flash'
}},
marks=pytest.mark.xfail(reason='"attn_impl" mispelled',
strict=True)),
])
+@patch('llmfoundry.models.layers.attention.is_flash_v2_installed',
+ new=Mock(return_value=True))
def test_hf_config_override(
model_cfg_overrides: Dict[str, Any],
conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml',
diff --git a/tests/models/hf/test_hf_mpt_gen.py b/tests/models/hf/test_hf_mpt_gen.py
index 1df553f126..917e970852 100644
--- a/tests/models/hf/test_hf_mpt_gen.py
+++ b/tests/models/hf/test_hf_mpt_gen.py
@@ -13,14 +13,14 @@
@pytest.mark.gpu
@pytest.mark.parametrize('device', ['cpu', 'gpu'])
-@pytest.mark.parametrize('attn_impl', ['triton', 'torch'])
+@pytest.mark.parametrize('attn_impl', ['flash', 'torch'])
def test_init_hfhub_mpt(
device: str,
attn_impl: str,
build_tiny_hf_mpt: Callable[..., ComposerHFCausalLM],
mpt_tokenizer: PreTrainedTokenizerBase,
):
- if device == 'cpu' and attn_impl == 'triton':
+ if device == 'cpu' and attn_impl == 'flash':
pytest.skip(f'{attn_impl=} not implemented for {device=}.')
composer_device = get_device(device)
diff --git a/tests/models/hf/test_hf_t5.py b/tests/models/hf/test_hf_t5.py
index 12ee0935e3..fb8689e665 100644
--- a/tests/models/hf/test_hf_t5.py
+++ b/tests/models/hf/test_hf_t5.py
@@ -18,8 +18,6 @@ def test_experimental_hf_t5():
},
'pretrained': False,
'init_device': 'cpu',
- 'z_loss': 0.0,
- 'adapt_vocab_for_denoising': False
})
tokenizer = transformers.T5Tokenizer.from_pretrained('t5-base')
diff --git a/tests/models/hf/test_hf_v_mpt.py b/tests/models/hf/test_hf_v_mpt.py
index b44c8d14c2..82b64ce80c 100644
--- a/tests/models/hf/test_hf_v_mpt.py
+++ b/tests/models/hf/test_hf_v_mpt.py
@@ -16,14 +16,9 @@
('flash', 0.0, False, 1, False),
('flash', 0.1, False, 1, False),
('torch', 0.0, False, 1, False),
- ('triton', 0.0, False, 1, False),
- ('triton', 0.1, False, 1, False),
('torch', 0.0, False, 0, False),
- ('triton', 0.0, False, 0, False),
- ('triton', 0.1, False, 0, False),
('flash', 0.0, False, None, True),
('torch', 0.0, False, None, True),
- ('triton', 0.0, False, None, True),
])
def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool,
mask_val: Optional[int], no_attn_mask: bool):
@@ -93,7 +88,7 @@ def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool,
# extract model cfg
model_cfg = cfg.model
- # use triton attn implementation
+ # use given attn implementation
model_cfg.attn_impl = attn_impl
model_cfg.alibi = alibi
# modify cfg for HF GPT2 compatibility
diff --git a/tests/models/layers/test_flash_attn.py b/tests/models/layers/test_flash_attn.py
index 622e9574ea..a3b17c36df 100644
--- a/tests/models/layers/test_flash_attn.py
+++ b/tests/models/layers/test_flash_attn.py
@@ -6,12 +6,9 @@
import pytest
import torch
-from llmfoundry.models.layers.attention import (attn_bias_shape,
- build_attn_bias,
- check_alibi_support,
- flash_attn_fn, gen_slopes,
- is_flash_v2_installed,
- triton_flash_attn_fn)
+from llmfoundry.models.layers.attention import (
+ attn_bias_shape, build_attn_bias, check_alibi_support, flash_attn_fn,
+ gen_slopes, is_flash_v2_installed, scaled_multihead_dot_product_attention)
from llmfoundry.models.mpt.modeling_mpt import gen_flash_attn_padding_info
@@ -239,7 +236,7 @@ def test_sliding_window(sliding_window_size: int):
torch.ones(seqlen_1, seqlen_1), diagonal=-(sliding_window_size + 1)).to(
dtype=dtype, device=device) * torch.finfo(attn_bias_2.dtype).min
attn_bias_2 = attn_bias_2 + window_mask_2
- output_2, _, _ = triton_flash_attn_fn(
+ output_2, _, _ = scaled_multihead_dot_product_attention(
query=query_2,
key=key_2,
value=value_2,
@@ -257,13 +254,15 @@ def test_sliding_window(sliding_window_size: int):
output_2.sum().backward()
- assert torch.allclose(output_1, output_2)
- assert torch.norm(query_2.grad - query_1.grad # type: ignore
- ) <= 1e-2 + 1e-2 * torch.norm(query_2.grad)
- assert torch.norm(key_2.grad - key_1.grad # type: ignore
- ) <= 1e-2 + 1e-2 * torch.norm(key_2.grad)
- assert torch.norm(value_2.grad - value_1.grad # type: ignore
- ) <= 1e-2 + 1e-2 * torch.norm(value_2.grad)
+ print(torch.max(output_1 - output_2))
+
+ _assert_approx_equal(output_1, output_2)
+ assert (query_2.grad is not None) and (query_1.grad is not None)
+ _assert_approx_equal(query_1.grad, query_2.grad)
+ assert (key_2.grad is not None) and (key_1.grad is not None)
+ _assert_approx_equal(key_1.grad, key_2.grad)
+ assert (value_2.grad is not None) and (value_1.grad is not None)
+ _assert_approx_equal(value_1.grad, value_2.grad)
@pytest.mark.gpu
@@ -322,17 +321,16 @@ def test_alibi_bias(n_heads: int):
def gen_bias():
causal = True
- bs = attn_bias_shape('triton',
+ bs = attn_bias_shape('torch',
n_heads,
seqlen_1,
True,
- prefix_lm=False,
use_sequence_id=False,
causal=causal)
attn_bias = torch.zeros(*bs, device=device)
attn_bias = build_attn_bias(
- 'triton',
+ 'torch',
attn_bias,
n_heads,
seqlen_1,
@@ -344,7 +342,7 @@ def gen_bias():
attn_bias_2 = gen_bias()
- output_2, _, _ = triton_flash_attn_fn(
+ output_2, _, _ = scaled_multihead_dot_product_attention(
query=query_2,
key=key_2,
value=value_2,
@@ -362,13 +360,14 @@ def gen_bias():
output_2.sum().backward()
- assert torch.allclose(output_1, output_2)
+ _assert_approx_equal(output_1, output_2)
assert (query_2.grad is not None) and (query_1.grad is not None)
- assert torch.norm(query_2.grad -
- query_1.grad) <= 1e-2 + 1e-2 * torch.norm(query_2.grad)
+ _assert_approx_equal(query_1.grad, query_2.grad)
assert (key_2.grad is not None) and (key_1.grad is not None)
- assert torch.norm(key_2.grad -
- key_1.grad) <= 1e-2 + 1e-2 * torch.norm(key_2.grad)
+ _assert_approx_equal(key_1.grad, key_2.grad)
assert (value_2.grad is not None) and (value_1.grad is not None)
- assert torch.norm(value_2.grad -
- value_1.grad) <= 1e-2 + 1e-2 * torch.norm(value_2.grad)
+ _assert_approx_equal(value_1.grad, value_2.grad)
+
+
+def _assert_approx_equal(value1: torch.Tensor, value2: torch.Tensor):
+ assert torch.norm(value2 - value1) <= 1e-2 + 1e-2 * torch.norm(value2)
diff --git a/tests/models/layers/test_flash_triton_torch.py b/tests/models/layers/test_flash_torch.py
similarity index 98%
rename from tests/models/layers/test_flash_triton_torch.py
rename to tests/models/layers/test_flash_torch.py
index 4e1efa3f34..c0e9f4b3b5 100644
--- a/tests/models/layers/test_flash_triton_torch.py
+++ b/tests/models/layers/test_flash_torch.py
@@ -23,9 +23,7 @@ def allclose_helper(t0: torch.Tensor,
@pytest.mark.gpu
@pytest.mark.parametrize('attn_impl_0, attn_impl_1', [
- ('flash', 'triton'),
('flash', 'torch'),
- ('triton', 'torch'),
])
@pytest.mark.parametrize('clip_qkv', [True, False])
@pytest.mark.parametrize('qk_ln, qk_gn', [
@@ -146,7 +144,6 @@ def gen_bias(attn_impl: str):
cfg.n_heads,
s,
alibi,
- prefix_lm=False,
use_sequence_id=attn_uses_sequence_id,
causal=causal)
if bs is not None:
@@ -291,7 +288,7 @@ def gen_bias(attn_impl: str):
@pytest.mark.gpu
-@pytest.mark.parametrize('attn_impl', ['flash', 'triton', 'torch'])
+@pytest.mark.parametrize('attn_impl', ['flash', 'torch'])
def test_vs_mha(attn_impl: str, device: str = 'cuda'):
"""Compare diff attn_impl to torch.nn.MultiheadAttention."""
from llmfoundry.models.layers import attention
@@ -388,7 +385,7 @@ def gen_tca_mask():
@pytest.mark.gpu
-@pytest.mark.parametrize('attn_impl', ['flash', 'triton', 'torch'])
+@pytest.mark.parametrize('attn_impl', ['flash', 'torch'])
@pytest.mark.parametrize('n_heads', [16, 8])
@pytest.mark.parametrize('kv_n_heads', [4, 2, 1])
def test_grouped_attention_heads(attn_impl: str,
diff --git a/tests/models/layers/test_huggingface_flash.py b/tests/models/layers/test_huggingface_flash.py
index dfd3b17f96..1e8ec2383d 100644
--- a/tests/models/layers/test_huggingface_flash.py
+++ b/tests/models/layers/test_huggingface_flash.py
@@ -3,105 +3,16 @@
import contextlib
import os
-from unittest.mock import patch
import pytest
-import torch
-import transformers
from composer.core.precision import get_precision_context
-from composer.utils import reproducibility
from omegaconf import OmegaConf as om
-from transformers.models.llama.modeling_llama import LlamaAttention
from llmfoundry.models.hf.hf_fsdp import rgetattr
from llmfoundry.models.layers.attention import is_flash_v2_installed
-from llmfoundry.models.layers.llama_attention_monkeypatch import (
- llama_attention_patch_torch, llama_attention_patch_triton)
from llmfoundry.utils.builders import build_composer_model, build_tokenizer
-@pytest.mark.parametrize('patch_fn_name', ['torch', 'triton'])
-@pytest.mark.parametrize('explicit_mask', [True, False])
-@pytest.mark.parametrize(
- 'model_name', ['meta-llama/Llama-2-7b-hf', 'meta-llama/Llama-2-70b-hf'])
-@pytest.mark.gpu
-def test_patch_equivalence(patch_fn_name: str, explicit_mask: bool,
- model_name: str):
- if 'HUGGING_FACE_HUB_TOKEN' not in os.environ:
- pytest.skip(
- 'The CI cluster does not have access to the Llama models, so skip this test.'
- )
-
- device = 'cuda:0'
- sequence_length = 64
- model_dim = 128 if '7b' in model_name else 256
- batch_size = 2
- if patch_fn_name == 'torch':
- patch_fn = llama_attention_patch_torch
- dtype = torch.float32
- atol = 0.0
- rtol = 0.0
- elif patch_fn_name == 'triton':
- # the huggingface implementation of llama performs the softmax in fp32
- # this can result in fairly large differences for the triton implementation
- # but the torch implementation produces the exact same output so we can confirm
- # the implementation is correct
- patch_fn = llama_attention_patch_triton
- dtype = torch.bfloat16
- atol = 1e-2
- rtol = 1e-2
- else:
- raise ValueError(f'Unknown patch_fn_name: {patch_fn_name}')
-
- llama_config = transformers.AutoConfig.from_pretrained(
- model_name, use_auth_token=True, hidden_size=model_dim)
-
- reproducibility.seed_all(42)
- attention = LlamaAttention(config=llama_config,)
- attention.to(dtype=dtype, device=device)
-
- rng = torch.Generator(device=device).manual_seed(42)
- hidden_states = torch.randn(batch_size,
- sequence_length,
- model_dim,
- generator=rng,
- dtype=dtype,
- device=device)
- causal_mask = torch.full((sequence_length, sequence_length),
- torch.finfo(torch.float32).min,
- device=device)
- causal_mask = causal_mask.triu(diagonal=1)
- causal_mask = causal_mask[None,
- None, :, :].expand(batch_size, 1, sequence_length,
- sequence_length)
- position_ids = torch.arange(sequence_length,
- dtype=torch.long,
- device=device)
- position_ids = position_ids[None, :].expand(batch_size, sequence_length)
-
- attn_output, _, _ = attention(
- hidden_states=hidden_states,
- attention_mask=causal_mask if explicit_mask else None,
- position_ids=position_ids,
- past_key_value=None,
- use_cache=False,
- )
-
- reproducibility.seed_all(42)
- with patch.object(LlamaAttention, 'forward', new=patch_fn):
- attention = LlamaAttention(config=llama_config,)
- attention.to(dtype=dtype, device=device)
- new_output, _, _ = attention(
- hidden_states=hidden_states,
- attention_mask=causal_mask if explicit_mask else None,
- position_ids=position_ids,
- past_key_value=None,
- use_cache=False,
- )
-
- assert torch.allclose(attn_output, new_output, atol=atol, rtol=rtol)
-
-
@pytest.mark.gpu
@pytest.mark.world_size(2)
@pytest.mark.parametrize('model_name', ['llama2', 'mistral'])
diff --git a/tests/models/test_model.py b/tests/models/test_model.py
index 79a5e4f98f..7244ddc8c2 100644
--- a/tests/models/test_model.py
+++ b/tests/models/test_model.py
@@ -26,7 +26,7 @@
from transformers.models.bloom.modeling_bloom import build_alibi_tensor
from llmfoundry import ComposerHFCausalLM
-from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
+from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP
from llmfoundry.models.layers import NORM_CLASS_REGISTRY, build_alibi_bias
from llmfoundry.models.layers.attention import (check_alibi_support,
is_flash_v2_installed)
@@ -108,7 +108,7 @@ def gen_random_batch(batch_size: int,
# default to only input ids
if inputs == None:
inputs = ['input_ids']
- # generate input batch of random data, suitable for a Causal or Prefix LM
+ # generate input batch of random data, suitable for a Causal LM
batch = {}
for inp in inputs:
if inp == 'input_ids':
@@ -128,8 +128,6 @@ def gen_random_batch(batch_size: int,
batch['attention_mask'] = torch.ones(size=(batch_size,
test_cfg.max_seq_len),
dtype=torch.int64).to(test_cfg.device)
- batch['bidirectional_mask'] = batch['attention_mask'].clone()
- batch['bidirectional_mask'][:, (test_cfg.max_seq_len // 2):] = 0
return batch
@@ -257,9 +255,7 @@ def test_attention_mechanism(batch_size: int = 2):
x = x + block.resid_ffn_dropout(n)
-@pytest.mark.parametrize('prefixlm', [False, True])
-def test_full_forward_and_backward_gpt2_small(prefixlm: bool,
- batch_size: int = 2):
+def test_full_forward_and_backward_gpt2_small(batch_size: int = 2):
warnings.filterwarnings(
action='ignore',
message='Torchmetrics v0.9 introduced a new argument class property')
@@ -270,11 +266,7 @@ def test_full_forward_and_backward_gpt2_small(prefixlm: bool,
device = 'cpu'
neo_cfg.device = device
neo_cfg.max_seq_len = 256
-
- if prefixlm:
- neo_cfg.model.name = 'hf_prefix_lm'
- else:
- neo_cfg.model.name = 'hf_causal_lm'
+ neo_cfg.model.name = 'hf_causal_lm'
tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(neo_cfg.tokenizer)
tokenizer = build_tokenizer(neo_cfg.tokenizer.name,
@@ -687,7 +679,7 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool,
@pytest.mark.gpu
-@pytest.mark.parametrize('attention_impl', ['flash', 'triton', 'torch'])
+@pytest.mark.parametrize('attention_impl', ['flash', 'torch'])
@pytest.mark.parametrize('pos_emb_config', [{
'alibi': True,
'rope': False
@@ -799,7 +791,6 @@ def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict):
@pytest.mark.parametrize('attention_impl', [
'torch',
pytest.param('flash', marks=pytest.mark.gpu),
- pytest.param('triton', marks=pytest.mark.gpu),
pytest.param('torch', marks=pytest.mark.gpu)
])
@pytest.mark.parametrize('pos_emb_config', [{
@@ -1002,65 +993,9 @@ def test_forward_with_padding(attention_impl: str, pos_emb_config: dict,
atol=pad_vs_unpad_atol)
-@pytest.mark.parametrize('attention_impl', ['torch', 'triton'])
-def test_advanced_mask_building(attention_impl: str):
- # Test that the correct attention mask is created when both
- # prefix_mask and sequence_id are used
- hf_config = MPTConfig(
- init_device='cpu',
- d_model=16,
- n_heads=1,
- n_layers=1,
- expansion_ratio=1,
- max_seq_len=256,
- emb_pdrop=0.0,
- resid_pdrop=0.0,
- attn_config={
- 'attn_impl': attention_impl,
- 'prefix_lm': True,
- 'attn_uses_sequence_id': True,
- 'alibi': False,
- },
- )
- mpt = MPTForCausalLM(hf_config)
- mpt.eval()
-
- prefix_mask = torch.ByteTensor([[1, 1, 0, 0, 1, 1, 1, 0]])
- sequence_id = torch.LongTensor([[0, 0, 0, 0, 1, 1, 1, 1]])
-
- attn_bias, _ = mpt.transformer._attn_bias(device=mpt.device,
- dtype=torch.float32,
- attention_mask=None,
- prefix_mask=prefix_mask,
- sequence_id=sequence_id)
-
- assert isinstance(attn_bias, torch.Tensor)
- assert attn_bias.shape == torch.Size([1, 1, 8, 8])
-
- # We'll construct the expected value of attn_bias and then compare.
- can_attend = torch.tensor([
- [1, 1, 0, 0, 0, 0, 0, 0],
- [1, 1, 0, 0, 0, 0, 0, 0],
- [1, 1, 1, 0, 0, 0, 0, 0],
- [1, 1, 1, 1, 0, 0, 0, 0],
- [0, 0, 0, 0, 1, 1, 1, 0],
- [0, 0, 0, 0, 1, 1, 1, 0],
- [0, 0, 0, 0, 1, 1, 1, 0],
- [0, 0, 0, 0, 1, 1, 1, 1],
- ])
- can_attend = can_attend.bool().view(1, 1, 8, 8)
- expected_attn_bias = torch.zeros_like(attn_bias)
- expected_attn_bias = expected_attn_bias.masked_fill(
- torch.logical_not(can_attend),
- torch.finfo(attn_bias.dtype).min)
-
- assert torch.equal(attn_bias, expected_attn_bias)
-
-
@pytest.mark.parametrize('attention_impl,precision', [
('torch', 'fp32'),
pytest.param('flash', 'amp_bf16', marks=pytest.mark.gpu),
- pytest.param('triton', 'amp_bf16', marks=pytest.mark.gpu),
pytest.param('torch', 'amp_bf16', marks=pytest.mark.gpu),
pytest.param('torch', 'fp32', marks=pytest.mark.gpu),
])
@@ -1313,7 +1248,6 @@ def test_save_from_pretrained(tmp_path: pathlib.Path):
@pytest.mark.parametrize('attn_impl', [
'torch',
pytest.param('flash', marks=pytest.mark.gpu),
- pytest.param('triton', marks=pytest.mark.gpu),
])
@pytest.mark.parametrize('pos_emb_config', [{
'alibi': False,
@@ -1447,7 +1381,6 @@ def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict):
@pytest.mark.parametrize('attn_impl', [
'torch',
pytest.param('flash', marks=pytest.mark.gpu),
- pytest.param('triton', marks=pytest.mark.gpu),
])
@pytest.mark.parametrize('pos_emb_config', [{
'alibi': False,
@@ -1586,7 +1519,6 @@ def test_forward_with_cache(attn_impl: str, pos_emb_config: dict,
@pytest.mark.parametrize('attn_impl', [
'torch',
pytest.param('flash', marks=pytest.mark.gpu),
- pytest.param('triton', marks=pytest.mark.gpu),
])
@pytest.mark.parametrize('pos_emb_config', [{
'alibi': False,
@@ -1685,7 +1617,6 @@ def test_generate_with_past_kv(attn_impl: str, pos_emb_config: dict,
@pytest.mark.parametrize('attn_impl', [
'torch',
pytest.param('flash', marks=pytest.mark.gpu),
- pytest.param('triton', marks=pytest.mark.gpu),
])
@pytest.mark.parametrize('generation_kwargs', [{
'max_new_tokens': 2,
@@ -1882,7 +1813,6 @@ def test_alibi_vs_hf():
@pytest.mark.parametrize('attn_impl', [
'torch',
pytest.param('flash', marks=pytest.mark.gpu),
- pytest.param('triton', marks=pytest.mark.gpu),
pytest.param('torch', marks=pytest.mark.gpu),
])
@pytest.mark.parametrize('pos_emb_config', [{
@@ -1915,7 +1845,7 @@ def test_forward_with_output_attentions_and_output_hidden_states(
attn_impl: str, pos_emb_config: dict):
if pos_emb_config['alibi'] and not check_alibi_support(attn_impl):
pytest.skip(f'flash attention below v2.4.2 does not support alibi.')
- if attn_impl in ['flash', 'triton']:
+ if attn_impl == 'flash':
pytest.skip(f'output_attentions only implemented with torch attention.')
if pos_emb_config['rope'] and pos_emb_config[
'rope_impl'] == 'dail' and not is_flash_v2_installed():
@@ -2039,7 +1969,7 @@ def test_hf_init(tmp_path: pathlib.Path,
prepare_fsdp_module(model, optimizer, fsdp_config, precision, device, False)
- model = HuggingFaceModelWithZLoss(model, tokenizer)
+ model = HuggingFaceModelWithFSDP(model, tokenizer)
batch = gen_random_batch(batch_size, test_cfg)
@@ -2058,7 +1988,7 @@ def test_hf_init(tmp_path: pathlib.Path,
@pytest.mark.gpu
-def test_head_dim_8_triton_mqa_attn(batch_size: int = 2):
+def test_head_dim_8_flash_mqa_attn(batch_size: int = 2):
test_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml')
test_cfg.device = torch.cuda.current_device()
@@ -2074,7 +2004,7 @@ def test_head_dim_8_triton_mqa_attn(batch_size: int = 2):
emb_pdrop=0.1,
resid_pdrop=0.2,
attn_config={
- 'attn_impl': 'triton',
+ 'attn_impl': 'flash',
'attn_type': 'multiquery_attention'
},
)
@@ -2086,7 +2016,7 @@ def test_head_dim_8_triton_mqa_attn(batch_size: int = 2):
mpt = MPTForCausalLM(hf_config)
- model = HuggingFaceModelWithZLoss(mpt, tokenizer, shift_labels=True)
+ model = HuggingFaceModelWithFSDP(mpt, tokenizer, shift_labels=True)
model = model.to(test_cfg.device)
batch = gen_random_batch(batch_size, test_cfg)
diff --git a/tests/models/test_mpt_gen.py b/tests/models/test_mpt_gen.py
index 9f022ef487..35f130cd46 100644
--- a/tests/models/test_mpt_gen.py
+++ b/tests/models/test_mpt_gen.py
@@ -28,7 +28,6 @@ def forward(
input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
- prefix_mask: Optional[torch.ByteTensor] = None,
sequence_id: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
@@ -38,7 +37,7 @@ def forward(
inputs_embeds: Optional[torch.FloatTensor] = None,
):
result = super().forward(input_ids, past_key_values, attention_mask,
- prefix_mask, sequence_id, labels, return_dict,
+ sequence_id, labels, return_dict,
output_attentions, output_hidden_states,
use_cache, inputs_embeds)
# Modify the logits to select the next token.
@@ -53,7 +52,7 @@ def forward(
@pytest.mark.world_size(2)
@pytest.mark.gpu
-@pytest.mark.parametrize('attn_impl', ['triton', 'torch'])
+@pytest.mark.parametrize('attn_impl', ['flash', 'torch'])
@pytest.mark.parametrize('use_alibi', [True, False])
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
@patch('llmfoundry.models.mpt.modeling_mpt.MPTForCausalLM',
@@ -93,7 +92,7 @@ def test_mpt_generate_multi_gpu(attn_impl: str, use_alibi: bool,
@pytest.mark.gpu
-@pytest.mark.parametrize('attn_impl', ['triton', 'torch'])
+@pytest.mark.parametrize('attn_impl', ['flash', 'torch'])
@pytest.mark.parametrize('use_alibi', [True, False])
def test_mpt_generate_callback(attn_impl: str, use_alibi: bool,
build_tiny_mpt: Callable[...,
@@ -144,7 +143,7 @@ def test_mpt_generate_callback(attn_impl: str, use_alibi: bool,
@pytest.mark.gpu
-@pytest.mark.parametrize('attn_impl', ['triton', 'torch'])
+@pytest.mark.parametrize('attn_impl', ['flash', 'torch'])
@pytest.mark.parametrize('use_alibi', [True, False])
def test_mpt_generate_callback_not_tied(
use_alibi: bool, attn_impl: str,