Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[phi2] fix the apply rotary to account for position_ids when packing #877

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/phi/phi-ft.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ tokenizer_type: AutoTokenizer
is_llama_derived_model: false
trust_remote_code: true

model_config:
flash_attn: true

load_in_8bit: false
load_in_4bit: false
strict: false
Expand Down
114 changes: 93 additions & 21 deletions src/axolotl/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,14 @@
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
from flash_attn.ops.fused_dense import FusedDense
except: # noqa: E722
pad_input, unpad_input = None, None
FlashRotaryEmbedding = None
FlashSelfAttention, FlashCrossAttention = None, None

try:
from flash_attn.ops.fused_dense import FusedDense
except: # noqa: E722
FusedDense = None


Expand Down Expand Up @@ -91,18 +94,27 @@ def _apply_rotary_emb(
x: torch.FloatTensor,
cos: torch.FloatTensor,
sin: torch.FloatTensor,
position_ids: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor:
_, seqlen, _, _ = x.shape
_, rotary_dim = cos.shape
rotary_dim *= 2

# Ensure position_ids is broadcastable to the shape of x
position_ids = position_ids.view(-1, 1, 1)

x_rot = x[:, :, :, :rotary_dim]
x_pass = x[:, :, :, rotary_dim:]

x1, x2 = x_rot.chunk(2, dim=-1)
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
sin[:seqlen], "s d -> s 1 d"
)
# c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
# sin[:seqlen], "s d -> s 1 d"
# )

# Select cos and sin values based on position_ids
c = cos[position_ids].expand(-1, -1, -1, rotary_dim // 2)
s = sin[position_ids].expand(-1, -1, -1, rotary_dim // 2)

x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]]

x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype)
Expand All @@ -116,18 +128,27 @@ def _apply_rotary_emb_kv(
sin: torch.FloatTensor,
cos_k: Optional[torch.FloatTensor] = None,
sin_k: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor:
_, seqlen, _, _, _ = kv.shape
_, rotary_dim = cos.shape
rotary_dim *= 2

# Ensure position_ids is broadcastable to the shape of x
position_ids = position_ids.view(-1, 1, 1)

k_rot = kv[:, :, 0, :, :rotary_dim]
k_pass = kv[:, :, 0, :, rotary_dim:]

k1, k2 = k_rot.chunk(2, dim=-1)
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
sin[:seqlen], "s d -> s 1 d"
)
# c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
# sin[:seqlen], "s d -> s 1 d"
# )

# Select cos and sin values based on position_ids
c = cos[position_ids].expand(-1, -1, -1, rotary_dim // 2)
s = sin[position_ids].expand(-1, -1, -1, rotary_dim // 2)

k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]]

k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype)
Expand All @@ -141,17 +162,27 @@ def _apply_rotary_emb_kv(
)


@torch.jit.script
def _apply_rotary_emb_qkv(
qkv: torch.FloatTensor,
cos: torch.FloatTensor,
sin: torch.FloatTensor,
cos_k: Optional[torch.FloatTensor] = None,
sin_k: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor:
_, seqlen, _, _, _ = qkv.shape
_, rotary_dim = cos.shape
rotary_dim *= 2

# position_ids = position_ids.view(-1, 1, 1)
# Check if position_ids is provided, if not, create a default range
if position_ids is None:
position_ids = torch.arange(seqlen, device=qkv.device).view(1, -1)

pos_len = position_ids.size(1)
position_ids = position_ids.view(pos_len, 1, 1)

q_rot = qkv[:, :, 0, :, :rotary_dim]
q_pass = qkv[:, :, 0, :, rotary_dim:]

Expand All @@ -160,9 +191,14 @@ def _apply_rotary_emb_qkv(

q1, q2 = q_rot.chunk(2, dim=-1)
k1, k2 = k_rot.chunk(2, dim=-1)
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
sin[:seqlen], "s d -> s 1 d"
)
# c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
# sin[:seqlen], "s d -> s 1 d"
# )

# Select cos and sin values based on position_ids
c = cos[position_ids].squeeze(-2)
s = sin[position_ids].squeeze(-2)

q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]

q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
Expand Down Expand Up @@ -279,6 +315,7 @@ def forward(
qkv: torch.Tensor,
kv: Optional[torch.Tensor] = None,
seqlen_offset: int = 0,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
seq_start = seqlen_offset
Expand All @@ -298,17 +335,20 @@ def forward(
qkv,
self._cos_cached[seq_start:seq_end],
self._sin_cached[seq_start:seq_end],
position_ids=position_ids,
)
else:
q = _apply_rotary_emb(
qkv,
self._cos_cached[seq_start:seq_end],
self._sin_cached[seq_start:seq_end],
position_ids=position_ids,
)
kv = _apply_rotary_emb_kv(
kv,
self._cos_cached[seq_start:seq_end],
self._sin_cached[seq_start:seq_end],
position_ids=position_ids,
)

return q, kv
Expand Down Expand Up @@ -631,18 +671,20 @@ def _forward_self_attn(
key_padding_mask: Optional[torch.BoolTensor],
cu_seqlens: Optional[torch.LongTensor] = None,
max_seqlen: Optional[int] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor:
qkv = self.Wqkv(x)
qkv = rearrange(
qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim
)

if self.rotary_dim > 0:
qkv = self.rotary_emb(qkv)
qkv = self.rotary_emb(qkv, position_ids=position_ids)

if self.flash_attn:
batch_size, seqlen = qkv.shape[0], qkv.shape[1]

return_w_pad = False
if (
key_padding_mask is not None
and cu_seqlens is None
Expand All @@ -653,6 +695,7 @@ def _forward_self_attn(
qkv, indices, cu_seqlens, max_seqlen = unpad_input(
qkv, key_padding_mask
)
return_w_pad = True

if self.checkpointing:
attn_output = torch.utils.checkpoint.checkpoint(
Expand All @@ -666,7 +709,7 @@ def _forward_self_attn(
# If `key_padding_mask` is supplied, we need to pad the output back to the original shape
return (
pad_input(attn_output, indices, batch_size, seqlen)
if key_padding_mask is not None
if key_padding_mask is not None and return_w_pad
else attn_output
)

Expand All @@ -682,6 +725,9 @@ def _forward_cross_attn(
x: torch.FloatTensor,
past_key_values: Optional[InferenceParams],
key_padding_mask: Optional[torch.BoolTensor],
cu_seqlens: Optional[torch.LongTensor] = None,
max_seqlen: Optional[int] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor:
batch_size = x.shape[0]

Expand All @@ -698,7 +744,9 @@ def _forward_cross_attn(
)
causal = None if seqlen_offset == 0 else False
if self.rotary_dim > 0:
q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset)
q, kv = self.rotary_emb(
q, kv=kv, seqlen_offset=seqlen_offset, position_ids=position_ids
)

if past_key_values is not None:
kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
Expand All @@ -708,12 +756,17 @@ def _forward_cross_attn(
seqlen_k = kv.shape[1]

cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = (
None,
None,
None,
None,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
)
if key_padding_mask is not None:
return_w_pad = False
if (
key_padding_mask is not None
and cu_seqlens is None
and max_seqlen is None
):
kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask)

if seqlen_q == 1:
Expand All @@ -724,6 +777,7 @@ def _forward_cross_attn(
q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
q, key_padding_mask
)
return_w_pad = True

if self.checkpointing:
attn_output = torch.utils.checkpoint.checkpoint(
Expand All @@ -749,7 +803,7 @@ def _forward_cross_attn(

return (
pad_input(attn_output, indices_q, batch_size, max_seqlen_q)
if key_padding_mask is not None
if key_padding_mask is not None and return_w_pad
else attn_output
)

Expand All @@ -773,6 +827,7 @@ def forward(
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
max_seqlen: Optional[int] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
# TODO: Need an alternative way for dynamic control flow: torch.any(~attention_mask.bool())
Expand All @@ -786,7 +841,11 @@ def forward(
if past_key_values is None:
# If `past_key_values` are not supplied, we run self-attention
attn_output = self._forward_self_attn(
x, attention_mask, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
x,
attention_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
position_ids=position_ids,
)
else:
# If `past_key_values` are supplied, it means that we might have cached values and
Expand All @@ -797,12 +856,20 @@ def forward(
attention_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
position_ids=position_ids,
)
# MQA / GQA
else:
# Regardless of `past_key_values` being supplied or not, it always use cross-attention
# because `q` and `kv` lengths might be different
attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
attn_output = self._forward_cross_attn(
x,
past_key_values,
attention_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
position_ids=position_ids,
)

output = rearrange(attn_output, "... h d -> ... (h d)")
output = self.out_proj(output)
Expand Down Expand Up @@ -836,6 +903,7 @@ def forward(
hidden_states: torch.FloatTensor,
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
attention_mask: Optional[torch.BoolTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
Expand All @@ -845,6 +913,7 @@ def forward(
hidden_states,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
)
if isinstance(attn_outputs, tuple):
attn_outputs = attn_outputs[0]
Expand Down Expand Up @@ -990,6 +1059,7 @@ def forward(
attention_mask: Optional[torch.BoolTensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
max_seqlen: Optional[int] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor:
hidden_states = self.embd(input_ids)

Expand All @@ -1000,6 +1070,7 @@ def forward(
attention_mask=attention_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
position_ids=position_ids,
)

return hidden_states
Expand Down Expand Up @@ -1051,6 +1122,7 @@ def forward(
attention_mask=attention_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
position_ids=position_ids,
)
lm_logits = self.lm_head(hidden_states)

Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def load_model(

model = PhiForCausalLM.from_pretrained(
base_model,
config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs,
Expand Down
Loading