Skip to content

Commit

Permalink
incorporated some suggestions from the pr
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Oct 13, 2023
1 parent b0960f7 commit f6632e1
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 19 deletions.
24 changes: 12 additions & 12 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import math
import warnings
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -55,7 +55,7 @@ def scaled_multihead_dot_product_attention(
value: torch.Tensor,
n_heads: int,
kv_n_heads: Optional[int] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
softmax_scale: Optional[float] = None,
attn_bias: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
Expand All @@ -64,7 +64,7 @@ def scaled_multihead_dot_product_attention(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:

if multiquery:
Expand Down Expand Up @@ -168,7 +168,7 @@ def scaled_multihead_dot_product_attention(


def check_valid_inputs(*tensors: torch.Tensor,
valid_dtypes: Optional[List[torch.dtype]] = None):
valid_dtypes: Optional[list[torch.dtype]] = None):
if valid_dtypes is None:
valid_dtypes = [torch.float16, torch.bfloat16]
for tensor in tensors:
Expand All @@ -184,7 +184,7 @@ def flash_attn_fn(
value: torch.Tensor,
n_heads: int,
kv_n_heads: Optional[int] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
softmax_scale: Optional[float] = None,
attn_bias: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
Expand All @@ -193,7 +193,7 @@ def flash_attn_fn(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
try:
from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip
Expand Down Expand Up @@ -304,7 +304,7 @@ def triton_flash_attn_fn(
value: torch.Tensor,
n_heads: int,
kv_n_heads: Optional[int] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
softmax_scale: Optional[float] = None,
attn_bias: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
Expand All @@ -313,7 +313,7 @@ def triton_flash_attn_fn(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
try:
from llmfoundry.models.layers.flash_attn_triton import flash_attn_func
Expand Down Expand Up @@ -519,13 +519,13 @@ def __init__(
def forward(
self,
x: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attn_bias: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb_w_offset_info: Optional[Dict] = None,
rotary_emb_w_offset_info: Optional[dict] = None,
is_causal: bool = True,
needs_weights: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
qkv = self.Wqkv(x)

Expand Down Expand Up @@ -663,7 +663,7 @@ def __init__(
def attn_bias_shape(
attn_impl: str, n_heads: int, seq_len: int, alibi: bool,
prefix_lm: bool, causal: bool,
use_sequence_id: bool) -> Optional[Tuple[int, int, int, int]]:
use_sequence_id: bool) -> Optional[tuple[int, int, int, int]]:
if attn_impl == 'flash':
return None
elif attn_impl in ['torch', 'triton']:
Expand Down
11 changes: 4 additions & 7 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,6 @@ def forward(
), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'

rotary_emb_w_offset_info = None
pos_emb = 0.0
tok_emb = self.wte(input_ids)
if self.learned_pos_emb or self.rope:
past_position = 0
Expand Down Expand Up @@ -434,12 +433,10 @@ def forward(
'pos': pos,
'seq_len': S + past_position
}
x = tok_emb
if self.learned_pos_emb:
pos_emb = self.wpe(pos)
x = x + pos_emb

x = tok_emb + pos_emb
x = tok_emb
if self.learned_pos_emb:
pos_emb = self.wpe(pos)
x = x + pos_emb

if self.embedding_fraction == 1:
x = self.emb_drop(x)
Expand Down

0 comments on commit f6632e1

Please sign in to comment.