diff --git a/src/axolotl/models/phi/modeling_mixformer_sequential.py b/src/axolotl/models/phi/modeling_mixformer_sequential.py index 860d2f7694..46bb526ff1 100644 --- a/src/axolotl/models/phi/modeling_mixformer_sequential.py +++ b/src/axolotl/models/phi/modeling_mixformer_sequential.py @@ -36,6 +36,7 @@ from __future__ import annotations import copy +import inspect from dataclasses import dataclass, field from typing import Any, Dict, Optional, Tuple @@ -284,10 +285,10 @@ def _update_kv_cache(kv, inference_params, layer_idx): sequence_start = inference_params.sequence_len_offset sequence_end = sequence_start + kv.shape[1] assert batch_end <= ( - kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0] + kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0] # noqa ) assert sequence_end <= ( - kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2] + kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2] # noqa ) assert kv_cache is not None @@ -394,8 +395,8 @@ def __init__( n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner n_inner = n_inner if n_inner is not None else 4 * config.n_embd - gelu_activations = ["gelu_new", "gelu_fast", "gelu_approx"] - activation = "gelu_approx" if act_fn in gelu_activations else "relu" + gelu_activations = ["gelu_new", "gelu_fast", "gelu_approx"] # noqa + activation = "gelu_approx" if act_fn in gelu_activations else "relu" # noqa self.mlp = MLP(config, n_inner=n_inner, act_fn=act_fn) @@ -422,7 +423,7 @@ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): self.drop = nn.Dropout(attention_dropout) def forward( - self, qkv, causal=None, key_padding_mask=None, cu_seqlen=None, max_seqlen=None + self, qkv, causal=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None ): """Implements the multihead softmax attention. Arguments @@ -433,9 +434,11 @@ def forward( False means to mask out. (B, S) """ causal = self.causal if causal is None else causal - if cu_seqlen: + if cu_seqlens is not None: return flash_attn_varlen_qkvpacked_func( - qkv, + qkv.squeeze(0), + cu_seqlens, + max_seqlen, dropout_p=self.drop.p, softmax_scale=self.softmax_scale, causal=causal, @@ -645,7 +648,7 @@ def forward( assert max_seqlen is not None assert key_padding_mask is None assert self.flash_attn - assert self.rotary_emb_dim == 0 + # assert self.rotary_emb_dim == 0 if key_padding_mask is not None: assert cu_seqlens is None @@ -826,6 +829,22 @@ def prepare_inputs_for_generation( return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs} +class PackedSequential(nn.Sequential): + def forward( + self, + input, + cu_seqlens: Optional[torch.LongTensor] = None, + max_seqlen: Optional[int] = None, + ): + for module in self: + sig = inspect.signature(module.forward) + if "cu_seqlens" in sig.parameters: + input = module(input, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + else: + input = module(input) + return input + + class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel): """MixFormer (sequential for DeepSpeed) for Causal Language Modeling.""" @@ -851,14 +870,14 @@ def __init__(self, config: MixFormerSequentialConfig) -> None: # `block_cls` with `legacy` value is for backward compatibility # `path` key is for backward compatibility block = copy.deepcopy(block) or {"block_cls": "parallel"} - block_cls = block.pop("path", None) or block.pop("block_cls", None) + # block_cls = block.pop("path", None) or block.pop("block_cls", None) block["block_idx"] = block_idx modules.append(ParallelBlock(config, **block)) modules.append(CausalLMHead(config)) - self.layers = nn.Sequential(*modules) + self.layers = PackedSequential(*modules) self.loss = CausalLMLoss() self.post_init() @@ -885,7 +904,7 @@ def forward( ) -> CausalLMOutputWithPast: cu_seqlens: Optional[torch.LongTensor] = None max_seqlen: Optional[int] = None - if position_ids: + if position_ids is not None: batch_size, seq_length = input_ids.shape position_ids = position_ids.view(-1, seq_length).long() cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)