Skip to content

Commit

Permalink
Remove llama attention patch (#1066)
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Mar 26, 2024
1 parent b71e4b0 commit 2d7390e
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 293 deletions.
13 changes: 0 additions & 13 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.utils import init_empty_weights
from llmfoundry.utils.config_utils import pop_config
from llmfoundry.utils.warnings import VersionedDeprecationWarning

if TYPE_CHECKING:
from peft import PeftConfig
Expand Down Expand Up @@ -56,8 +55,6 @@ class ComposerHFCausalLM(HuggingFaceModelWithFSDP):
cfg.use_train_metrics (bool, optional): Whether to use training metrics. Default: ``True``.
cfg.load_in_8bit (bool, optional): Whether to load the model in 8-bit mode. Default: ``False``.
cfg.init_device (str, optional): Which device to initialize the model on. Default: ``'cpu'``.
cfg.attention_patch_type (str, optional): Which attention patch to use for llama models. Default: ``None``.
Deprecated. Will automatically use flash attention 2.
cfg.use_flash_attention_2 (bool, optional): Whether to use flash-attention 2. Default: ``False``.
tokenizer (PreTrainedTokenizer): The tokenizer that the model will use.
"""
Expand Down Expand Up @@ -90,16 +87,6 @@ def __init__(self, om_model_config: DictConfig,
init_device = om_model_config.get('init_device', 'cpu')
# Resolve "mixed" init device to either "cpu" or "meta"
resolved_init_device = hf_get_init_device(init_device)
attention_patch_type = om_model_config.get('attention_patch_type', None)
if attention_patch_type is not None:
warnings.warn(
VersionedDeprecationWarning(
'attention_patch_type is deprecated and will automatically use flash attention 2. '
+
'We recommend `use_flash_attention_2: true` for llama models.',
remove_version='0.7.0'))
use_flash_attention_2 = True

requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager'

if use_flash_attention_2 and not is_flash_v2_installed():
Expand Down
200 changes: 0 additions & 200 deletions llmfoundry/models/layers/llama_attention_monkeypatch.py

This file was deleted.

80 changes: 0 additions & 80 deletions tests/models/layers/test_huggingface_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,96 +3,16 @@

import contextlib
import os
from unittest.mock import patch

import pytest
import torch
import transformers
from composer.core.precision import get_precision_context
from composer.utils import reproducibility
from omegaconf import OmegaConf as om
from transformers.models.llama.modeling_llama import LlamaAttention

from llmfoundry.models.hf.hf_fsdp import rgetattr
from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.layers.llama_attention_monkeypatch import \
llama_attention_patch_torch
from llmfoundry.utils.builders import build_composer_model, build_tokenizer


@pytest.mark.parametrize('patch_fn_name', ['torch'])
@pytest.mark.parametrize('explicit_mask', [True, False])
@pytest.mark.parametrize(
'model_name', ['meta-llama/Llama-2-7b-hf', 'meta-llama/Llama-2-70b-hf'])
@pytest.mark.gpu
def test_patch_equivalence(patch_fn_name: str, explicit_mask: bool,
model_name: str):
if 'HUGGING_FACE_HUB_TOKEN' not in os.environ:
pytest.skip(
'The CI cluster does not have access to the Llama models, so skip this test.'
)

device = 'cuda:0'
sequence_length = 64
model_dim = 128 if '7b' in model_name else 256
batch_size = 2
if patch_fn_name == 'torch':
patch_fn = llama_attention_patch_torch
dtype = torch.float32
atol = 0.0
rtol = 0.0
else:
raise ValueError(f'Unknown patch_fn_name: {patch_fn_name}')

llama_config = transformers.AutoConfig.from_pretrained(
model_name, use_auth_token=True, hidden_size=model_dim)

reproducibility.seed_all(42)
attention = LlamaAttention(config=llama_config,)
attention.to(dtype=dtype, device=device)

rng = torch.Generator(device=device).manual_seed(42)
hidden_states = torch.randn(batch_size,
sequence_length,
model_dim,
generator=rng,
dtype=dtype,
device=device)
causal_mask = torch.full((sequence_length, sequence_length),
torch.finfo(torch.float32).min,
device=device)
causal_mask = causal_mask.triu(diagonal=1)
causal_mask = causal_mask[None,
None, :, :].expand(batch_size, 1, sequence_length,
sequence_length)
position_ids = torch.arange(sequence_length,
dtype=torch.long,
device=device)
position_ids = position_ids[None, :].expand(batch_size, sequence_length)

attn_output, _, _ = attention(
hidden_states=hidden_states,
attention_mask=causal_mask if explicit_mask else None,
position_ids=position_ids,
past_key_value=None,
use_cache=False,
)

reproducibility.seed_all(42)
with patch.object(LlamaAttention, 'forward', new=patch_fn):
attention = LlamaAttention(config=llama_config,)
attention.to(dtype=dtype, device=device)
new_output, _, _ = attention(
hidden_states=hidden_states,
attention_mask=causal_mask if explicit_mask else None,
position_ids=position_ids,
past_key_value=None,
use_cache=False,
)

assert torch.allclose(attn_output, new_output, atol=atol, rtol=rtol)


@pytest.mark.gpu
@pytest.mark.world_size(2)
@pytest.mark.parametrize('model_name', ['llama2', 'mistral'])
Expand Down

0 comments on commit 2d7390e

Please sign in to comment.