diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 18dc353a23..334948dd9d 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -12,7 +12,7 @@ from dataclasses import dataclass, field from functools import wraps from pathlib import Path -from typing import Optional +from typing import Optional, Type, Union import torch import transformers @@ -37,6 +37,7 @@ BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, MambaDataCollator, + V2BatchSamplerDataCollatorForSeq2Seq, ) from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.schedulers import ( @@ -896,14 +897,22 @@ def build_collator( if is_eval and training_args.eval_sample_packing: use_batch_sampler_collator = True + collator: Type[ + Union[ + V2BatchSamplerDataCollatorForSeq2Seq, + BatchSamplerDataCollatorForSeq2Seq, + DataCollatorForSeq2Seq, + ] + ] if use_batch_sampler_collator: - return BatchSamplerDataCollatorForSeq2Seq( - self.tokenizer, - return_tensors="pt", - **kwargs, - ) + if self.cfg.model_config_type == "mixtral": + collator = V2BatchSamplerDataCollatorForSeq2Seq + else: + collator = BatchSamplerDataCollatorForSeq2Seq + else: + collator = DataCollatorForSeq2Seq - return DataCollatorForSeq2Seq( + return collator( self.tokenizer, return_tensors="pt", **kwargs, diff --git a/src/axolotl/monkeypatch/mixtral/__init__.py b/src/axolotl/monkeypatch/mixtral/__init__.py index 74fa00f649..fb40172308 100644 --- a/src/axolotl/monkeypatch/mixtral/__init__.py +++ b/src/axolotl/monkeypatch/mixtral/__init__.py @@ -3,20 +3,10 @@ """ import transformers +from axolotl.monkeypatch.utils import get_unpad_data -def replace_mixtral_attn_with_multipack_flash_attn(): - from .modeling_mixtral import ( - MixtralMultipackFlashAttention2, - mixtral_decoder_layer_forward, - mixtral_model_forward, - ) - transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = ( - mixtral_decoder_layer_forward - ) - transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = ( - mixtral_model_forward +def replace_mixtral_attn_with_multipack_flash_attn(): + transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data ) - transformers.models.mixtral.modeling_mixtral.MIXTRAL_ATTENTION_CLASSES[ - "flash_attention_2" - ] = MixtralMultipackFlashAttention2 diff --git a/src/axolotl/monkeypatch/mixtral/modeling_mixtral.py b/src/axolotl/monkeypatch/mixtral/modeling_mixtral.py deleted file mode 100644 index db892530d6..0000000000 --- a/src/axolotl/monkeypatch/mixtral/modeling_mixtral.py +++ /dev/null @@ -1,383 +0,0 @@ -""" -Mixtral modeling for multipack -""" -# pylint: disable=missing-module-docstring,unused-argument,protected-access,pointless-string-statement,duplicate-code -import logging -import warnings -from typing import List, Optional, Tuple, Union - -import torch -from einops import rearrange -from flash_attn import flash_attn_varlen_qkvpacked_func -from transformers import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask -from transformers.modeling_outputs import MoeModelOutputWithPast -from transformers.models.mixtral.modeling_mixtral import ( - MixtralFlashAttention2, - apply_rotary_pos_emb, - repeat_kv, -) - -from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids - -LOG = logging.getLogger("axolotl.monkeypatch.mixtral") - - -class MixtralMultipackFlashAttention2(MixtralFlashAttention2): - """ - Custom multipack implementation w flash attention 2 - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._flash_attn_uses_top_left_mask = True - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[torch.Tensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: - # special handling using sample packing - qkv = torch.stack( - [query_states, key_states, value_states], dim=2 - ) # [bsz, nh, 3, q_len, hd] - qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] - qkv = rearrange(qkv, "b s ... -> (b s) ...") - - attn_output = flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=self.attention_dropout, - softmax_scale=None, - causal=True, - ) - attn_output = rearrange(attn_output, "(b s) ... -> b s ...", b=bsz) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -def mixtral_decoder_layer_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - output_router_logits: Optional[bool] = False, - use_cache: Optional[bool] = False, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[torch.Tensor] = None, - **kwargs, -) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_router_logits (`bool`, *optional*): - Whether or not to return the logits of all the routers. They are useful for computing the router loss, and - should not be returned during inference. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states, router_logits = self.block_sparse_moe(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - if output_router_logits: - outputs += (router_logits,) - - return outputs - - -def mixtral_model_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, -) -> Union[Tuple, MoeModelOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_router_logits = ( - output_router_logits - if output_router_logits is not None - else self.config.output_router_logits - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - if input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - past_key_values_length = 0 - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - cu_seqlens = None - max_seqlen = None - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids) - cu_seqlens = cu_seqlens.squeeze() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if ( - attention_mask is not None - and self._attn_implementation == "flash_attention_2" - and use_cache - ): - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - - if self._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = ( - attention_mask - if (attention_mask is not None and 0 in attention_mask) - else None - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - LOG.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_router_logits = () if output_router_logits else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, - cu_seqlens, - max_seqlen, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if output_router_logits: - all_router_logits += (layer_outputs[-1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = None - if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() - if use_legacy_cache - else next_decoder_cache - ) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_cache, - all_hidden_states, - all_self_attns, - all_router_logits, - ] - if v is not None - ) - - return MoeModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - router_logits=all_router_logits, - ) diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index d71fbc6bc5..193c9d0ec6 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -2,6 +2,40 @@ Shared utils for the monkeypatches """ import torch +import torch.nn.functional as F + + +@torch.jit.script +def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor: + max_num = int(torch.max(attention_mask).item()) + batch_size, _ = attention_mask.shape + counts = torch.zeros((batch_size, max_num), dtype=torch.int32) + + for i in range(1, max_num + 1): + mask = attention_mask == i + counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32) + + result = counts.flatten() + nonzero_indices = torch.nonzero(result).squeeze(-1) + return result[nonzero_indices] + + +@torch.jit.script +def get_unpad_data(attention_mask: torch.Tensor): + device = attention_mask.device + seqlens_in_batch = get_max_seqlen_in_batch(attention_mask) + indices = torch.nonzero(attention_mask.flatten()).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = ( + F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + .to(device=device) + .detach() + ) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) def get_cu_seqlens(attn_mask): diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators.py index b9c1c3b3c1..da1ee392fb 100644 --- a/src/axolotl/utils/collators.py +++ b/src/axolotl/utils/collators.py @@ -152,6 +152,33 @@ def __call__(self, features, return_tensors=None): return super().__call__(features, return_tensors=return_tensors) +@dataclass +class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): + """ + Collator for multipack specific to the using the BatchSampler + """ + + def __call__(self, features, return_tensors=None): + chunked_data = {} + for feature in features[0].keys(): + if feature == "length": + continue + if feature == "attention_mask": + arrays = [ + (i + 1) * np.array(item[feature]) + for i, item in enumerate(features) + if feature in item + ] + chunked_data[feature] = np.concatenate(arrays) + else: + arrays = [ + np.array(item[feature]) for item in features if feature in item + ] + chunked_data[feature] = np.concatenate(arrays) + features = [chunked_data] + return super().__call__(features, return_tensors=return_tensors) + + @dataclass class MambaDataCollator: """ diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 0df6136c6f..beade8621b 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -1,12 +1,14 @@ """Module for working with config dicts""" - +import json import logging import os +from pathlib import Path import torch from transformers.utils import is_torch_bf16_gpu_available from axolotl.utils.bench import log_gpu_memory_usage +from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model_config LOG = logging.getLogger("axolotl") @@ -135,7 +137,7 @@ def normalize_config(cfg): ] ) or cfg.is_mistral_derived_model - or "mistral" in cfg.base_model.lower() + or "mistral" in cfg.base_model.lower().split("/")[-1] or (cfg.model_type and "mistral" in cfg.model_type.lower()) ) @@ -484,6 +486,40 @@ def validate_config(cfg): "max_memory and gpu_memory_limit are mutually exclusive and cannot be used together." ) + if ( + cfg.unfrozen_parameters + and cfg.gradient_checkpointing_kwargs + and cfg.gradient_checkpointing_kwargs.use_reentrant is True + ): + # https://github.com/huggingface/transformers/issues/21381 + raise ValueError( + "`use_reentrant` must be false when used with partially frozen model." + ) + + if cfg.flash_attention and cfg.deepspeed and Path(cfg.deepspeed).is_file(): + with open(cfg.deepspeed, encoding="utf-8") as file: + contents = file.read() + deepspeed_cfg: DictDefault = DictDefault(json.loads(contents)) + if ( + deepspeed_cfg.zero_optimization + and deepspeed_cfg.zero_optimization.stage == 3 + ): + if not ( + ( + deepspeed_cfg.bf16 + and deepspeed_cfg.bf16.enabled # pylint: disable=no-member + is True + ) + or ( + deepspeed_cfg.fp16 + and deepspeed_cfg.fp16.enabled # pylint: disable=no-member + is True + ) + ): + raise ValueError( + "bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention" + ) + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 71cdcc69c5..0342499324 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -305,12 +305,16 @@ def load_model( ) # Modify mistral derived models - if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing: + if ( + cfg.model_config_type == "mistral" + and cfg.flash_attention + and cfg.sample_packing + ): from axolotl.monkeypatch.mistral_attn_hijack_flash import ( replace_mistral_attn_with_flash_attn, ) - LOG.info("patching with flash attention") + LOG.info("patching mistral with flash attention") replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) if ( @@ -322,7 +326,7 @@ def load_model( replace_mixtral_attn_with_multipack_flash_attn, ) - LOG.info("patching with flash attention") + LOG.info("patching mixtral with flash attention") replace_mixtral_attn_with_multipack_flash_attn() if ( diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 3fc2446052..44642fb300 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -152,6 +152,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer): or (cfg.is_mistral_derived_model and cfg.flash_attention) or cfg.model_config_type == "mamba" ): + LOG.info("dropping attention_mask column") train_dataset = train_dataset.remove_columns("attention_mask") if eval_dataset: eval_dataset = eval_dataset.remove_columns("attention_mask") diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py index 4eff3825ae..30c53103ed 100644 --- a/tests/e2e/patched/test_mixtral_samplepack.py +++ b/tests/e2e/patched/test_mixtral_samplepack.py @@ -7,8 +7,6 @@ import unittest from pathlib import Path -from transformers.utils import is_torch_bf16_gpu_available - from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs from axolotl.train import train @@ -60,12 +58,9 @@ def test_qlora(self, temp_dir): "save_steps": 10, "eval_steps": 10, "sample_packing": True, + "bf16": "auto", } ) - if is_torch_bf16_gpu_available(): - cfg.bf16 = True - else: - cfg.fp16 = True normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) @@ -101,23 +96,16 @@ def test_ft(self, temp_dir): "save_steps": 10, "eval_steps": 10, "sample_packing": True, + "bf16": "auto", } ) - if is_torch_bf16_gpu_available(): - cfg.bf16 = True - else: - cfg.fp16 = True normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert ( - "axolotl.monkeypatch.mixtral.modeling_mixtral" - in model.model.layers[0].self_attn.__class__.__module__ - ) - assert ( - "MixtralMultipackFlashAttention2" + "MixtralFlashAttention2" in model.model.layers[0].self_attn.__class__.__name__ ) assert (Path(temp_dir) / "pytorch_model.bin").exists() diff --git a/tests/e2e/patched/test_model_patches.py b/tests/e2e/patched/test_model_patches.py index 65d372c735..8384b826f1 100644 --- a/tests/e2e/patched/test_model_patches.py +++ b/tests/e2e/patched/test_model_patches.py @@ -52,11 +52,7 @@ def test_mixtral_multipack(self, temp_dir): model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) assert ( - "axolotl.monkeypatch.mixtral.modeling_mixtral" - in model.model.layers[0].self_attn.__class__.__module__ - ) - assert ( - "MixtralMultipackFlashAttention2" + "MixtralFlashAttention2" in model.model.layers[0].self_attn.__class__.__name__ ) diff --git a/tests/monkeypatch/test_llama_attn_hijack_flash.py b/tests/monkeypatch/test_llama_attn_hijack_flash.py index 289c01a863..cce421e88a 100644 --- a/tests/monkeypatch/test_llama_attn_hijack_flash.py +++ b/tests/monkeypatch/test_llama_attn_hijack_flash.py @@ -5,7 +5,12 @@ import torch -from axolotl.monkeypatch.utils import get_cu_seqlens, get_cu_seqlens_from_pos_ids +from axolotl.monkeypatch.utils import ( + get_cu_seqlens, + get_cu_seqlens_from_pos_ids, + get_max_seqlen_in_batch, + get_unpad_data, +) class TestMonkeyPatchUtils(unittest.TestCase): @@ -25,6 +30,70 @@ def test_get_cu_seqlens_from_pos_ids_1d(self): torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res) ) + def test_get_max_seqlen_in_batch(self): + attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]]) + target_res = torch.tensor([4, 3, 5, 2], dtype=torch.int32) + self.assertTrue(torch.allclose(get_max_seqlen_in_batch(attn_mask), target_res)) + + def test_get_unpad_data(self): + attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]]) + target_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]) + target_cu_seqlen = torch.tensor([0, 4, 7, 12, 14], dtype=torch.int32) + target_max_seqlen_in_batch = 5 + indices, cu_seqlen, max_seqlen_in_batch = get_unpad_data(attn_mask) + self.assertTrue(torch.allclose(target_indices, indices)) + self.assertTrue(torch.allclose(target_cu_seqlen, cu_seqlen)) + self.assertEqual(target_max_seqlen_in_batch, max_seqlen_in_batch) + + attn_mask = torch.tensor( + [ + [1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0], + [1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5], + ] + ) + target_indices = torch.tensor( + [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + ] + ) + target_cu_seqlen = torch.tensor( + [0, 4, 7, 12, 14, 17, 22, 24, 27, 30], dtype=torch.int32 + ) + target_max_seqlen_in_batch = 5 + indices, cu_seqlen, max_seqlen_in_batch = get_unpad_data(attn_mask) + self.assertTrue(torch.allclose(target_indices, indices)) + self.assertTrue(torch.allclose(target_cu_seqlen, cu_seqlen)) + self.assertEqual(target_max_seqlen_in_batch, max_seqlen_in_batch) + if __name__ == "__main__": unittest.main()