diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index d441531d29..b3bb6e193d 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -5,7 +5,7 @@ import math import warnings -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional import torch import torch.nn as nn @@ -55,7 +55,7 @@ def scaled_multihead_dot_product_attention( value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int] = None, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, softmax_scale: Optional[float] = None, attn_bias: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None, @@ -64,7 +64,7 @@ def scaled_multihead_dot_product_attention( training: bool = False, needs_weights: bool = False, multiquery: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: if multiquery: @@ -168,7 +168,7 @@ def scaled_multihead_dot_product_attention( def check_valid_inputs(*tensors: torch.Tensor, - valid_dtypes: Optional[List[torch.dtype]] = None): + valid_dtypes: Optional[list[torch.dtype]] = None): if valid_dtypes is None: valid_dtypes = [torch.float16, torch.bfloat16] for tensor in tensors: @@ -184,7 +184,7 @@ def flash_attn_fn( value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int] = None, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, softmax_scale: Optional[float] = None, attn_bias: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None, @@ -193,7 +193,7 @@ def flash_attn_fn( training: bool = False, needs_weights: bool = False, multiquery: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: try: from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip @@ -304,7 +304,7 @@ def triton_flash_attn_fn( value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int] = None, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, softmax_scale: Optional[float] = None, attn_bias: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None, @@ -313,7 +313,7 @@ def triton_flash_attn_fn( training: bool = False, needs_weights: bool = False, multiquery: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: try: from llmfoundry.models.layers.flash_attn_triton import flash_attn_func @@ -519,13 +519,13 @@ def __init__( def forward( self, x: torch.Tensor, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - rotary_emb_w_offset_info: Optional[Dict] = None, + rotary_emb_w_offset_info: Optional[dict] = None, is_causal: bool = True, needs_weights: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: qkv = self.Wqkv(x) @@ -663,7 +663,7 @@ def __init__( def attn_bias_shape( attn_impl: str, n_heads: int, seq_len: int, alibi: bool, prefix_lm: bool, causal: bool, - use_sequence_id: bool) -> Optional[Tuple[int, int, int, int]]: + use_sequence_id: bool) -> Optional[tuple[int, int, int, int]]: if attn_impl == 'flash': return None elif attn_impl in ['torch', 'triton']: diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 62c7da7cd5..18050da079 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -390,7 +390,6 @@ def forward( ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' rotary_emb_w_offset_info = None - pos_emb = 0.0 tok_emb = self.wte(input_ids) if self.learned_pos_emb or self.rope: past_position = 0 @@ -434,12 +433,10 @@ def forward( 'pos': pos, 'seq_len': S + past_position } - x = tok_emb - if self.learned_pos_emb: - pos_emb = self.wpe(pos) - x = x + pos_emb - - x = tok_emb + pos_emb + x = tok_emb + if self.learned_pos_emb: + pos_emb = self.wpe(pos) + x = x + pos_emb if self.embedding_fraction == 1: x = self.emb_drop(x)