From 814aee660306c89c8fcfba48d35b31db98467295 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 23 Jan 2024 12:54:36 -0500 Subject: [PATCH] Phi2 multipack (#1173) * phi2 multipack * update validation and examples for phi * more updates to phi examples * make sure to use the correct collator for phi multipack * phi needs attention mask now for multipack * if the special token already exists in the tokenizer, don't require in lora modules to save * fix qlora yml for phi, fix phi test validation * test qlora too * make sure flash attention is enabled for the test * don't use remote code for phi anymore * reduce sequence len for sample packing phi --- examples/phi/phi-ft.yml | 19 +- examples/phi/phi-qlora.yml | 21 +- examples/phi/phi2-ft.yml | 25 +- src/axolotl/core/trainer_builder.py | 2 +- src/axolotl/models/phi/__init__.py | 8 - .../phi/configuration_mixformer_sequential.py | 63 - src/axolotl/models/phi/configuration_phi.py | 65 - .../phi/modeling_mixformer_sequential.py | 930 -------------- src/axolotl/models/phi/modeling_phi.py | 1092 ----------------- src/axolotl/monkeypatch/phi/__init__.py | 12 + src/axolotl/utils/config.py | 14 - src/axolotl/utils/data.py | 2 +- src/axolotl/utils/lora_embeddings.py | 2 - src/axolotl/utils/models.py | 23 +- src/axolotl/utils/trainer.py | 9 +- tests/e2e/patched/test_phi_multipack.py | 123 ++ tests/e2e/test_phi.py | 52 +- tests/test_validation.py | 8 +- 18 files changed, 201 insertions(+), 2269 deletions(-) delete mode 100644 src/axolotl/models/phi/__init__.py delete mode 100644 src/axolotl/models/phi/configuration_mixformer_sequential.py delete mode 100644 src/axolotl/models/phi/configuration_phi.py delete mode 100644 src/axolotl/models/phi/modeling_mixformer_sequential.py delete mode 100644 src/axolotl/models/phi/modeling_phi.py create mode 100644 src/axolotl/monkeypatch/phi/__init__.py create mode 100644 tests/e2e/patched/test_phi_multipack.py diff --git a/examples/phi/phi-ft.yml b/examples/phi/phi-ft.yml index cab280c2a3..b21386f707 100644 --- a/examples/phi/phi-ft.yml +++ b/examples/phi/phi-ft.yml @@ -1,8 +1,6 @@ base_model: microsoft/phi-1_5 -model_type: PhiForCausalLM +model_type: AutoModelForCausalLM tokenizer_type: AutoTokenizer -is_llama_derived_model: false -trust_remote_code: true load_in_8bit: false load_in_4bit: false @@ -18,7 +16,7 @@ output_dir: ./phi-sft-out sequence_len: 2048 sample_packing: true -pad_to_sequence_len: +pad_to_sequence_len: true adapter: lora_model_dir: @@ -35,7 +33,7 @@ wandb_name: wandb_log_model: gradient_accumulation_steps: 1 -micro_batch_size: 1 +micro_batch_size: 2 num_epochs: 4 optimizer: adamw_torch adam_beta2: 0.95 @@ -45,18 +43,20 @@ lr_scheduler: cosine learning_rate: 0.000003 train_on_inputs: false -group_by_length: true +group_by_length: false bf16: auto fp16: tf32: true -gradient_checkpointing: +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: True early_stopping_patience: resume_from_checkpoint: local_rank: logging_steps: 1 xformers_attention: -flash_attention: +flash_attention: true warmup_steps: 100 evals_per_epoch: 4 @@ -68,7 +68,4 @@ fsdp: fsdp_config: resize_token_embeddings_to_32x: true special_tokens: - bos_token: "<|endoftext|>" - eos_token: "<|endoftext|>" - unk_token: "<|endoftext|>" pad_token: "<|endoftext|>" diff --git a/examples/phi/phi-qlora.yml b/examples/phi/phi-qlora.yml index bb0ff40be9..d2b5d661c9 100644 --- a/examples/phi/phi-qlora.yml +++ b/examples/phi/phi-qlora.yml @@ -1,8 +1,6 @@ base_model: microsoft/phi-1_5 model_type: AutoModelForCausalLM tokenizer_type: AutoTokenizer -is_llama_derived_model: false -trust_remote_code: true load_in_8bit: false load_in_4bit: true @@ -16,9 +14,9 @@ dataset_prepared_path: val_set_size: 0.05 output_dir: ./phi-sft-out -sequence_len: 1024 -sample_packing: false # not CURRENTLY compatible with LoRAs -pad_to_sequence_len: +sequence_len: 2048 +sample_packing: true +pad_to_sequence_len: true adapter: qlora lora_model_dir: @@ -35,7 +33,7 @@ wandb_name: wandb_log_model: gradient_accumulation_steps: 1 -micro_batch_size: 1 +micro_batch_size: 2 num_epochs: 4 optimizer: adamw_torch adam_beta2: 0.95 @@ -45,18 +43,20 @@ lr_scheduler: cosine learning_rate: 0.000003 train_on_inputs: false -group_by_length: true +group_by_length: false bf16: auto fp16: tf32: true -gradient_checkpointing: +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: True early_stopping_patience: resume_from_checkpoint: local_rank: logging_steps: 1 xformers_attention: -flash_attention: +flash_attention: true warmup_steps: 100 evals_per_epoch: 4 @@ -68,7 +68,4 @@ fsdp: fsdp_config: resize_token_embeddings_to_32x: true special_tokens: - bos_token: "<|endoftext|>" - eos_token: "<|endoftext|>" - unk_token: "<|endoftext|>" pad_token: "<|endoftext|>" diff --git a/examples/phi/phi2-ft.yml b/examples/phi/phi2-ft.yml index af146ae643..7a2d05d018 100644 --- a/examples/phi/phi2-ft.yml +++ b/examples/phi/phi2-ft.yml @@ -1,8 +1,6 @@ base_model: microsoft/phi-2 -model_revision: 834565c # pin model repo to the previous architecture model_type: AutoModelForCausalLM tokenizer_type: AutoTokenizer -trust_remote_code: true load_in_8bit: false load_in_4bit: false @@ -17,19 +15,16 @@ val_set_size: 0.05 output_dir: ./phi-sft-out sequence_len: 2048 -sample_packing: false # currently unsupported -pad_to_sequence_len: +sample_packing: true +pad_to_sequence_len: true adapter: lora_model_dir: -lora_r: 16 -lora_alpha: 32 -lora_dropout: 0.1 -lora_target_linear: true +lora_r: +lora_alpha: +lora_dropout: +lora_target_linear: lora_fan_in_fan_out: -lora_modules_to_save: - - embd - - lm_head wandb_project: wandb_entity: @@ -38,14 +33,14 @@ wandb_name: wandb_log_model: gradient_accumulation_steps: 1 -micro_batch_size: 1 +micro_batch_size: 2 num_epochs: 4 -optimizer: paged_adamw_8bit +optimizer: adamw_torch adam_beta2: 0.95 adam_epsilon: 0.00001 max_grad_norm: 1.0 lr_scheduler: cosine -learning_rate: 1e-5 +learning_rate: 0.000003 train_on_inputs: false group_by_length: false @@ -54,6 +49,8 @@ fp16: tf32: true gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: True early_stopping_patience: resume_from_checkpoint: local_rank: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index e109db7f84..0f62aae9af 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -930,7 +930,7 @@ def build_collator( ] ] if use_batch_sampler_collator: - if self.cfg.model_config_type in ["mixtral", "qwen2"]: + if self.cfg.model_config_type in ["mixtral", "qwen2", "falcon", "phi"]: collator = V2BatchSamplerDataCollatorForSeq2Seq else: collator = BatchSamplerDataCollatorForSeq2Seq diff --git a/src/axolotl/models/phi/__init__.py b/src/axolotl/models/phi/__init__.py deleted file mode 100644 index 76d6a0e10b..0000000000 --- a/src/axolotl/models/phi/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -""" -MixFormers model architecture used for phi models -""" - -from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa -from .configuration_phi import PhiConfig # noqa -from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa -from .modeling_phi import PhiForCausalLM # noqa diff --git a/src/axolotl/models/phi/configuration_mixformer_sequential.py b/src/axolotl/models/phi/configuration_mixformer_sequential.py deleted file mode 100644 index ceba62093a..0000000000 --- a/src/axolotl/models/phi/configuration_mixformer_sequential.py +++ /dev/null @@ -1,63 +0,0 @@ -# pylint: skip-file - -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import math -from typing import Any, Dict, List, Optional, Union - -from transformers import PretrainedConfig - - -class MixFormerSequentialConfig(PretrainedConfig): - """MixFormer (sequential for DeepSpeed) configuration.""" - - model_type = "mixformer-sequential" - - attribute_map = { - "max_position_embeddings": "n_positions", - "hidden_size": "n_embd", - "num_attention_heads": "n_head", - "num_hidden_layers": "n_layer", - "input_emb_layer": "embd_layer", # `input_emb_layer` key is for backward compatibility - "blocks": "architecture", # `blocks` key is for backward compatibility - } - - def __init__( - self, - vocab_size: Optional[int] = 50304, - n_positions: Optional[int] = 2048, - n_embd: Optional[int] = 1024, - n_layer: Optional[int] = 20, - n_inner: Optional[int] = None, - n_head: Optional[int] = 16, - rotary_dim: Optional[int] = 32, - activation_function: Optional[str] = "gelu_new", - embd_layer: Optional[str] = "default", - architecture: Union[Dict[str, Any], List[Dict[str, Any]]] = None, - embd_pdrop: Optional[float] = 0.0, - resid_pdrop: Optional[float] = 0.0, - layer_norm_epsilon: Optional[float] = 1e-5, - initializer_range: Optional[float] = 0.02, - tie_word_embeddings: Optional[bool] = False, - pad_vocab_size_multiple: Optional[int] = 64, - **kwargs - ) -> None: - self.vocab_size = int( - math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple - ) - self.n_positions = n_positions - self.n_embd = n_embd - self.n_layer = n_layer - self.n_inner = n_inner - self.n_head = n_head - self.rotary_dim = min(rotary_dim, n_embd // n_head) - self.activation_function = activation_function - self.embd_layer = embd_layer - self.architecture = architecture - self.embd_pdrop = embd_pdrop - self.resid_pdrop = resid_pdrop - self.layer_norm_epsilon = layer_norm_epsilon - self.initializer_range = initializer_range - - super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) diff --git a/src/axolotl/models/phi/configuration_phi.py b/src/axolotl/models/phi/configuration_phi.py deleted file mode 100644 index e941bf7980..0000000000 --- a/src/axolotl/models/phi/configuration_phi.py +++ /dev/null @@ -1,65 +0,0 @@ -# pylint: skip-file -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import math -from typing import Optional - -from transformers import PretrainedConfig - - -class PhiConfig(PretrainedConfig): - """Phi configuration.""" - - model_type = "phi" - attribute_map = { - "max_position_embeddings": "n_positions", - "hidden_size": "n_embd", - "num_attention_heads": "n_head", - "num_hidden_layers": "n_layer", - } - - def __init__( - self, - vocab_size: int = 50304, - n_positions: int = 2048, - n_embd: int = 1024, - n_layer: int = 20, - n_inner: Optional[int] = None, - n_head: int = 16, - n_head_kv: Optional[int] = None, - rotary_dim: Optional[int] = 32, - activation_function: Optional[str] = "gelu_new", - flash_attn: bool = False, - flash_rotary: bool = False, - fused_dense: bool = False, - attn_pdrop: float = 0.0, - embd_pdrop: float = 0.0, - resid_pdrop: float = 0.0, - layer_norm_epsilon: float = 1e-5, - initializer_range: float = 0.02, - tie_word_embeddings: bool = False, - pad_vocab_size_multiple: int = 64, - **kwargs - ) -> None: - self.vocab_size = int( - math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple - ) - self.n_positions = n_positions - self.n_embd = n_embd - self.n_layer = n_layer - self.n_inner = n_inner - self.n_head = n_head - self.n_head_kv = n_head_kv - self.rotary_dim = min(rotary_dim, n_embd // n_head) - self.activation_function = activation_function - self.flash_attn = flash_attn - self.flash_rotary = flash_rotary - self.fused_dense = fused_dense - self.attn_pdrop = attn_pdrop - self.embd_pdrop = embd_pdrop - self.resid_pdrop = resid_pdrop - self.layer_norm_epsilon = layer_norm_epsilon - self.initializer_range = initializer_range - - super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) diff --git a/src/axolotl/models/phi/modeling_mixformer_sequential.py b/src/axolotl/models/phi/modeling_mixformer_sequential.py deleted file mode 100644 index fd2ec054c5..0000000000 --- a/src/axolotl/models/phi/modeling_mixformer_sequential.py +++ /dev/null @@ -1,930 +0,0 @@ -# pylint: skip-file - -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -# BSD 3-Clause License -# -# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu. -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# * Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import annotations - -import copy -import inspect -from dataclasses import dataclass, field -from typing import Any, Dict, Optional, Tuple - -import torch -import torch.nn as nn -from einops import rearrange -from flash_attn.flash_attn_interface import ( - flash_attn_kvpacked_func, - flash_attn_qkvpacked_func, - flash_attn_varlen_qkvpacked_func, -) -from transformers import PretrainedConfig, PreTrainedModel -from transformers.activations import ACT2FN -from transformers.modeling_outputs import CausalLMOutputWithPast - -from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids -from .configuration_mixformer_sequential import MixFormerSequentialConfig - - -@dataclass -class InferenceParams: - """Inference parameters that are passed to the main model in order - to efficienly calculate and store the context during inference. - Adapted from https://github.com/Dao-AILab/flash-attention.""" - - max_sequence_len: int - max_batch_size: int - sequence_len_offset: int = 0 - batch_size_offset: int = 0 - key_value_memory_dict: dict = field(default_factory=dict) - fused_ft_kernel: bool = False - lengths_per_sample: Optional[torch.Tensor] = None - - -class Embedding(nn.Module): - """Token embedding with dropout.""" - - def __init__(self, config: PretrainedConfig) -> None: - super().__init__() - - self.wte = nn.Embedding(config.vocab_size, config.n_embd) - self.drop = nn.Dropout(config.embd_pdrop) - - def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - - hidden_states = self.wte(input_ids) - hidden_states = self.drop(hidden_states) - - return hidden_states - - -class RotaryEmbedding(nn.Module): - """PyTorch implementation of `flash-attn` RotaryEmbedding layer. - Adapted from https://github.com/Dao-AILab/flash-attention.""" - - def __init__( - self, - dim: int, - base: Optional[int] = 10000, - scale_base: Optional[float] = None, - device: Optional[str] = None, - **kwargs, - ) -> None: - super().__init__() - - if scale_base is not None: - raise NotImplementedError - - # Generate and save the inverse frequency buffer (non-trainable) - self.dim = dim - self.base = base - self.scale_base = scale_base - self.device = device - - inv_freq = 1.0 / ( - base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) - ) - self.register_buffer("inv_freq", inv_freq) - - scale = ( - (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) - / (1.4 * dim) - if scale_base is not None - else None - ) - self.register_buffer("scale", scale) - - self._seq_len_cached = 0 - self._cos_cached = None - self._sin_cached = None - self._cos_k_cached = None - self._sin_k_cached = None - - def _update_cos_sin_cache( - self, x: torch.FloatTensor, seqlen_offset: Optional[int] = 0 - ) -> None: - # Reset the tables if the sequence length has changed, - # or if we're on a new device (possibly due to tracing for instance) - seqlen = x.shape[1] + seqlen_offset - - # Re-generate the inverse frequency buffer if it's not fp32 - # (for instance if model.half() was called) - if self.inv_freq.dtype != "torch.float32": - self.inv_freq = 1.0 / ( - self.base - ** ( - torch.arange( - 0, self.dim, 2, device=self.device, dtype=torch.float32 - ) - / self.dim - ) - ) - - if ( - seqlen > self._seq_len_cached - or self._cos_cached.device != x.device - or self._cos_cached.dtype != x.dtype - ): - self._seq_len_cached = seqlen - t = torch.arange(seqlen, device=x.device, dtype=torch.float32) - - # Don't do einsum, it converts fp32 to fp16 - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - freqs = torch.outer( - t, self.inv_freq.to(device=t.device, dtype=torch.float32) - ) - if self.scale is None: - self._cos_cached = torch.cos(freqs).to(x.dtype) - self._sin_cached = torch.sin(freqs).to(x.dtype) - else: - power = ( - torch.arange( - seqlen, dtype=self.scale.dtype, device=self.scale.device - ) - - seqlen // 2 - ) / self.scale_base - scale = self.scale.to(device=power.device) ** rearrange( - power, "s -> s 1" - ) - - # We want the multiplication by scale to happen in fp32 - self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype) - self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype) - self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype) - self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype) - - def apply_rotary_emb_qkv( - self, - qkv: torch.FloatTensor, - sin: torch.FloatTensor, - cos: torch.FloatTensor, - sin_k: Optional[torch.FloatTensor] = None, - cos_k: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: - _, seqlen, three, _, headdim = qkv.shape - assert three == 3 - - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - assert ( - sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) - ) - - q_rot = qkv[:, :, 0, :, :rotary_dim] - q_pass = qkv[:, :, 0, :, rotary_dim:] - - k_rot = qkv[:, :, 1, :, :rotary_dim] - k_pass = qkv[:, :, 1, :, rotary_dim:] - - # Splits the queries and keys in half - q1, q2 = q_rot.chunk(2, dim=-1) - k1, k2 = k_rot.chunk(2, dim=-1) - c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange( - sin[:seqlen], "s d -> s 1 d" - ) - - # Casts to fp32 are necessary to prevent fp16 overflow issues - q1, q2, k1, k2, c, s = [ - t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s] - ] - - # Computes the new keys and queries, recasting to original dtype - q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype) - - k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype) - - return torch.cat( - [ - torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2), - torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2), - qkv[:, :, 2:3, :, :], - ], - axis=2, - ) - - def forward( - self, qkv: torch.Tensor, seqlen_offset: int = 0 - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Perform the forward pass. - - Args: - qkv: Query, key and value tensors of shape (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim). - seqlen_offset: Used in generation where the passed `qkv` is only the last token in the batch. - - Returns: - New `qkv` and the cached sinusoids. - - """ - - self._update_cos_sin_cache(qkv, seqlen_offset) - - return self.apply_rotary_emb_qkv( - qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:] - ) - - -def _update_kv_cache(kv, inference_params, layer_idx): - """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim) - Adapted from https://github.com/Dao-AILab/flash-attention.""" - # Pre-allocate memory for key-values for inference. - num_heads, head_dim = kv.shape[-2:] - if layer_idx not in inference_params.key_value_memory_dict: - kv_cache = torch.empty( - inference_params.max_batch_size, - inference_params.max_sequence_len, - 2, - num_heads, - head_dim, - dtype=kv.dtype, - device=kv.device, - ) - inference_params.key_value_memory_dict[layer_idx] = kv_cache - else: - kv_cache = inference_params.key_value_memory_dict[layer_idx] - - # Adjust key and value for inference - batch_start = inference_params.batch_size_offset - batch_end = batch_start + kv.shape[0] - sequence_start = inference_params.sequence_len_offset - sequence_end = sequence_start + kv.shape[1] - assert batch_end <= ( - kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0] # noqa - ) - assert sequence_end <= ( - kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2] # noqa - ) - - assert kv_cache is not None - kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv - kv = kv_cache[batch_start:batch_end, :sequence_end, ...] - return kv - - -class MLP(nn.Module): - """Multi-Layer Perceptron. - - Reference: - Attention Is All You Need. - https://arxiv.org/pdf/1706.03762.pdf. - - """ - - def __init__( - self, - config: PretrainedConfig, - n_inner: Optional[int] = None, - act_fn: Optional[str] = None, - ) -> None: - super().__init__() - - act_fn = config.activation_function if act_fn is None else act_fn - assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}." - - n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner - n_inner = n_inner if n_inner is not None else 4 * config.n_embd - - self.fc1 = nn.Linear(config.n_embd, n_inner) - self.fc2 = nn.Linear(n_inner, config.n_embd) - self.act = ACT2FN[act_fn] - - def _load_from_state_dict( - self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ): - old_keys = [ - prefix + "fc_in.weight", - prefix + "fc_out.weight", - prefix + "fc_in.bias", - prefix + "fc_out.bias", - ] - new_keys = [ - prefix + "fc1.weight", - prefix + "fc2.weight", - prefix + "fc1.bias", - prefix + "fc2.bias", - ] - - if all(k in state_dict for k in old_keys) and not all( - k in state_dict for k in new_keys - ): - # Older version of `MLP` saved with different key names. - for old_key, new_key in zip(old_keys, new_keys): - state_dict[new_key] = state_dict.pop(old_key) - - return super()._load_from_state_dict( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ) - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.fc2(hidden_states) - - return hidden_states - - -class FusedMLP(nn.Module): - """Fused Multi-Layer Perceptron from `flash-attn`. - - Reference: - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/ops/fused_dense.py. - - """ - - def __init__( - self, - config: PretrainedConfig, - n_inner: Optional[int] = None, - act_fn: Optional[str] = None, - raise_on_missing: bool = False, - ) -> None: - super().__init__() - - act_fn = config.activation_function if act_fn is None else act_fn - assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}." - - n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner - n_inner = n_inner if n_inner is not None else 4 * config.n_embd - - gelu_activations = ["gelu_new", "gelu_fast", "gelu_approx"] # noqa - activation = "gelu_approx" if act_fn in gelu_activations else "relu" # noqa - - self.mlp = MLP(config, n_inner=n_inner, act_fn=act_fn) - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - return self.mlp(hidden_states) - - -class SelfAttention(nn.Module): - """Implement the scaled dot product attention with softmax. - Adapted from https://github.com/Dao-AILab/flash-attention. - Arguments - --------- - softmax_scale: The temperature to use for the softmax attention. - (default: 1/sqrt(d_keys) where d_keys is computed at - runtime) - attention_dropout: The dropout rate to apply to the attention - (default: 0.0) - """ - - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): - super().__init__() - self.causal = causal - self.softmax_scale = softmax_scale - self.drop = nn.Dropout(attention_dropout) - - def forward( - self, qkv, causal=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None - ): - """Implements the multihead softmax attention. - Arguments - --------- - qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) - causal: if passed, will override self.causal - key_padding_mask: boolean mask to apply to the attention weights. True means to keep, - False means to mask out. (B, S) - """ - causal = self.causal if causal is None else causal - if cu_seqlens is not None: - return flash_attn_varlen_qkvpacked_func( - qkv.squeeze(0), - cu_seqlens, - max_seqlen, - dropout_p=self.drop.p, - softmax_scale=self.softmax_scale, - causal=causal, - ) - else: - return flash_attn_qkvpacked_func( - qkv, - dropout_p=self.drop.p, - softmax_scale=self.softmax_scale, - causal=causal, - ) - - -class CrossAttention(nn.Module): - """Implement the scaled dot product attention with softmax. - Adapted from https://github.com/Dao-AILab/flash-attention. - Arguments - --------- - softmax_scale: The temperature to use for the softmax attention. - (default: 1/sqrt(d_keys) where d_keys is computed at - runtime) - attention_dropout: The dropout rate to apply to the attention - (default: 0.0) - """ - - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): - super().__init__() - self.causal = causal - self.softmax_scale = softmax_scale - self.drop = nn.Dropout(attention_dropout) - - def forward(self, q, kv, causal=None, key_padding_mask=None): - """Implements the multihead softmax attention. - Arguments - --------- - q: The tensor containing the query. (B, Sq, H, D) - kv: The tensor containing the key and value. (B, Sk, 2, H, D) - causal: if passed, will override self.causal - key_padding_mask: boolean mask to apply to the attention weights. True means to keep, - False means to mask out. (B, Sk) - """ - causal = self.causal if causal is None else causal - return flash_attn_kvpacked_func( - q, - kv, - dropout_p=self.drop.p, - softmax_scale=self.softmax_scale, - causal=causal, - ) - - -def find_mha_dims( - config: PretrainedConfig, - n_head: Optional[int] = None, - head_dim: Optional[int] = None, -) -> Tuple[int, int]: - """Validate and return the number of heads and head dimension for multi-head attention. - - Args: - config: Model configuration. - n_head: Number of heads. - head_dim: Head dimension. - - Returns: - Number of heads and head dimension. - - """ - - assert all( - hasattr(config, attr) for attr in ["n_embd", "n_head"] - ), "`config` must have `n_embd` and `n_head` attributes." - - if head_dim is None: - assert ( - config.n_embd % config.n_head == 0 - ), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})." - - if n_head is None and head_dim is None: - head_dim = config.n_embd // config.n_head - n_head = config.n_head - elif n_head is None or head_dim is None: - raise ValueError("`n_head` and `head_dim` must be both specified or `None`.") - - return n_head, head_dim - - -class MHA(nn.Module): - """Multi-head attention layer. - Adapted from https://github.com/Dao-AILab/flash-attention.""" - - def __init__( - self, - config: PretrainedConfig, - rotary_dim: Optional[int] = None, - n_head: Optional[int] = None, - head_dim: Optional[int] = None, - bias: Optional[bool] = True, - dropout: Optional[float] = 0.0, - softmax_scale: Optional[float] = None, - causal: Optional[bool] = True, - layer_idx: Optional[int] = None, - rotary_emb_scale_base: Optional[float] = None, - return_residual: Optional[bool] = False, - checkpointing: Optional[bool] = False, - device: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - fused_dense: Optional[bool] = True, - flash_attn: Optional[bool] = True, - cutlass_attn: Optional[bool] = False, - flash_rotary: Optional[bool] = True, - raise_on_missing: Optional[bool] = False, - ) -> None: - super().__init__() - - factory_kwargs = {"device": device, "dtype": dtype} - n_head, head_dim = find_mha_dims(config, n_head, head_dim) - - self.hidden_size = config.n_embd - self.n_head = n_head - self.head_dim = head_dim - self.op_size = n_head * head_dim - - self.causal = causal - self.layer_idx = layer_idx - self.rotary_emb_dim = ( - rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0) - ) - self.fused_dense = fused_dense - self.flash_attn = flash_attn - self.cutlass_attn = cutlass_attn - self.flash_rotary = flash_rotary - self.return_residual = return_residual - self.checkpointing = checkpointing - - if self.rotary_emb_dim > 0: - rotary_kwargs = {"device": device} - if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0: - rotary_kwargs["scale_base"] = rotary_emb_scale_base - - self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs) - else: - pass - - self.Wqkv = nn.Linear( - self.hidden_size, 3 * self.op_size, bias=bias, **factory_kwargs - ) - self.out_proj = nn.Linear( - self.op_size, self.hidden_size, bias=bias, **factory_kwargs - ) - - self.inner_attn = SelfAttention( - causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout - ) - self.inner_cross_attn = CrossAttention( - causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout - ) - - def _update_kv_cache( - self, kv: torch.FloatTensor, inference_params: InferenceParams - ) -> None: - """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim) - Adapted from https://github.com/Dao-AILab/flash-attention.""" - - assert ( - self.layer_idx is not None - ), "Generation requires layer_idx in the constructor" - - return _update_kv_cache(kv, inference_params, self.layer_idx) - - def forward( - self, - x: torch.FloatTensor, - x_kv: Optional[torch.FloatTensor] = None, - key_padding_mask: Optional[torch.BoolTensor] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - max_seqlen: Optional[int] = None, - mixer_subset: Optional[torch.LongTensor] = None, - past_cache: Optional[InferenceParams] = None, - **kwargs, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: - """Perform the forward pass. - - Args: - x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if - cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total - is the is the sum of the sequence lengths in the batch. - x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x. - key_padding_mask: boolean mask, True means to keep, False means to mask out. - (batch, seqlen). Only applicable when not using FlashAttention. - cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into x. Only applicable when using - FlashAttention. - max_seqlen: int. Maximum sequence length in the batch. - mixer_subset: for cross-attention only. If not None, will take a subset of x - before applying the query projection. Useful for e.g., ViT where we only care - about the CLS token in the last layer. - past_cache: For generation only. - - Returns: - (batch, seqlen, hidden_dim) if cu_seqlens is None and max_seqlen is None, - else (total, hidden_dim) where total is the is the sum of the sequence lengths - in the batch. - - """ - - if cu_seqlens is not None: - assert max_seqlen is not None - assert key_padding_mask is None - assert self.flash_attn - # assert self.rotary_emb_dim == 0 - - if key_padding_mask is not None: - assert cu_seqlens is None - assert max_seqlen is None - assert not self.flash_attn - - if past_cache is not None: - assert key_padding_mask is None - assert cu_seqlens is None and max_seqlen is None - - attn_kwargs = {"key_padding_mask": key_padding_mask} - - assert x_kv is None and mixer_subset is None - - qkv = self.Wqkv(x) - qkv = rearrange( - qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim - ) - - if past_cache is None: - if self.rotary_emb_dim > 0: - qkv = self.rotary_emb(qkv) - context = self.inner_attn( - qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, **attn_kwargs - ) - - else: - if self.rotary_emb_dim > 0: - qkv = self.rotary_emb(qkv, seqlen_offset=past_cache.sequence_len_offset) - q = qkv[:, :, 0] - kv = self._update_kv_cache(qkv[:, :, 1:], past_cache) - # If we're processing the prompt, causal=None (use self.causal). - # If we're decoding, then causal=False. - causal = None if past_cache.sequence_len_offset == 0 else False - context = self.inner_cross_attn(q, kv, causal=causal) - - out = rearrange(context, "... h d -> ... (h d)") - out = self.out_proj(out) - - return out if not self.return_residual else (out, x) - - -class ParallelBlock(nn.Module): - """Parallel block. - - This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen). - - """ - - def __init__( - self, - config: PretrainedConfig, - mixer: Optional[Dict[str, Any]] = None, - mlp: Optional[Dict[str, Any]] = None, - block_idx: Optional[int] = None, - ) -> None: - super().__init__() - - self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.resid_dropout = nn.Dropout(config.resid_pdrop) - self.block_idx = block_idx - - self.mixer = MHA(config, layer_idx=block_idx) - self.mlp = MLP(config) - - def forward( - self, - hidden_states: torch.FloatTensor, - past_cache: Optional[torch.FloatTensor] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - max_seqlen: Optional[int] = None, - ) -> torch.FloatTensor: - residual = hidden_states - hidden_states = self.ln(hidden_states) - - attn_outputs = self.mixer( - hidden_states, - past_cache=past_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - if isinstance(attn_outputs, tuple): - attn_outputs = attn_outputs[0] - - attn_outputs = self.resid_dropout(attn_outputs) - feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states)) - - hidden_states = attn_outputs + feed_forward_hidden_states + residual - - return hidden_states - - -class CausalLMHead(nn.Module): - """Causal Language Modeling head. - - Reference: - Improving Language Understanding by Generative Pre-Training. - https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf. - - """ - - def __init__(self, config: PretrainedConfig) -> None: - super().__init__() - - self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.linear = nn.Linear(config.n_embd, config.vocab_size) - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - hidden_states = self.ln(hidden_states) - logits = self.linear(hidden_states).to(torch.float32) - - return logits - - -class CausalLMLoss(nn.Module): - """Causal Language Modeling loss. - - Reference: - Improving Language Understanding by Generative Pre-Training. - https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf. - - """ - - def __init__(self, shift_labels: Optional[bool] = True) -> None: - super().__init__() - - self.shift_labels = shift_labels - self.loss_fct = nn.CrossEntropyLoss() - - def forward( - self, logits: torch.FloatTensor, labels: torch.LongTensor - ) -> torch.FloatTensor: - if self.shift_labels: - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - - loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) - - return loss - - -class MixFormerSequentialPreTrainedModel(PreTrainedModel): - """MixFormer (sequential for DeepSpeed) pre-trained model.""" - - config_class = MixFormerSequentialConfig - base_model_prefix = "transformer" - supports_gradient_checkpointing = True - - def __init__(self, *inputs, **kwargs) -> None: - super().__init__(*inputs, **kwargs) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, **kwargs - ) -> Dict[str, Any]: - if "use_cache" in kwargs and not kwargs["use_cache"]: - return {"input_ids": input_ids} - - if past_key_values is None or not ( - isinstance(past_key_values, InferenceParams) - ): - past_key_values = InferenceParams( - max_batch_size=input_ids.shape[0], - max_sequence_len=self.config.n_positions, - sequence_len_offset=0, - batch_size_offset=0, - fused_ft_kernel=False, - key_value_memory_dict={}, - ) - else: - # assume past_key_values has cached all but last token in input_ids - past_key_values.sequence_len_offset = len(input_ids[0]) - 1 - input_ids = input_ids[:, -1].unsqueeze(-1) - - return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs} - - -class PackedSequential(nn.Sequential): - def forward( - self, - input, - cu_seqlens: Optional[torch.LongTensor] = None, - max_seqlen: Optional[int] = None, - ): - for module in self: - sig = inspect.signature(module.forward) - if "cu_seqlens" in sig.parameters: - input = module(input, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) - else: - input = module(input) - return input - - -class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel): - """MixFormer (sequential for DeepSpeed) for Causal Language Modeling.""" - - _keys_to_ignore_on_load_missing = [""] - _keys_to_ignore_on_load_unexpected = [ - r"layers\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)" - ] - _no_split_modules = ["ParallelBlock"] - - def __init__(self, config: MixFormerSequentialConfig) -> None: - super().__init__(config) - - modules = [Embedding(config)] - block_config = config.architecture - - if not isinstance(block_config, list): - block_config = [block_config for _ in range(config.n_layer)] - - if config.n_layer != len(block_config): - config.n_layer = len(block_config) - - for block_idx, block in enumerate(block_config): - # `block_cls` with `legacy` value is for backward compatibility - # `path` key is for backward compatibility - block = copy.deepcopy(block) or {"block_cls": "parallel"} - block.pop("path", None) or block.pop("block_cls", None) - - block["block_idx"] = block_idx - modules.append(ParallelBlock(config, **block)) - - modules.append(CausalLMHead(config)) - - self.layers = PackedSequential(*modules) - self.loss = CausalLMLoss() - - self.post_init() - - def get_input_embeddings(self) -> nn.Embedding: - return self.layers[0].wte - - def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None: - self.layers[0].wte = new_embeddings - - def get_output_embeddings(self) -> nn.Linear: - return self.layers[-1].linear - - def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: - self.layers[-1].linear = new_embeddings - - def forward( - self, - input_ids: torch.LongTensor, - labels: Optional[torch.LongTensor] = None, - past_key_values: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - **kwargs, - ) -> CausalLMOutputWithPast: - cu_seqlens: Optional[torch.LongTensor] = None - max_seqlen: Optional[int] = None - if position_ids is not None: - batch_size, seq_length = input_ids.shape - position_ids = position_ids.view(-1, seq_length).long() - cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids) - cu_seqlens = cu_seqlens.squeeze() - - if not past_key_values: - lm_logits = self.layers( - input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen - ) - else: - hidden_layer = self.layers[0](input_ids) - for module in self.layers[1:-1]: - hidden_layer = module( - hidden_layer, - past_cache=past_key_values, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - lm_logits = self.layers[-1](hidden_layer) - - loss = None - if labels is not None: - loss = self.loss(lm_logits, labels) - - return CausalLMOutputWithPast( - loss=loss, logits=lm_logits, past_key_values=past_key_values - ) diff --git a/src/axolotl/models/phi/modeling_phi.py b/src/axolotl/models/phi/modeling_phi.py deleted file mode 100644 index f28670749e..0000000000 --- a/src/axolotl/models/phi/modeling_phi.py +++ /dev/null @@ -1,1092 +0,0 @@ -# pylint: skip-file -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. -# -# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu. -# Licensed under the BSD 3-Clause License. - -from __future__ import annotations - -import math -from dataclasses import dataclass, field -from typing import Any, Callable, Dict, Optional, Tuple, Union - -import torch -import torch.nn as nn -from einops import rearrange, repeat -from torch.utils.checkpoint import checkpoint -from transformers import PretrainedConfig, PreTrainedModel -from transformers.activations import ACT2FN -from transformers.modeling_outputs import CausalLMOutputWithPast - -from .configuration_phi import PhiConfig - -try: - from flash_attn.bert_padding import pad_input, unpad_input - from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding - from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention -except ImportError: - pad_input, unpad_input = None, None - FlashRotaryEmbedding = None - FlashSelfAttention, FlashCrossAttention = None, None - -# this is in a seperate try/except block since sometimes fused_dense isn't available -# and it shouldn't completely disable flash attn when it isn't -try: - from flash_attn.ops.fused_dense import FusedDense -except ImportError: - FusedDense = None - - -@dataclass -class InferenceParams: - """Inference parameters passed to model to efficiently calculate - and store context during inference. - - Reference: - https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py. - - Args: - max_seqlen: Maximum sequence length. - max_batch_size: Maximum batch size. - seqlen_offset: Sequence length offset. - batch_size_offset: Batch size offset. - key_value_memory_dict: Key value memory dictionary. - lengths_per_sample: Lengths per sample. - - """ - - max_seqlen: int = field(metadata={"help": "Maximum sequence length."}) - - max_batch_size: int = field(metadata={"help": "Maximum batch size."}) - - seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."}) - - batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."}) - - key_value_memory_dict: Dict[str, Any] = field( - default_factory=dict, metadata={"help": "Key value memory dictionary."} - ) - - lengths_per_sample: torch.Tensor = field( - default=None, metadata={"help": "Lengths per sample."} - ) - - -class Embedding(nn.Module): - """Token embedding with dropout.""" - - def __init__(self, config: PretrainedConfig) -> None: - super().__init__() - - self.wte = nn.Embedding(config.vocab_size, config.n_embd) - self.drop = nn.Dropout(config.embd_pdrop) - - def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - - hidden_states = self.wte(input_ids) - hidden_states = self.drop(hidden_states) - - return hidden_states - - -def _apply_rotary_emb( - x: torch.FloatTensor, - cos: torch.FloatTensor, - sin: torch.FloatTensor, -) -> torch.FloatTensor: - _, seqlen, _, _ = x.shape - _, rotary_dim = cos.shape - rotary_dim *= 2 - - x_rot = x[:, :, :, :rotary_dim] - x_pass = x[:, :, :, rotary_dim:] - - x1, x2 = x_rot.chunk(2, dim=-1) - c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange( - sin[:seqlen], "s d -> s 1 d" - ) - x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]] - - x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype) - - return torch.cat([x_rot, x_pass], axis=-1) - - -def _apply_rotary_emb_kv( - kv: torch.FloatTensor, - cos: torch.FloatTensor, - sin: torch.FloatTensor, - cos_k: Optional[torch.FloatTensor] = None, - sin_k: Optional[torch.FloatTensor] = None, -) -> torch.FloatTensor: - _, seqlen, _, _, _ = kv.shape - _, rotary_dim = cos.shape - rotary_dim *= 2 - - k_rot = kv[:, :, 0, :, :rotary_dim] - k_pass = kv[:, :, 0, :, rotary_dim:] - - k1, k2 = k_rot.chunk(2, dim=-1) - c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange( - sin[:seqlen], "s d -> s 1 d" - ) - k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]] - - k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype) - - return torch.cat( - [ - torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2), - kv[:, :, 1:2, :, :], - ], - axis=2, - ) - - -def _apply_rotary_emb_qkv( - qkv: torch.FloatTensor, - cos: torch.FloatTensor, - sin: torch.FloatTensor, - cos_k: Optional[torch.FloatTensor] = None, - sin_k: Optional[torch.FloatTensor] = None, -) -> torch.FloatTensor: - _, seqlen, _, _, _ = qkv.shape - _, rotary_dim = cos.shape - rotary_dim *= 2 - - q_rot = qkv[:, :, 0, :, :rotary_dim] - q_pass = qkv[:, :, 0, :, rotary_dim:] - - k_rot = qkv[:, :, 1, :, :rotary_dim] - k_pass = qkv[:, :, 1, :, rotary_dim:] - - q1, q2 = q_rot.chunk(2, dim=-1) - k1, k2 = k_rot.chunk(2, dim=-1) - c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange( - sin[:seqlen], "s d -> s 1 d" - ) - q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]] - - q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype) - k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype) - - return torch.cat( - [ - torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2), - torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2), - qkv[:, :, 2:3, :, :], - ], - axis=2, - ) - - -class RotaryEmbedding(nn.Module): - """Rotary positional embedding (RoPE). - - Reference: - RoFormer: Enhanced Transformer with Rotary Position Embedding. - https://arxiv.org/pdf/2104.09864.pdf. - - """ - - def __init__( - self, - dim: int, - base: int = 10000, - scale_base: Optional[float] = None, - pos_idx_in_fp32: bool = True, - max_position_embeddings: int = 2048, - device: Optional[str] = None, - **kwargs, - ) -> None: - super().__init__() - - if scale_base is not None: - raise NotImplementedError - - self.dim = dim - self.base = float(base) - self.scale_base = scale_base - self.pos_idx_in_fp32 = pos_idx_in_fp32 - self.max_position_embeddings = max_position_embeddings - self.device = device - - # Generate and save the inverse frequency buffer (non-trainable) - inv_freq = self._compute_inv_freq(device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Generate and save the scale buffer (non-trainable) - scale = ( - (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) - / (1.4 * dim) - if scale_base is not None - else None - ) - self.register_buffer("scale", scale, persistent=False) - - # Initialize cached attributes since ONNX can't rely on dynamic initialization - self._update_cos_sin_cache( - max_position_embeddings, - device=device, - dtype=torch.float32, - ) - - def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor: - return 1.0 / ( - self.base - ** ( - torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) - / self.dim - ) - ) - - def _update_cos_sin_cache( - self, - seqlen: int, - device: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - ) -> None: - self._seq_len_cached = seqlen - - # fp32 is preferred since the output of `torch.arange` can be quite large - # and bf16 would lose a lot of precision - if self.pos_idx_in_fp32: - t = torch.arange(seqlen, device=device, dtype=torch.float32) - if self.inv_freq.dtype != torch.float32: - inv_freq = self._compute_inv_freq(device=device) - else: - inv_freq = self.inv_freq - else: - t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) - inv_freq = self.inv_freq - - # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP - freqs = torch.outer(t, inv_freq) - if self.scale is None: - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) - else: - power = ( - torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - - seqlen // 2 - ) / self.scale_base - scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") - - # Force the scale multiplication to happen in fp32 - self._cos_cached = (torch.cos(freqs) * scale).to(dtype) - self._sin_cached = (torch.sin(freqs) * scale).to(dtype) - self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) - self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) - - def forward( - self, - qkv: torch.Tensor, - kv: Optional[torch.Tensor] = None, - seqlen_offset: int = 0, - **kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor]: - if ( - self._seq_len_cached < qkv.shape[1] + seqlen_offset - or self._cos_cached.device != qkv.device - or self._cos_cached.dtype != qkv.dtype - or (self.training and self._cos_cached.is_inference()) - ): - self._update_cos_sin_cache( - qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype - ) - - if kv is None: - return _apply_rotary_emb_qkv( - qkv, - self._cos_cached[seqlen_offset:], - self._sin_cached[seqlen_offset:], - ) - else: - q = _apply_rotary_emb( - qkv, - self._cos_cached[seqlen_offset:], - self._sin_cached[seqlen_offset:], - ) - kv = _apply_rotary_emb_kv( - kv, - self._cos_cached[seqlen_offset:], - self._sin_cached[seqlen_offset:], - ) - - return q, kv - - -class MLP(nn.Module): - """Multi-Layer Perceptron. - - Reference: - Attention Is All You Need. - https://arxiv.org/pdf/1706.03762.pdf. - - """ - - def __init__( - self, - config: PretrainedConfig, - n_inner: Optional[int] = None, - act_fn: Optional[str] = None, - ) -> None: - super().__init__() - - act_fn = config.activation_function if act_fn is None else act_fn - - n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner - n_inner = n_inner if n_inner is not None else 4 * config.n_embd - - self.fc1 = nn.Linear(config.n_embd, n_inner) - self.fc2 = nn.Linear(n_inner, config.n_embd) - self.act = ACT2FN[act_fn] - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.fc2(hidden_states) - - return hidden_states - - -class SelfAttention(nn.Module): - """Self-attention layer (compatible with PyTorch). - - Reference: - https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py. - - """ - - def __init__( - self, - causal: bool = True, - softmax_scale: Optional[float] = None, - attention_dropout: float = 0.0, - ) -> None: - super().__init__() - - self.causal = causal - self.softmax_scale = softmax_scale - self.drop = nn.Dropout(attention_dropout) - - @torch.autocast("cpu", enabled=False) - @torch.autocast("cuda", enabled=False) - def forward( - self, - qkv: torch.FloatTensor, - causal: bool = None, - key_padding_mask: Optional[torch.BoolTensor] = None, - **kwargs, - ) -> torch.FloatTensor: - batch_size, seqlen = qkv.shape[0], qkv.shape[1] - q, k, v = qkv.unbind(dim=2) - - q = q.to(torch.float32) - k = k.to(torch.float32) - - causal = self.causal if causal is None else causal - softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) - - # Autocast is manually disabled to avoid `torch.einsum` performing the operation - # using float16, which might lead to overflow - scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) - - if key_padding_mask is not None: - padding_mask = torch.full( - (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device - ) - padding_mask.masked_fill_(key_padding_mask, 0.0) - - scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") - - if causal: - causal_mask = torch.triu( - torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1 - ) - scores = scores + causal_mask.to(dtype=scores.dtype) - - attention = torch.softmax(scores, dim=-1).to(v.dtype) - attention = self.drop(attention) - - output = torch.einsum("bhts,bshd->bthd", attention, v) - - return output - - -class CrossAttention(nn.Module): - """Cross-attention layer (compatible with PyTorch). - - Reference: - https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py. - - """ - - def __init__( - self, - causal: bool = True, - softmax_scale: Optional[float] = None, - attention_dropout: float = 0.0, - ) -> None: - super().__init__() - - self.causal = causal - self.softmax_scale = softmax_scale - self.drop = nn.Dropout(attention_dropout) - - @torch.autocast("cpu", enabled=False) - @torch.autocast("cuda", enabled=False) - def forward( - self, - q: torch.FloatTensor, - kv: torch.FloatTensor, - causal: bool = None, - key_padding_mask: Optional[torch.BoolTensor] = None, - **kwargs, - ) -> torch.FloatTensor: - batch_size, seqlen_q = q.shape[0], q.shape[1] - seqlen_k = kv.shape[1] - - if kv.shape[3] != q.shape[2]: - kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) - k, v = kv.unbind(dim=2) - - q = q.to(torch.float32) - k = k.to(torch.float32) - - causal = self.causal if causal is None else causal - softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) - - # Autocast is manually disabled to avoid `torch.einsum` performing the operation - # using float16, which might lead to overflow - scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) - - if key_padding_mask is not None: - padding_mask = torch.full( - (batch_size, seqlen_k), - -10000.0, - dtype=scores.dtype, - device=scores.device, - ) - padding_mask.masked_fill_(key_padding_mask, 0.0) - - scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") - - if causal: - rows = rearrange( - torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1" - ) - cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long) - causal_mask = cols > rows + seqlen_k - seqlen_q - - scores = scores.masked_fill(causal_mask, -10000.0) - - attention = torch.softmax(scores, dim=-1).to(v.dtype) - attention = self.drop(attention) - - output = torch.einsum("bhts,bshd->bthd", attention, v) - - return output - - -def _find_mha_dims( - config: PretrainedConfig, - n_head: Optional[int] = None, - n_head_kv: Optional[int] = None, - head_dim: Optional[int] = None, -) -> Tuple[int, int]: - if n_head is None and head_dim is None: - head_dim = config.n_embd // config.n_head - n_head = config.n_head - elif n_head is None or head_dim is None: - raise ValueError("`n_head` and `head_dim` must be both specified or `None`.") - - if n_head_kv is None: - n_head_kv = getattr(config, "n_head_kv", None) or n_head - - return n_head, n_head_kv, head_dim - - -def _update_kv_cache( - kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int -) -> torch.FloatTensor: - num_heads, head_dim = kv.shape[-2:] - - if layer_idx not in inference_params.key_value_memory_dict: - inference_params.key_value_memory_dict[layer_idx] = torch.empty( - inference_params.max_batch_size, - inference_params.max_seqlen, - 2, - num_heads, - head_dim, - dtype=kv.dtype, - device=kv.device, - ) - - batch_start = inference_params.batch_size_offset - batch_end = batch_start + kv.shape[0] - - sequence_start = inference_params.seqlen_offset - sequence_end = sequence_start + kv.shape[1] - - # When the current sequence length is equal to or larger than the maximum sequence length, - # we need to concatenate the current `kv` with the cached `kv` to expand its length - if sequence_end >= inference_params.max_seqlen: - inference_params.key_value_memory_dict[layer_idx] = torch.concatenate( - (inference_params.key_value_memory_dict[layer_idx], kv), dim=1 - ) - - inference_params.key_value_memory_dict[layer_idx][ - batch_start:batch_end, sequence_start:sequence_end, ... - ] = kv - kv = inference_params.key_value_memory_dict[layer_idx][ - batch_start:batch_end, :sequence_end, ... - ] - - return kv - - -class MHA(nn.Module): - """Multi-head attention layer.""" - - def __init__( - self, - config: PretrainedConfig, - dtype: Optional[torch.dtype] = None, - device: Optional[str] = None, - rotary_dim: Optional[int] = None, - rotary_base: float = 10000.0, - rotary_scale_base: Optional[float] = None, - n_head: Optional[int] = None, - n_head_kv: Optional[int] = None, - head_dim: Optional[int] = None, - bias: bool = True, - causal: bool = True, - softmax_scale: Optional[float] = None, - layer_idx: Optional[int] = None, - return_residual: bool = False, - checkpointing: bool = False, - ) -> None: - super().__init__() - - # Rotary embedding - self.rotary_dim = ( - rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0) - ) - if self.rotary_dim > 0: - rotary_cls = ( - FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding - ) - if rotary_cls is None: - rotary_cls = RotaryEmbedding - - rotary_kwargs = {} - if rotary_cls is RotaryEmbedding: - rotary_kwargs["max_position_embeddings"] = config.n_positions - - self.rotary_emb = rotary_cls( - self.rotary_dim, - base=rotary_base, - scale_base=rotary_scale_base, - device=device, - **rotary_kwargs, - ) - - # MLP - self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims( - config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim - ) - op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv) - hidden_size = config.n_embd - - linear_cls = FusedDense if config.fused_dense else nn.Linear - if linear_cls is None: - linear_cls = nn.Linear - - self.Wqkv = linear_cls( - hidden_size, op_size, bias=bias, device=device, dtype=dtype - ) - self.out_proj = linear_cls( - hidden_size, hidden_size, bias=bias, device=device, dtype=dtype - ) - - # Attention - attn_cls = FlashSelfAttention if config.flash_attn else SelfAttention - if attn_cls is None: - attn_cls = SelfAttention - - cross_attn_cls = FlashCrossAttention if config.flash_attn else CrossAttention - if cross_attn_cls is None: - cross_attn_cls = CrossAttention - - self.inner_attn = attn_cls( - causal=causal, - softmax_scale=softmax_scale, - attention_dropout=config.attn_pdrop, - ) - self.inner_cross_attn = cross_attn_cls( - causal=causal, - softmax_scale=softmax_scale, - attention_dropout=config.attn_pdrop, - ) - - self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention - self.layer_idx = layer_idx - self.return_residual = return_residual - self.checkpointing = checkpointing - self._gradient_checkpointing_func = None - - def _forward_self_attn( - self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor] - ) -> torch.FloatTensor: - qkv = self.Wqkv(x) - qkv = rearrange( - qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim - ) - - if self.rotary_dim > 0: - qkv = self.rotary_emb(qkv) - - if self.flash_attn: - batch_size, seqlen = qkv.shape[0], qkv.shape[1] - - cu_seqlens, max_seqlen = None, None - if key_padding_mask is not None: - # If `key_padding_mask` is supplied, we need to unpad the input and retrieve - # the `cu_seqlens` and `max_seqlen` to be used by `flash-attn` - qkv, indices, cu_seqlens, max_seqlen = unpad_input( - qkv, key_padding_mask - ) - - if self.checkpointing and self.training: - attn_output = self._gradient_checkpointing_func( - self.inner_attn, - qkv, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - use_reentrant=False, - ) - else: - attn_output = self.inner_attn( - qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen - ).to(qkv.device) - - # If `key_padding_mask` is supplied, we need to pad the output back to the original shape - return ( - pad_input(attn_output, indices, batch_size, seqlen) - if key_padding_mask is not None - else attn_output - ) - - if self.checkpointing and self.training: - return self._gradient_checkpointing_func( - self.inner_attn, - qkv, - key_padding_mask=key_padding_mask, - use_reentrant=False, - ) - - return self.inner_attn(qkv, key_padding_mask=key_padding_mask) - - def _forward_cross_attn( - self, - x: torch.FloatTensor, - past_key_values: Optional[InferenceParams], - key_padding_mask: Optional[torch.BoolTensor], - ) -> torch.FloatTensor: - batch_size = x.shape[0] - - qkv = self.Wqkv(x) - - q = qkv[..., : self.n_head * self.head_dim] - q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) - - kv = qkv[..., self.n_head * self.head_dim :] - kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) - - seqlen_offset = ( - past_key_values.seqlen_offset if past_key_values is not None else 0 - ) - causal = None if seqlen_offset == 0 else False - if self.rotary_dim > 0: - q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset) - - if past_key_values is not None: - kv = _update_kv_cache(kv, past_key_values, self.layer_idx) - - if self.flash_attn: - batch_size, seqlen_q = q.shape[0], q.shape[1] - seqlen_k = kv.shape[1] - - cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = ( - None, - None, - None, - None, - ) - if key_padding_mask is not None: - kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask) - - if seqlen_q == 1: - key_padding_mask = torch.ones(batch_size, 1, device=q.device) - elif seqlen_q != seqlen_k: - key_padding_mask = key_padding_mask[:, -seqlen_q:] - - q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input( - q, key_padding_mask - ) - - if self.checkpointing and self.training: - attn_output = self._gradient_checkpointing_func( - self.inner_cross_attn, - q, - kv, - causal=causal, - cu_seqlens=cu_seqlens_q, - max_seqlen=max_seqlen_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_k=max_seqlen_k, - use_reentrant=False, - ) - else: - attn_output = self.inner_cross_attn( - q, - kv, - causal=causal, - cu_seqlens=cu_seqlens_q, - max_seqlen=max_seqlen_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_k=max_seqlen_k, - ) - - return ( - pad_input(attn_output, indices_q, batch_size, max_seqlen_q) - if key_padding_mask is not None - else attn_output - ) - - if self.checkpointing and self.training: - return self._gradient_checkpointing_func( - self.inner_cross_attn, - q, - kv, - key_padding_mask=key_padding_mask, - causal=causal, - use_reentrant=False, - ) - - return self.inner_cross_attn( - q, kv, key_padding_mask=key_padding_mask, causal=causal - ) - - def forward( - self, - x: torch.FloatTensor, - past_key_values: Optional[InferenceParams] = None, - attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None, - **kwargs, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: - if attention_mask is not None: - attention_mask = attention_mask.bool() - else: - attention_mask = None - - # MHA - if self.n_head == self.n_head_kv: - if past_key_values is None: - # If `past_key_values` are not supplied, we run self-attention - attn_output = self._forward_self_attn(x, attention_mask) - else: - # If `past_key_values` are supplied, it means that we might have cached values and - # could take advantage of cross-attention - attn_output = self._forward_cross_attn( - x, past_key_values, attention_mask - ) - # MQA / GQA - else: - # Regardless of `past_key_values` being supplied or not, it always use cross-attention - # because `q` and `kv` lengths might be different - attn_output = self._forward_cross_attn(x, past_key_values, attention_mask) - - output = rearrange(attn_output, "... h d -> ... (h d)") - output = self.out_proj(output) - - return output if not self.return_residual else (output, x) - - -class ParallelBlock(nn.Module): - """Parallel block. - - This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen). - - """ - - def __init__( - self, - config: PretrainedConfig, - block_idx: Optional[int] = None, - ) -> None: - super().__init__() - - self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.resid_dropout = nn.Dropout(config.resid_pdrop) - self.block_idx = block_idx - - self.mixer = MHA(config, layer_idx=block_idx) - self.mlp = MLP(config) - self.checkpointing = False - self._gradient_checkpointing_func = None - - def forward( - self, - hidden_states: torch.FloatTensor, - past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, - attention_mask: Optional[torch.BoolTensor] = None, - **kwargs, - ) -> torch.FloatTensor: - def _forward( - mixer, - resid_dropout, - mlp, - ln, - hidden_states, - past_key_values, - attention_mask, - ): - residual = hidden_states - hidden_states = ln(hidden_states) - - attn_outputs = mixer( - hidden_states, - past_key_values=past_key_values, - attention_mask=attention_mask, - ) - if isinstance(attn_outputs, tuple): - attn_outputs = attn_outputs[0] - - attn_outputs = resid_dropout(attn_outputs) - feed_forward_hidden_states = resid_dropout(mlp(hidden_states)) - - return attn_outputs + feed_forward_hidden_states + residual - - if self.training and self.checkpointing: - return self._gradient_checkpointing_func( - _forward, - self.mixer, - self.resid_dropout, - self.mlp, - self.ln, - hidden_states, - past_key_values, - attention_mask, - ) - - return _forward( - self.mixer, - self.resid_dropout, - self.mlp, - self.ln, - hidden_states, - past_key_values, - attention_mask, - ) - - -class CausalLMHead(nn.Module): - """Causal Language Modeling head. - - Reference: - Improving Language Understanding by Generative Pre-Training. - https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf. - - """ - - def __init__(self, config: PretrainedConfig) -> None: - super().__init__() - - self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.linear = nn.Linear(config.n_embd, config.vocab_size) - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - hidden_states = self.ln(hidden_states) - logits = self.linear(hidden_states).to(torch.float32) - - return logits - - -class CausalLMLoss(nn.Module): - """Causal Language Modeling loss. - - Reference: - Improving Language Understanding by Generative Pre-Training. - https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf. - - """ - - def __init__(self, shift_labels: bool = True) -> None: - super().__init__() - - self.shift_labels = shift_labels - self.loss_fct = nn.CrossEntropyLoss() - - def forward( - self, logits: torch.FloatTensor, labels: torch.LongTensor - ) -> torch.FloatTensor: - if self.shift_labels: - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - - loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) - - return loss - - -class PhiPreTrainedModel(PreTrainedModel): - """Phi pre-trained model.""" - - config_class = PhiConfig - base_model_prefix = "transformer" - supports_gradient_checkpointing = True - _no_split_modules = ["ParallelBlock"] - - def __init__(self, *inputs, **kwargs) -> None: - super().__init__(*inputs, **kwargs) - - def _init_weights(self, module: nn.Module) -> None: - if isinstance(module, (nn.Linear,)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - if module.bias is not None: - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - def _set_gradient_checkpointing( - self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint - ): - for module in self.modules(): - if hasattr(module, "checkpointing"): - module._gradient_checkpointing_func = gradient_checkpointing_func - module.checkpointing = enable - - def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, - attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None, - **kwargs, - ) -> Dict[str, Any]: - if past_key_values is None or not ( - isinstance(past_key_values, InferenceParams) - ): - past_key_values = InferenceParams( - max_seqlen=self.config.n_positions, - max_batch_size=input_ids.shape[0], - seqlen_offset=0, - batch_size_offset=0, - key_value_memory_dict={}, - lengths_per_sample=None, - ) - else: - # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids` - past_key_values.seqlen_offset = input_ids.shape[1] - 1 - input_ids = input_ids[:, -1].unsqueeze(-1) - - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "attention_mask": attention_mask, - } - - -class PhiModel(PhiPreTrainedModel): - """Phi model.""" - - _keys_to_ignore_on_load_missing = [""] - _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"] - - def __init__(self, config: PhiConfig) -> None: - super().__init__(config) - - self.embd = Embedding(config) - self.h = nn.ModuleList( - [ParallelBlock(config, block_idx=i) for i in range(config.n_layer)] - ) - self.gradient_checkpointing = False - self.post_init() - - def get_input_embeddings(self) -> nn.Embedding: - return self.embd.wte - - def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None: - self.embd.wte = new_embeddings - - def forward( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, - attention_mask: Optional[torch.BoolTensor] = None, - ) -> torch.FloatTensor: - hidden_states = self.embd(input_ids) - - for layer in self.h: - hidden_states = layer( - hidden_states, - past_key_values=past_key_values, - attention_mask=attention_mask, - ) - - return hidden_states - - -class PhiForCausalLM(PhiPreTrainedModel): - """Phi for Causal Language Modeling.""" - - _keys_to_ignore_on_load_missing = [""] - _keys_to_ignore_on_load_unexpected = [ - r"transformer\.h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)" - ] - - def __init__(self, config: PhiConfig) -> None: - super().__init__(config) - - self.transformer = PhiModel(config) - self.lm_head = CausalLMHead(config) - self.loss = CausalLMLoss() - - self.post_init() - - def get_output_embeddings(self) -> nn.Linear: - return self.lm_head.linear - - def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: - self.lm_head.linear = new_embeddings - - def forward( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, - attention_mask: Optional[torch.BoolTensor] = None, - labels: Optional[torch.LongTensor] = None, - **kwargs, - ) -> CausalLMOutputWithPast: - hidden_states = self.transformer( - input_ids, past_key_values=past_key_values, attention_mask=attention_mask - ) - lm_logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - loss = self.loss(lm_logits, labels) - - return CausalLMOutputWithPast( - loss=loss, logits=lm_logits, past_key_values=past_key_values - ) diff --git a/src/axolotl/monkeypatch/phi/__init__.py b/src/axolotl/monkeypatch/phi/__init__.py new file mode 100644 index 0000000000..1076708a0b --- /dev/null +++ b/src/axolotl/monkeypatch/phi/__init__.py @@ -0,0 +1,12 @@ +""" +Patches to support multipack for phi2 +""" +import transformers + +from axolotl.monkeypatch.utils import get_unpad_data + + +def replace_phi_attn_with_multipack_flash_attn(): + transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 78fbe52d29..09bc31db3b 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -364,20 +364,6 @@ def validate_config(cfg): "`early_stopping_patience` requires that eval_steps should evenly divide save_steps." ) - if cfg.model_type == "MixFormerSequentialForCausalLM" and cfg.adapter is not None: - LOG.warning("Use AutoModelForCausalLM for phi/MixFormer models with qLoRA") - - if cfg.model_config_type == "mixformer-sequential": - if cfg.sample_packing: - if cfg.adapter is not None: - LOG.warning( - "phi/MixFormer models are not currently compatible with LoRA and sample_packing" - ) - if cfg.model_type == "AutoModelForCausalLM": - raise ValueError( - "`model_type: MixFormerSequentialForCausalLM` required for sample_packing" - ) - if cfg.datasets: for idx, ds_cfg in enumerate(cfg.datasets): if not ds_cfg.type: diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index fb2eb9bc42..5839f74f69 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -397,7 +397,7 @@ def for_d_in_datasets(dataset_configs): LOG.info("shuffle merged datasets") dataset = dataset.shuffle(seed=seed) - dataset, _ = process_datasets_for_packing(cfg, dataset, None, tokenizer) + dataset, _ = process_datasets_for_packing(cfg, dataset, None) if cfg.local_rank == 0: LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") diff --git a/src/axolotl/utils/lora_embeddings.py b/src/axolotl/utils/lora_embeddings.py index d9fe35eb81..70f56655ea 100644 --- a/src/axolotl/utils/lora_embeddings.py +++ b/src/axolotl/utils/lora_embeddings.py @@ -7,8 +7,6 @@ def get_linear_embedding_layers(model_type): """ returns the linear embedding layers needed for loras, dependent on the model arch """ - if model_type == "phi-msft": - return ["embd.wte", "lm_head.linear"] if model_type == "gpt_neox": return ["embed_in", "embed_out"] if model_type == "falcon": diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6ba1e3704b..25b575686c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -169,6 +169,7 @@ def load_tokenizer(cfg): # pylint: disable=too-many-boolean-expressions if ( (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val) + and (len(tokenizer.encode(val)) > 1) and cfg.adapter and ( not cfg.lora_modules_to_save @@ -342,6 +343,12 @@ def load_model( LOG.info("patching falcon with flash attention") replace_falcon_attn_with_multipack_flash_attn() + if cfg.model_config_type == "phi" and cfg.flash_attention and cfg.sample_packing: + from axolotl.monkeypatch.phi import replace_phi_attn_with_multipack_flash_attn + + LOG.info("patching phi with flash attention") + replace_phi_attn_with_multipack_flash_attn() + if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing: from axolotl.monkeypatch.qwen2 import ( replace_qwen2_attn_with_multipack_flash_attn, @@ -448,7 +455,7 @@ def load_model( "flash_attention_2" ) else: - if model_config.model_type in ["mixtral", "qwen2", "falcon"]: + if model_config.model_type in ["mixtral", "qwen2", "falcon", "phi"]: model_kwargs["attn_implementation"] = "flash_attention_2" model_config._attn_implementation = ( # pylint: disable=protected-access "flash_attention_2" @@ -458,10 +465,6 @@ def load_model( model_config._attn_implementation = ( # pylint: disable=protected-access "eager" ) - if model_config.model_type == "phi-msft": - model_config.flash_attn = True - model_config.flash_rotary = True - model_config.fused_dense = True try: if ( @@ -518,16 +521,6 @@ def load_model( # device=cfg.device, # ) # model.train() # sets to train instead of eval mode - elif model_type == "PhiForCausalLM" or model_config.model_type == "phi-msft": - from axolotl.models.phi import PhiForCausalLM - - model = PhiForCausalLM.from_pretrained( - base_model, - config=model_config, - load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, - load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, - **model_kwargs, - ) elif model_type == "MambaLMHeadModel": # FIXME this is janky at best and hacked together to make it work MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 2e9d782c74..38b67fb434 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -106,19 +106,16 @@ def drop_long_seq(sample, sequence_len=2048): return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0 -def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer): +def process_datasets_for_packing(cfg, train_dataset, eval_dataset): drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len) with zero_first(is_main_process()): if cfg.is_preprocess: max_input_len = np.max(get_dataset_lengths(train_dataset)) LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) - # Phi doesn't want the attention_mask feature when training if ( - "CodeGenTokenizer" in tokenizer.__class__.__name__ - or (cfg.is_mistral_derived_model and cfg.flash_attention) - or cfg.model_config_type == "mamba" - ): + cfg.is_mistral_derived_model and cfg.flash_attention + ) or cfg.model_config_type == "mamba": LOG.info("dropping attention_mask column") train_dataset = train_dataset.remove_columns("attention_mask") if eval_dataset: diff --git a/tests/e2e/patched/test_phi_multipack.py b/tests/e2e/patched/test_phi_multipack.py new file mode 100644 index 0000000000..5f30453c18 --- /dev/null +++ b/tests/e2e/patched/test_phi_multipack.py @@ -0,0 +1,123 @@ +""" +E2E tests for lora llama +""" + +import logging +import os +import unittest +from pathlib import Path + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from ..utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestPhiMultipack(unittest.TestCase): + """ + Test case for Phi2 models + """ + + @with_temp_dir + def test_ft_packed(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "microsoft/phi-1_5", + "model_type": "PhiForCausalLM", + "tokenizer_type": "AutoTokenizer", + "sequence_len": 1024, + "sample_packing": True, + "flash_attention": True, + "pad_to_sequence_len": True, + "load_in_8bit": False, + "adapter": None, + "val_set_size": 0.1, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "dataset_shard_num": 10, + "dataset_shard_idx": 0, + "num_epochs": 1, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "eval_steps": 10, + "save_steps": 10, + "bf16": "auto", + } + ) + + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "pytorch_model.bin").exists() + + @with_temp_dir + def test_qlora_packed(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "microsoft/phi-1_5", + "model_type": "PhiForCausalLM", + "tokenizer_type": "AutoTokenizer", + "sequence_len": 1024, + "sample_packing": True, + "flash_attention": True, + "pad_to_sequence_len": True, + "load_in_8bit": False, + "adapter": "qlora", + "lora_r": 64, + "lora_alpha": 32, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "dataset_shard_num": 10, + "dataset_shard_idx": 0, + "num_epochs": 1, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "eval_steps": 10, + "save_steps": 10, + "bf16": "auto", + } + ) + + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py index 80c748cc9c..4cc6bcdcc9 100644 --- a/tests/e2e/test_phi.py +++ b/tests/e2e/test_phi.py @@ -7,9 +7,6 @@ import unittest from pathlib import Path -import pytest -from transformers.utils import is_torch_bf16_gpu_available - from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs from axolotl.train import train @@ -27,17 +24,15 @@ class TestPhi(unittest.TestCase): Test case for Phi2 models """ - @pytest.mark.skip(reason="fixme later") @with_temp_dir - def test_phi2_ft(self, temp_dir): + def test_phi_ft(self, temp_dir): # pylint: disable=duplicate-code cfg = DictDefault( { - "base_model": "microsoft/phi-2", - "trust_remote_code": True, + "base_model": "microsoft/phi-1_5", "model_type": "AutoModelForCausalLM", "tokenizer_type": "AutoTokenizer", - "sequence_len": 512, + "sequence_len": 2048, "sample_packing": False, "load_in_8bit": False, "adapter": None, @@ -64,13 +59,9 @@ def test_phi2_ft(self, temp_dir): "max_steps": 10, "save_steps": 10, "eval_steps": 10, - "save_safetensors": True, + "bf16": "auto", } ) - if is_torch_bf16_gpu_available(): - cfg.bf16 = True - else: - cfg.fp16 = True normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) @@ -78,25 +69,24 @@ def test_phi2_ft(self, temp_dir): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "pytorch_model.bin").exists() - @pytest.mark.skip(reason="multipack no longer supported atm") @with_temp_dir - def test_ft_packed(self, temp_dir): + def test_phi_qlora(self, temp_dir): # pylint: disable=duplicate-code cfg = DictDefault( { - "base_model": "microsoft/phi-2", - "trust_remote_code": True, - "model_type": "PhiForCausalLM", + "base_model": "microsoft/phi-1_5", + "model_type": "AutoModelForCausalLM", "tokenizer_type": "AutoTokenizer", - "sequence_len": 512, - "sample_packing": True, + "sequence_len": 2048, + "sample_packing": False, "load_in_8bit": False, - "adapter": None, + "adapter": "qlora", + "lora_r": 64, + "lora_alpha": 32, + "lora_dropout": 0.05, + "lora_target_linear": True, "val_set_size": 0.1, "special_tokens": { - "unk_token": "<|endoftext|>", - "bos_token": "<|endoftext|>", - "eos_token": "<|endoftext|>", "pad_token": "<|endoftext|>", }, "datasets": [ @@ -112,18 +102,18 @@ def test_ft_packed(self, temp_dir): "gradient_accumulation_steps": 1, "output_dir": temp_dir, "learning_rate": 0.00001, - "optimizer": "adamw_bnb_8bit", + "optimizer": "paged_adamw_8bit", "lr_scheduler": "cosine", + "flash_attention": True, + "max_steps": 10, + "save_steps": 10, + "eval_steps": 10, + "bf16": "auto", } ) - if is_torch_bf16_gpu_available(): - cfg.bf16 = True - else: - cfg.fp16 = True - normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + assert (Path(temp_dir) / "adapter_model.bin").exists() diff --git a/tests/test_validation.py b/tests/test_validation.py index 5201bdf46a..5c3641f65e 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -742,11 +742,11 @@ def test_llama_add_tokens_adapter(self): check_model_config(cfg, model_config) - def test_phi2_add_tokens_adapter(self): + def test_phi_add_tokens_adapter(self): cfg = DictDefault( {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]} ) - model_config = DictDefault({"model_type": "phi-msft"}) + model_config = DictDefault({"model_type": "phi"}) with pytest.raises( ValueError, @@ -759,7 +759,7 @@ def test_phi2_add_tokens_adapter(self): "adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"], - "lora_modules_to_save": ["embed_tokens", "lm_head"], + "lora_modules_to_save": ["embd.wte", "lm_head.linear"], } ) @@ -774,7 +774,7 @@ def test_phi2_add_tokens_adapter(self): "adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"], - "lora_modules_to_save": ["embd.wte", "lm_head.linear"], + "lora_modules_to_save": ["embed_tokens", "lm_head"], } )