From 2d7390e5cd12e70e48681e8d2f60e96ada8a8d6d Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Mon, 25 Mar 2024 22:24:52 -0700 Subject: [PATCH] Remove llama attention patch (#1066) --- llmfoundry/models/hf/hf_causal_lm.py | 13 -- .../layers/llama_attention_monkeypatch.py | 200 ------------------ tests/models/layers/test_huggingface_flash.py | 80 ------- 3 files changed, 293 deletions(-) delete mode 100644 llmfoundry/models/layers/llama_attention_monkeypatch.py diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 8f95ba06c2..38ed7a7e70 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -21,7 +21,6 @@ from llmfoundry.models.layers.attention import is_flash_v2_installed from llmfoundry.models.utils import init_empty_weights from llmfoundry.utils.config_utils import pop_config -from llmfoundry.utils.warnings import VersionedDeprecationWarning if TYPE_CHECKING: from peft import PeftConfig @@ -56,8 +55,6 @@ class ComposerHFCausalLM(HuggingFaceModelWithFSDP): cfg.use_train_metrics (bool, optional): Whether to use training metrics. Default: ``True``. cfg.load_in_8bit (bool, optional): Whether to load the model in 8-bit mode. Default: ``False``. cfg.init_device (str, optional): Which device to initialize the model on. Default: ``'cpu'``. - cfg.attention_patch_type (str, optional): Which attention patch to use for llama models. Default: ``None``. - Deprecated. Will automatically use flash attention 2. cfg.use_flash_attention_2 (bool, optional): Whether to use flash-attention 2. Default: ``False``. tokenizer (PreTrainedTokenizer): The tokenizer that the model will use. """ @@ -90,16 +87,6 @@ def __init__(self, om_model_config: DictConfig, init_device = om_model_config.get('init_device', 'cpu') # Resolve "mixed" init device to either "cpu" or "meta" resolved_init_device = hf_get_init_device(init_device) - attention_patch_type = om_model_config.get('attention_patch_type', None) - if attention_patch_type is not None: - warnings.warn( - VersionedDeprecationWarning( - 'attention_patch_type is deprecated and will automatically use flash attention 2. ' - + - 'We recommend `use_flash_attention_2: true` for llama models.', - remove_version='0.7.0')) - use_flash_attention_2 = True - requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager' if use_flash_attention_2 and not is_flash_v2_installed(): diff --git a/llmfoundry/models/layers/llama_attention_monkeypatch.py b/llmfoundry/models/layers/llama_attention_monkeypatch.py deleted file mode 100644 index 1f1d600693..0000000000 --- a/llmfoundry/models/layers/llama_attention_monkeypatch.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright 2022 MosaicML LLM Foundry authors -# SPDX-License-Identifier: Apache-2.0 - -# This file is copied and modified from -# https://github.com/huggingface/transformers/blob/fe3c8ab1af558b95f67f5fafc0c55f09fd2b09db/src/transformers/models/llama/modeling_llama.py -# See the clearly denoted code blocks for the main modifications (there are a few others like type ignores, and error messages) - -import logging -from typing import Callable, Optional, Tuple - -import torch -import torch.nn.functional as F -from transformers.models.llama.modeling_llama import LlamaAttention - -from llmfoundry.models.layers.attention import \ - scaled_multihead_dot_product_attention - -log = logging.getLogger(__name__) - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """Equivalent of torch.repeat_interleave(x, dim=1, - - repeats=n_rep). - - The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to - (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, - None, :, :].expand(batch, num_key_value_heads, - n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, - head_dim) - - -def rotate_half(x: torch.Tensor) -> torch.Tensor: - """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb( - q: torch.Tensor, - k: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - position_ids: Optional[torch.Tensor] = None, - unsqueeze_dim: int = 1, -): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def get_llama_attention_patch_fn(patch_fn_name: str = 'torch') -> Callable: - if patch_fn_name == 'torch': - return llama_attention_patch_torch - else: - raise ValueError( - f'Unrecognized llama attention patch function: {patch_fn_name}') - - -def llama_attention_patch_torch( - self: LlamaAttention, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if use_cache: - raise NotImplementedError( - 'use_cache is not yet supported when patching Llama attention.') - - bsz, q_len, _ = hidden_states.size() - - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * - self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, - dim=0) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [ - F.linear(hidden_states, query_slices[i]) - for i in range(self.config.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - F.linear(hidden_states, key_slices[i]) - for i in range(self.config.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) - - value_states = [ - F.linear(hidden_states, value_slices[i]) - for i in range(self.config.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, position_ids) - - query_states, key_states = apply_rotary_pos_emb( - q=query_states, - k=key_states, - cos=cos, - sin=sin, - position_ids=None, - ) - - ### MAIN MODIFICATIONS START HERE ### - query_states = query_states.transpose(1, 2).view( - bsz, q_len, self.num_heads * self.head_dim) - key_states = key_states.transpose(1, 2).view( - bsz, q_len, self.num_key_value_heads * self.head_dim) - value_states = value_states.transpose(1, 2).view( - bsz, q_len, self.num_key_value_heads * self.head_dim) - - attn_output, attn_weights, _ = scaled_multihead_dot_product_attention( - query=query_states, - key=key_states, - value=value_states, - n_heads=self.num_heads, - kv_n_heads=self.num_key_value_heads, - past_key_value=None, - softmax_scale=None, - attn_bias=attention_mask, - key_padding_mask=None, - is_causal=False, # The causal mask is propagated from LLamaForCausalLM - dropout_p=0, - training=self.training, - needs_weights=False, - ) - ### MAIN MODIFICATIONS END HERE ### - - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // - self.config.pretraining_tp, - dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // - self.config.pretraining_tp, - dim=1) - attn_output = sum([ - F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.config.pretraining_tp) - ]) - else: - attn_output = self.o_proj(attn_output) - - assert isinstance(attn_output, torch.Tensor) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, None diff --git a/tests/models/layers/test_huggingface_flash.py b/tests/models/layers/test_huggingface_flash.py index 4e35aeb153..1e8ec2383d 100644 --- a/tests/models/layers/test_huggingface_flash.py +++ b/tests/models/layers/test_huggingface_flash.py @@ -3,96 +3,16 @@ import contextlib import os -from unittest.mock import patch import pytest -import torch -import transformers from composer.core.precision import get_precision_context -from composer.utils import reproducibility from omegaconf import OmegaConf as om -from transformers.models.llama.modeling_llama import LlamaAttention from llmfoundry.models.hf.hf_fsdp import rgetattr from llmfoundry.models.layers.attention import is_flash_v2_installed -from llmfoundry.models.layers.llama_attention_monkeypatch import \ - llama_attention_patch_torch from llmfoundry.utils.builders import build_composer_model, build_tokenizer -@pytest.mark.parametrize('patch_fn_name', ['torch']) -@pytest.mark.parametrize('explicit_mask', [True, False]) -@pytest.mark.parametrize( - 'model_name', ['meta-llama/Llama-2-7b-hf', 'meta-llama/Llama-2-70b-hf']) -@pytest.mark.gpu -def test_patch_equivalence(patch_fn_name: str, explicit_mask: bool, - model_name: str): - if 'HUGGING_FACE_HUB_TOKEN' not in os.environ: - pytest.skip( - 'The CI cluster does not have access to the Llama models, so skip this test.' - ) - - device = 'cuda:0' - sequence_length = 64 - model_dim = 128 if '7b' in model_name else 256 - batch_size = 2 - if patch_fn_name == 'torch': - patch_fn = llama_attention_patch_torch - dtype = torch.float32 - atol = 0.0 - rtol = 0.0 - else: - raise ValueError(f'Unknown patch_fn_name: {patch_fn_name}') - - llama_config = transformers.AutoConfig.from_pretrained( - model_name, use_auth_token=True, hidden_size=model_dim) - - reproducibility.seed_all(42) - attention = LlamaAttention(config=llama_config,) - attention.to(dtype=dtype, device=device) - - rng = torch.Generator(device=device).manual_seed(42) - hidden_states = torch.randn(batch_size, - sequence_length, - model_dim, - generator=rng, - dtype=dtype, - device=device) - causal_mask = torch.full((sequence_length, sequence_length), - torch.finfo(torch.float32).min, - device=device) - causal_mask = causal_mask.triu(diagonal=1) - causal_mask = causal_mask[None, - None, :, :].expand(batch_size, 1, sequence_length, - sequence_length) - position_ids = torch.arange(sequence_length, - dtype=torch.long, - device=device) - position_ids = position_ids[None, :].expand(batch_size, sequence_length) - - attn_output, _, _ = attention( - hidden_states=hidden_states, - attention_mask=causal_mask if explicit_mask else None, - position_ids=position_ids, - past_key_value=None, - use_cache=False, - ) - - reproducibility.seed_all(42) - with patch.object(LlamaAttention, 'forward', new=patch_fn): - attention = LlamaAttention(config=llama_config,) - attention.to(dtype=dtype, device=device) - new_output, _, _ = attention( - hidden_states=hidden_states, - attention_mask=causal_mask if explicit_mask else None, - position_ids=position_ids, - past_key_value=None, - use_cache=False, - ) - - assert torch.allclose(attn_output, new_output, atol=atol, rtol=rtol) - - @pytest.mark.gpu @pytest.mark.world_size(2) @pytest.mark.parametrize('model_name', ['llama2', 'mistral'])