diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py
index 3bb9eed043..51fa67993a 100644
--- a/llmfoundry/__init__.py
+++ b/llmfoundry/__init__.py
@@ -4,6 +4,11 @@
import torch
try:
+ # Before importing any transformers models, we need to disable transformers flash attention if
+ # we are in an environment with flash attention version <2. Transformers hard errors on a not properly
+ # gated import otherwise.
+ import transformers
+
from llmfoundry import optim, utils
from llmfoundry.data import (ConcatTokensDataset,
MixtureOfDenoisersCollator, NoConcatDataset,
@@ -14,8 +19,8 @@
ComposerHFT5)
from llmfoundry.models.layers.attention import (
MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
- flash_attn_fn, scaled_multihead_dot_product_attention,
- triton_flash_attn_fn)
+ flash_attn_fn, is_flash_v1_installed,
+ scaled_multihead_dot_product_attention, triton_flash_attn_fn)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import (FFN_CLASS_REGISTRY, MPTMLP,
build_ffn)
@@ -24,6 +29,8 @@
MPTForCausalLM, MPTModel,
MPTPreTrainedModel)
from llmfoundry.tokenizers import TiktokenTokenizerWrapper
+ if is_flash_v1_installed():
+ transformers.utils.is_flash_attn_available = lambda: False
except ImportError as e:
try:
diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py
index 13857e9bb9..eb90b07045 100644
--- a/llmfoundry/models/hf/hf_causal_lm.py
+++ b/llmfoundry/models/hf/hf_causal_lm.py
@@ -24,8 +24,7 @@
from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
-from llmfoundry.models.layers.llama_attention_monkeypatch import \
- get_llama_attention_patch_fn
+from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.utils import init_empty_weights
try:
@@ -95,12 +94,28 @@ def __init__(self, om_model_config: Union[DictConfig,
# load the model config
trust_remote_code = om_model_config.get('trust_remote_code', True)
use_auth_token = om_model_config.get('use_auth_token', False)
+ use_flash_attention_2 = om_model_config.get('use_flash_attention_2',
+ False)
+ if use_flash_attention_2 and not is_flash_v2_installed():
+ raise ValueError(
+ 'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. '
+ + 'Please install flash_attn==2.3.2`.')
+
config = AutoConfig.from_pretrained(
om_model_config.pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
)
+ # This is not how you are supposed to set this, but transformers currently only
+ # supports enabling flash attention 2 when using the from_pretrained API.
+ # We need to support it for both from_pretrained and from_config, so we have to
+ # set the private attribute here. This will just skip all of transformers'
+ # validation logic that it is ok to use flash attention 2, so we check
+ # whether it is installed above, and whether the chosen config supports it here.
+ # https://github.com/huggingface/transformers/issues/26878
+ config._flash_attn_2_enabled = use_flash_attention_2
+
# set config overrides
for k, v in om_model_config.get('config_overrides', {}).items():
if not hasattr(config, k):
@@ -200,6 +215,9 @@ def __init__(self, om_model_config: Union[DictConfig,
)
from transformers.models.llama.modeling_llama import \
LlamaAttention
+
+ from llmfoundry.models.layers.llama_attention_monkeypatch import \
+ get_llama_attention_patch_fn
LlamaAttention.forward = get_llama_attention_patch_fn(
attention_patch_type)
model.config.use_cache = False
diff --git a/llmfoundry/tokenizers/tiktoken.py b/llmfoundry/tokenizers/tiktoken.py
index 41518a582a..45192e09dd 100644
--- a/llmfoundry/tokenizers/tiktoken.py
+++ b/llmfoundry/tokenizers/tiktoken.py
@@ -155,7 +155,7 @@ def convert_ids_to_tokens(
"""
if isinstance(ids, int):
if ids in self.added_tokens_decoder:
- return self.added_tokens_decoder[ids]
+ return str(self.added_tokens_decoder[ids])
return self._convert_id_to_token(ids)
@@ -171,7 +171,7 @@ def convert_ids_to_tokens(
if index in self.added_tokens_decoder:
tokens.append(self.encoding.decode(current_stream))
current_stream = []
- tokens.append(self.added_tokens_decoder[index])
+ tokens.append(str(self.added_tokens_decoder[index]))
else:
current_stream.append(index)
diff --git a/scripts/train/README.md b/scripts/train/README.md
index f10fdf59f0..4c706dc040 100644
--- a/scripts/train/README.md
+++ b/scripts/train/README.md
@@ -5,14 +5,15 @@ This README walks through pretraining and finetuning a large language model usin
#### Table of Contents
1. [Part 1: LLM Pretraining](#llmpretraining)
1. [Installation](#installation)
- 2. [Dataset Preparation](#datasetpreparation)
- 3. [How to start single and multi-node pretraining](#howtostartpretraining)
-2. [Part 2: LLM Finetuning](#llmfinetuning)
+ 1. [Dataset Preparation](#datasetpreparation)
+ 1. [How to start single and multi-node pretraining](#howtostartpretraining)
+1. [Part 2: LLM Finetuning](#llmfinetuning)
1. [Using a dataset on the HuggingFace Hub](#hfdataset)
- 2. [Using a local dataset](#localdataset)
- 3. [Using a StreamingDataset (MDS) formatted dataset locally or in an object store](#mdsdataset)
-3. [FAQ: How many GPUs do I need to train a LLM?](#howmandygpus)
-4. [FAQ: Optimizing Performance](#optimizingperformance)
+ 1. [Using a local dataset](#localdataset)
+ 1. [Using a StreamingDataset (MDS) formatted dataset locally or in an object store](#mdsdataset)
+1. [Using Flash Attention](#flashattention)
+1. [FAQ: How many GPUs do I need to train a LLM?](#howmandygpus)
+1. [FAQ: Optimizing Performance](#optimizingperformance)
# Part 1: LLM Pretraining
@@ -332,6 +333,53 @@ train_loader:
...
```
+# Using Flash Attention
+
+Flash Attention is an optimized implementation of the attention mechanism, first introduced by [Dao et al.](https://github.com/Dao-AILab/flash-attention). There are three versions of Flash Attention that can be used with LLM Foundry: Flash Attention V1, Flash Attention V2, and a Triton implementation of Flash Attention. To start, we recommend using one of our [provided Docker images](../../README.md#mosaicml-docker-images) corresponding to the Flash Attention version you would like to use. The Triton implementation can be used with either Flash Attention V1 or V2. Next, how you specify to use Flash Attention depends on which model you are using.
+
+For MPT, you can specify Flash Attention in your YAML like so:
+```yaml
+model:
+ name: mpt_causal_lm
+ ...
+ attn_config:
+ # Will use either V1 or V2 depending on what is installed
+ # "triton" will use the Triton implementation
+ attn_impl: flash
+ ...
+```
+
+If loading MPT from the HuggingFace Hub, you can specify Flash Attention in your YAML like so:
+```yaml
+model:
+ name: hf_causal_lm
+ pretrained_model_name_or_path: mosaicml/mpt-7b
+ ...
+ config_overrides:
+ # Will use either V1 or V2 depending on what is installed
+ # "triton" will use the Triton implementation
+ attn_config:
+ attn_impl: flash
+ ...
+```
+
+For any HuggingFace model that supports Flash Attention (e.g. Llama and Mistral), you can specify Flash Attention in your YAML like so:
+```yaml
+model:
+ name: hf_causal_lm
+ use_flash_attention_2: True # Will be automatically set to True if Flash Attention V2 is installed and the model supports it
+ ...
+```
+HuggingFace models currently only support Flash Attention V2.
+
+For Llama specifically, we have another option if you would like to use the Triton implementation of Flash Attention. You can specify this in your YAML like so:
+```yaml
+model:
+ name: hf_causal_lm
+ pretrained_model_name_or_path: meta-llama/Llama-2-7b-hf
+ attention_patch_type: triton
+ ...
+```
# FAQ: How many GPUs do I need to train a LLM?
This is a complicated question in general, but if we assume that you are using FSDP with `FULL_SHARD`,
diff --git a/scripts/train/train.py b/scripts/train/train.py
index 28ecb68e34..8c1c28eb5c 100644
--- a/scripts/train/train.py
+++ b/scripts/train/train.py
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import copy
+import gc
import logging
import os
import sys
@@ -216,6 +217,12 @@ def main(cfg: DictConfig) -> Trainer:
os.environ[
'PYTORCH_CUDA_ALLOC_CONF'] = f'max_split_size_mb:{max_split_size_mb}'
+ # Set CUDA lazy loading
+ # This can save a bit of memory if not all modules are needed
+ cuda_load_lazy: bool = cfg.pop('cuda_load_lazy', True)
+ if cuda_load_lazy:
+ os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
+
# Set seed first
seed: int = pop_config(cfg, 'seed', must_exist=True)
reproducibility.seed_all(seed)
@@ -634,6 +641,7 @@ def main(cfg: DictConfig) -> Trainer:
print('Logging config')
log_config(logged_cfg)
torch.cuda.empty_cache()
+ gc.collect()
# Eval first if requested
if eval_first and trainer.state.timestamp.batch.value == 0:
diff --git a/setup.py b/setup.py
index d0ecc66160..63aac9d752 100644
--- a/setup.py
+++ b/setup.py
@@ -49,7 +49,7 @@
install_requires = [
'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.16.4,<0.17',
'accelerate>=0.20,<0.21', # for HF inference `device_map`
- 'transformers>=4.33,<4.34',
+ 'transformers>=4.34.1,<4.35',
'mosaicml-streaming>=0.6,<0.7',
'torch>=1.13.1,<2.1.1',
'datasets>=2.14.5,<2.15',
@@ -114,9 +114,10 @@
extra_deps['all-cpu'] = set(
dep for key, deps in extra_deps.items() for dep in deps if 'gpu' not in key)
extra_deps['all'] = set(dep for key, deps in extra_deps.items() for dep in deps
- if key != 'gpu-flash2')
-extra_deps['all-flash2'] = set(
- dep for key, deps in extra_deps.items() for dep in deps if key != 'gpu')
+ if key not in {'gpu-flash2', 'all-cpu'})
+extra_deps['all-flash2'] = set(dep for key, deps in extra_deps.items()
+ for dep in deps
+ if key not in {'gpu', 'all', 'all-cpu'})
setup(
name=_PACKAGE_NAME,
diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py
index e7787754de..fcb2cc3a7e 100644
--- a/tests/test_hf_conversion_script.py
+++ b/tests/test_hf_conversion_script.py
@@ -138,6 +138,49 @@ def check_hf_tokenizer_equivalence(tokenizer1: PreTrainedTokenizerBase,
tokenizer1.__dict__['init_kwargs'].pop('auto_map', None)
tokenizer2.__dict__['init_kwargs'].pop('auto_map', None)
+ # Additional special tokens do not match between original tokenizer and loaded tokenizer due to transformers
+ # constructor differences
+ additional_special_tokens_1 = {
+ t if isinstance(t, str) else t.content
+ for t in tokenizer1.__dict__.pop('_additional_special_tokens', [])
+ }
+ additional_special_tokens_2 = {
+ t if isinstance(t, str) else t.content
+ for t in tokenizer2.__dict__.pop('_additional_special_tokens', [])
+ }
+ # Also pop it out of init_kwargs
+ tokenizer1.__dict__['init_kwargs'].pop('additional_special_tokens', None)
+ tokenizer2.__dict__['init_kwargs'].pop('additional_special_tokens', None)
+ tokenizer1.__dict__['init_kwargs'].pop('added_tokens_decoder', None)
+ tokenizer2.__dict__['init_kwargs'].pop('added_tokens_decoder', None)
+ # If the additional special tokens are the same (or a subset of each other), or if one of them is empty, then we are good
+ assert additional_special_tokens_1.issubset(
+ additional_special_tokens_2) or additional_special_tokens_2.issubset(
+ additional_special_tokens_1)
+
+ # The special token attributes may be strings or they may be AddedToken objects, so we just check string values
+ # First check that they have the same attrs
+ assert tokenizer1.SPECIAL_TOKENS_ATTRIBUTES == tokenizer2.SPECIAL_TOKENS_ATTRIBUTES
+ # Then check that the values are the same
+ for special_token_attr in tokenizer1.SPECIAL_TOKENS_ATTRIBUTES:
+ # Skip additional_special_tokens because we already checked it above
+ if special_token_attr == 'additional_special_tokens':
+ continue
+
+ # The init_kwargs can change between the original tokenizer and the loaded tokenizer,
+ # so we just pop them
+ tokenizer1.__dict__['init_kwargs'].pop(special_token_attr, None)
+ tokenizer2.__dict__['init_kwargs'].pop(special_token_attr, None)
+
+ attr1 = tokenizer1.__dict__.pop('_' + special_token_attr, None)
+ attr2 = tokenizer2.__dict__.pop('_' + special_token_attr, None)
+ if attr1 is None and attr2 is None:
+ continue
+
+ attr_value1 = attr1 if isinstance(attr1, str) else attr1.content
+ attr_value2 = attr2 if isinstance(attr2, str) else attr2.content
+ assert attr_value1 == attr_value2
+
assert tokenizer1.__dict__ == tokenizer2.__dict__
diff --git a/tests/test_huggingface_flash.py b/tests/test_huggingface_flash.py
new file mode 100644
index 0000000000..a71217ea1f
--- /dev/null
+++ b/tests/test_huggingface_flash.py
@@ -0,0 +1,195 @@
+# Copyright 2022 MosaicML LLM Foundry authors
+# SPDX-License-Identifier: Apache-2.0
+
+import contextlib
+import os
+from unittest.mock import patch
+
+import pytest
+import torch
+import transformers
+from composer.core.precision import get_precision_context
+from composer.utils import reproducibility
+from omegaconf import OmegaConf as om
+
+from llmfoundry import COMPOSER_MODEL_REGISTRY
+from llmfoundry.models.hf.hf_fsdp import rgetattr
+from llmfoundry.models.layers.attention import (is_flash_v1_installed,
+ is_flash_v2_installed)
+from llmfoundry.utils.builders import build_tokenizer
+
+# Before importing any transformers models, we need to disable transformers flash attention if
+# we are in an environment with flash attention version <2. Transformers hard errors on a not properly
+# gated import otherwise.
+if is_flash_v1_installed():
+ transformers.utils.is_flash_attn_available = lambda: False
+
+from transformers.models.llama.modeling_llama import LlamaAttention
+
+from llmfoundry.models.layers.llama_attention_monkeypatch import (
+ llama_attention_patch_torch, llama_attention_patch_triton)
+
+
+@pytest.mark.parametrize('patch_fn_name', ['torch', 'triton'])
+@pytest.mark.parametrize('explicit_mask', [True, False])
+@pytest.mark.parametrize(
+ 'model_name', ['meta-llama/Llama-2-7b-hf', 'meta-llama/Llama-2-70b-hf'])
+@pytest.mark.gpu
+def test_patch_equivalence(patch_fn_name: str, explicit_mask: bool,
+ model_name: str):
+ if 'HUGGING_FACE_HUB_TOKEN' not in os.environ:
+ pytest.skip(
+ 'The CI cluster does not have access to the Llama models, so skip this test.'
+ )
+
+ device = 'cuda:0'
+ sequence_length = 4096
+ model_dim = 4096 if '7b' in model_name else 8192
+ batch_size = 2
+ if patch_fn_name == 'torch':
+ patch_fn = llama_attention_patch_torch
+ dtype = torch.float32
+ atol = 0.0
+ rtol = 0.0
+ elif patch_fn_name == 'triton':
+ # the huggingface implementation of llama performs the softmax in fp32
+ # this can result in fairly large differences for the triton implementation
+ # but the torch implementation produces the exact same output so we can confirm
+ # the implementation is correct
+ patch_fn = llama_attention_patch_triton
+ dtype = torch.bfloat16
+ atol = 1e-2
+ rtol = 1e-2
+ else:
+ raise ValueError(f'Unknown patch_fn_name: {patch_fn_name}')
+
+ llama_config = transformers.AutoConfig.from_pretrained(model_name,
+ use_auth_token=True)
+
+ reproducibility.seed_all(42)
+ attention = LlamaAttention(config=llama_config,)
+ attention.to(dtype=dtype, device=device)
+
+ rng = torch.Generator(device=device).manual_seed(42)
+ hidden_states = torch.randn(batch_size,
+ sequence_length,
+ model_dim,
+ generator=rng,
+ dtype=dtype,
+ device=device)
+ causal_mask = torch.full((sequence_length, sequence_length),
+ torch.finfo(torch.float32).min,
+ device=device)
+ causal_mask = causal_mask.triu(diagonal=1)
+ causal_mask = causal_mask[None,
+ None, :, :].expand(batch_size, 1, sequence_length,
+ sequence_length)
+ attn_output, _, _ = attention(
+ hidden_states=hidden_states,
+ attention_mask=causal_mask if explicit_mask else None,
+ position_ids=None,
+ past_key_value=None,
+ use_cache=False,
+ )
+
+ reproducibility.seed_all(42)
+ with patch.object(LlamaAttention, 'forward', new=patch_fn):
+ attention = LlamaAttention(config=llama_config,)
+ attention.to(dtype=dtype, device=device)
+ new_output, _, _ = attention(
+ hidden_states=hidden_states,
+ attention_mask=causal_mask if explicit_mask else None,
+ position_ids=None,
+ past_key_value=None,
+ use_cache=False,
+ )
+
+ assert torch.allclose(attn_output, new_output, atol=atol, rtol=rtol)
+
+
+@pytest.mark.gpu
+@pytest.mark.parametrize('model_name', ['llama2', 'mistral'])
+@pytest.mark.parametrize('use_flash_attention_2', [True, False])
+def test_flash2(model_name: str, use_flash_attention_2: bool):
+ if model_name == 'llama2':
+ if 'HUGGING_FACE_HUB_TOKEN' not in os.environ:
+ pytest.skip(
+ 'The CI cluster does not have access to the Llama models, so skip this test.'
+ )
+ model_cfg = {
+ 'name': 'hf_causal_lm',
+ 'pretrained_model_name_or_path': 'meta-llama/Llama-2-7b-hf',
+ 'config_overrides': {
+ 'num_hidden_layers': 2,
+ 'intermediate_size': 64,
+ },
+ 'use_auth_token': True,
+ 'pretrained': False,
+ 'init_device': 'cpu',
+ }
+
+ tokenizer_name = 'meta-llama/Llama-2-7b-hf'
+ from transformers.models.llama.modeling_llama import (
+ LlamaAttention, LlamaFlashAttention2)
+ flash_attn_class = LlamaFlashAttention2 if use_flash_attention_2 else LlamaAttention
+ attention_layers_attr = 'model.model.layers'
+ attention_attr = 'self_attn'
+ elif model_name == 'mistral':
+ model_cfg = {
+ 'name': 'hf_causal_lm',
+ 'pretrained_model_name_or_path': 'mistralai/Mistral-7B-v0.1',
+ 'config_overrides': {
+ 'num_hidden_layers': 2,
+ 'intermediate_size': 64,
+ },
+ 'pretrained': False,
+ 'init_device': 'cpu',
+ }
+
+ tokenizer_name = 'mistralai/Mistral-7B-v0.1'
+ from transformers.models.mistral.modeling_mistral import (
+ MistralAttention, MistralFlashAttention2)
+ flash_attn_class = MistralFlashAttention2 if use_flash_attention_2 else MistralAttention
+ attention_layers_attr = 'model.model.layers'
+ attention_attr = 'self_attn'
+ else:
+ raise ValueError(f'Unknown model: {model_name}')
+
+ if use_flash_attention_2:
+ model_cfg['use_flash_attention_2'] = True
+
+ model_cfg = om.create(model_cfg)
+
+ tokenizer = build_tokenizer(
+ tokenizer_name=tokenizer_name,
+ tokenizer_kwargs={'model_max_length': 10},
+ )
+ tokenizer.pad_token = tokenizer.eos_token
+
+ error_context = pytest.raises(
+ ValueError, match='use_flash_attention_2 is set to True'
+ ) if not is_flash_v2_installed(
+ ) and use_flash_attention_2 else contextlib.nullcontext()
+
+ with error_context:
+ model = COMPOSER_MODEL_REGISTRY[model_cfg['name']](model_cfg, tokenizer)
+
+ # check that it actually used flash attention 2
+ assert model.model.config._flash_attn_2_enabled if use_flash_attention_2 else not model.model.config._flash_attn_2_enabled
+ attention_layer = rgetattr(
+ rgetattr(model, attention_layers_attr)[0], attention_attr)
+ assert isinstance(attention_layer, flash_attn_class)
+
+ tokenized_input = tokenizer(['Hello world blah blah', 'Goodbye world'],
+ return_tensors='pt',
+ padding=True)
+ tokenized_input['labels'] = tokenized_input['input_ids'].clone()
+
+ tokenized_input = {k: v.cuda() for k, v in tokenized_input.items()}
+ model.to('cuda')
+
+ with get_precision_context('amp_bf16'):
+ # We're just testing that flash attention 2 runs okay
+ outputs = model(tokenized_input)
+ loss = outputs.loss
+ loss.backward()
diff --git a/tests/test_llama_patch.py b/tests/test_llama_patch.py
deleted file mode 100644
index b1cd3711e0..0000000000
--- a/tests/test_llama_patch.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# Copyright 2022 MosaicML LLM Foundry authors
-# SPDX-License-Identifier: Apache-2.0
-
-import os
-
-import pytest
-import torch
-import transformers
-from composer.utils import reproducibility
-from transformers.models.llama.modeling_llama import LlamaAttention
-
-from llmfoundry.models.layers.llama_attention_monkeypatch import (
- llama_attention_patch_torch, llama_attention_patch_triton)
-
-
-@pytest.mark.parametrize('patch_fn_name', ['torch', 'triton'])
-@pytest.mark.parametrize('explicit_mask', [True, False])
-@pytest.mark.parametrize(
- 'model_name', ['meta-llama/Llama-2-7b-hf', 'meta-llama/Llama-2-70b-hf'])
-@pytest.mark.gpu
-def test_patch_equivalence(patch_fn_name: str, explicit_mask: bool,
- model_name: str):
- if 'HUGGING_FACE_HUB_TOKEN' not in os.environ:
- pytest.skip(
- 'The CI cluster does not have access to the Llama models, so skip this test.'
- )
-
- original_forward = LlamaAttention.forward
-
- device = 'cuda:0'
- sequence_length = 4096
- model_dim = 4096 if '7b' in model_name else 8192
- batch_size = 2
- if patch_fn_name == 'torch':
- patch_fn = llama_attention_patch_torch
- dtype = torch.float32
- atol = 0.0
- rtol = 0.0
- elif patch_fn_name == 'triton':
- # the huggingface implementation of llama performs the softmax in fp32
- # this can result in fairly large differences for the triton implementation
- # but the torch implementation produces the exact same output so we can confirm
- # the implementation is correct
- patch_fn = llama_attention_patch_triton
- dtype = torch.bfloat16
- atol = 1e-2
- rtol = 1e-2
- else:
- raise ValueError(f'Unknown patch_fn_name: {patch_fn_name}')
-
- llama_config = transformers.AutoConfig.from_pretrained(model_name,
- use_auth_token=True)
-
- reproducibility.seed_all(42)
- attention = LlamaAttention(config=llama_config,)
- attention.to(dtype=dtype, device=device)
-
- rng = torch.Generator(device=device).manual_seed(42)
- hidden_states = torch.randn(batch_size,
- sequence_length,
- model_dim,
- generator=rng,
- dtype=dtype,
- device=device)
- causal_mask = torch.full((sequence_length, sequence_length),
- torch.finfo(torch.float32).min,
- device=device)
- causal_mask = causal_mask.triu(diagonal=1)
- causal_mask = causal_mask[None,
- None, :, :].expand(batch_size, 1, sequence_length,
- sequence_length)
- attn_output, _, _ = attention(
- hidden_states=hidden_states,
- attention_mask=causal_mask if explicit_mask else None,
- position_ids=None,
- past_key_value=None,
- use_cache=False,
- )
-
- reproducibility.seed_all(42)
- LlamaAttention.forward = patch_fn
- attention = LlamaAttention(config=llama_config,)
- attention.to(dtype=dtype, device=device)
- new_output, _, _ = attention(
- hidden_states=hidden_states,
- attention_mask=causal_mask if explicit_mask else None,
- position_ids=None,
- past_key_value=None,
- use_cache=False,
- )
-
- # Reset the forward function so patches don't persist
- LlamaAttention.forward = original_forward
-
- assert torch.allclose(attn_output, new_output, atol=atol, rtol=rtol)