Skip to content

Commit

Permalink
sample packing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 15, 2023
1 parent a159dca commit cf773f1
Showing 1 changed file with 30 additions and 11 deletions.
41 changes: 30 additions & 11 deletions src/axolotl/models/phi/modeling_mixformer_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from __future__ import annotations

import copy
import inspect
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple

Expand Down Expand Up @@ -284,10 +285,10 @@ def _update_kv_cache(kv, inference_params, layer_idx):
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + kv.shape[1]
assert batch_end <= (
kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0]
kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0] # noqa
)
assert sequence_end <= (
kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2]
kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2] # noqa
)

assert kv_cache is not None
Expand Down Expand Up @@ -394,8 +395,8 @@ def __init__(
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
n_inner = n_inner if n_inner is not None else 4 * config.n_embd

gelu_activations = ["gelu_new", "gelu_fast", "gelu_approx"]
activation = "gelu_approx" if act_fn in gelu_activations else "relu"
gelu_activations = ["gelu_new", "gelu_fast", "gelu_approx"] # noqa
activation = "gelu_approx" if act_fn in gelu_activations else "relu" # noqa

self.mlp = MLP(config, n_inner=n_inner, act_fn=act_fn)

Expand All @@ -422,7 +423,7 @@ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
self.drop = nn.Dropout(attention_dropout)

def forward(
self, qkv, causal=None, key_padding_mask=None, cu_seqlen=None, max_seqlen=None
self, qkv, causal=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None
):
"""Implements the multihead softmax attention.
Arguments
Expand All @@ -433,9 +434,11 @@ def forward(
False means to mask out. (B, S)
"""
causal = self.causal if causal is None else causal
if cu_seqlen:
if cu_seqlens is not None:
return flash_attn_varlen_qkvpacked_func(
qkv,
qkv.squeeze(0),
cu_seqlens,
max_seqlen,
dropout_p=self.drop.p,
softmax_scale=self.softmax_scale,
causal=causal,
Expand Down Expand Up @@ -645,7 +648,7 @@ def forward(
assert max_seqlen is not None
assert key_padding_mask is None
assert self.flash_attn
assert self.rotary_emb_dim == 0
# assert self.rotary_emb_dim == 0

if key_padding_mask is not None:
assert cu_seqlens is None
Expand Down Expand Up @@ -826,6 +829,22 @@ def prepare_inputs_for_generation(
return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs}


class PackedSequential(nn.Sequential):
def forward(
self,
input,
cu_seqlens: Optional[torch.LongTensor] = None,
max_seqlen: Optional[int] = None,
):
for module in self:
sig = inspect.signature(module.forward)
if "cu_seqlens" in sig.parameters:
input = module(input, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
else:
input = module(input)
return input


class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
"""MixFormer (sequential for DeepSpeed) for Causal Language Modeling."""

Expand All @@ -851,14 +870,14 @@ def __init__(self, config: MixFormerSequentialConfig) -> None:
# `block_cls` with `legacy` value is for backward compatibility
# `path` key is for backward compatibility
block = copy.deepcopy(block) or {"block_cls": "parallel"}
block_cls = block.pop("path", None) or block.pop("block_cls", None)
# block_cls = block.pop("path", None) or block.pop("block_cls", None)

block["block_idx"] = block_idx
modules.append(ParallelBlock(config, **block))

modules.append(CausalLMHead(config))

self.layers = nn.Sequential(*modules)
self.layers = PackedSequential(*modules)
self.loss = CausalLMLoss()

self.post_init()
Expand All @@ -885,7 +904,7 @@ def forward(
) -> CausalLMOutputWithPast:
cu_seqlens: Optional[torch.LongTensor] = None
max_seqlen: Optional[int] = None
if position_ids:
if position_ids is not None:
batch_size, seq_length = input_ids.shape
position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
Expand Down

0 comments on commit cf773f1

Please sign in to comment.