diff --git a/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py b/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py index be5a705595..1275906804 100644 --- a/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py @@ -7,6 +7,7 @@ from typing import Optional, Tuple import torch +from accelerate import init_empty_weights from flash_attn.flash_attn_interface import flash_attn_func from transformers import AutoConfig, AutoModelForCausalLM @@ -17,7 +18,8 @@ def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"): # this is a wonky hack to get the remotely loaded module model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) # we need to load the model here in order for modeling_btlm to be available - AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + with init_empty_weights(): + AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) module_name = model_config.__class__.__module__.replace( ".configuration_btlm", ".modeling_btlm" ) diff --git a/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py b/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py new file mode 100644 index 0000000000..0269f90157 --- /dev/null +++ b/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py @@ -0,0 +1,415 @@ +# coding=utf-8 +# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This code is based off the following work: +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py +""" PyTorch StableLM Epoch model. """ +import importlib +import math +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from accelerate import init_empty_weights +from einops import rearrange +from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports + flash_attn_varlen_qkvpacked_func, +) +from torch import nn +from transformers import AutoConfig, AutoModelForCausalLM +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.utils import logging + +from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids + +logger = logging.get_logger(__name__) + + +def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stablelm-3b-4e1t"): + # this is a wonky hack to get the remotely loaded module + model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + # we need to load the model here in order for modeling_stablelm_epoch to be available + with init_empty_weights(): + AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + module_name = model_config.__class__.__module__.replace( + ".configuration_stablelm_epoch", ".modeling_stablelm_epoch" + ) + modeling_stablelm = importlib.import_module(module_name) + modeling_stablelm.Attention.forward = ( # pylint: disable=protected-access + flashattn_attn + ) + modeling_stablelm.StableLMEpochModel.forward = ( # pylint: disable=protected-access + stablelm_model_forward + ) + modeling_stablelm.DecoderLayer.forward = ( # pylint: disable=protected-access + decoder_layer_forward + ) + + +def rotate_half(x: torch.Tensor): + """Rotates half the hidden dims of the input.""" + # pylint: disable=invalid-name + x1, x2 = torch.chunk(x, 2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + # pylint: disable=invalid-name + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def flashattn_attn( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + position_ids: torch.LongTensor, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, # pylint: disable=unused-argument + use_cache: Optional[bool] = False, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + query_rot = query_states[..., : self.rotary_ndims] + query_pass = query_states[..., self.rotary_ndims :] + key_rot = key_states[..., : self.rotary_ndims] + key_pass = key_states[..., self.rotary_ndims :] + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_rot, key_rot, cos, sin, position_ids + ) + + # [batch_size, num_heads, seq_len, head_dim] + query_states = torch.cat((query_states, query_pass), dim=-1) + key_states = torch.cat((key_states, key_pass), dim=-1) + + if past_key_value is not None: + # Reuse k, v, self_attention + key_states = torch.cat((past_key_value[0], key_states), dim=2) + value_states = torch.cat((past_key_value[1], value_states), dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # Repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: + # special handling using sample packing + qkv = torch.stack( + [query_states, key_states, value_states], dim=2 + ) # [bsz, nh, 3, q_len, hd] + qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] + qkv = rearrange(qkv, "b s ... -> (b s) ...") + softmax_scale = None + + output = flash_attn_varlen_qkvpacked_func( + qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=softmax_scale, causal=True + ) + + attn_output = rearrange(output, "(b s) ... -> b s ...", b=bsz) + attn_output = rearrange(attn_output, "b s h d -> b s (h d)") + else: + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # Upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + # Merge heads + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + # Final linear projection + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +def decoder_layer_forward( + self, + hidden_states: Optional[torch.FloatTensor], + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[torch.Tensor] = None, +) -> Union[ + Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]] +]: + # pylint: disable=duplicate-code + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +def stablelm_model_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + # pylint: disable=duplicate-code + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # Retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + if input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + cu_seqlens = None + max_seqlen = None + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids) + cu_seqlens = cu_seqlens.squeeze() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # Embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device, + ) + attention_mask = ( + self._prepare_decoder_attention_mask( # pylint: disable=protected-access + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # Decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + None, + cu_seqlens, + max_seqlen, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # Add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c05bccbf08..aa6049bd3e 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -124,6 +124,17 @@ def load_model( replace_btlm_attn_with_flash_attn(cfg.base_model) + if ( + hasattr(model_config, "model_type") + and model_config.model_type == "stablelm_epoch" + ): + if cfg.flash_attention and cfg.sample_packing: + from axolotl.monkeypatch.stablelm_attn_hijack_flash import ( + replace_stablelm_attn_with_flash_attn, + ) + + replace_stablelm_attn_with_flash_attn(cfg.base_model) + if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing: if cfg.device not in ["mps", "cpu"] and not inference: from axolotl.monkeypatch.llama_attn_hijack_flash import (