Skip to content

Commit

Permalink
refactor: remove deprecated rotary embedding classes and streamline p…
Browse files Browse the repository at this point in the history
…osition embeddings handling
  • Loading branch information
h3110Fr13nd committed Dec 11, 2024
1 parent 44c8b2f commit 9bf4eae
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 75 deletions.
30 changes: 0 additions & 30 deletions src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,13 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
get_torch_version,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_phi import PhiConfig


if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward


logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "microsoft/phi-1"
_CONFIG_FOR_DOC = "PhiConfig"
Expand Down Expand Up @@ -150,31 +145,6 @@ def forward(self, x, position_ids):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
"""PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

def __init__(self, *args, **kwargs):
logger.warning_once(
"`PhiLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
"`PhiRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
)
kwargs["rope_type"] = "linear"
super().__init__(*args, **kwargs)


class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
"""PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

def __init__(self, *args, **kwargs):
logger.warning_once(
"`PhiDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
"`PhiRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
"__init__)."
)
kwargs["rope_type"] = "dynamic"
super().__init__(*args, **kwargs)


class PhiMLP(nn.Module):
def __init__(self, config):
super().__init__()
Expand Down
54 changes: 9 additions & 45 deletions src/transformers/models/phi/modular_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,8 @@
from ..clip.modeling_clip import CLIPMLP
from ..gemma.modeling_gemma import GemmaForCausalLM
from ..llama.modeling_llama import (
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaLinearScalingRotaryEmbedding,
LlamaModel,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
Expand Down Expand Up @@ -236,15 +234,6 @@ def __init__(
class PhiRotaryEmbedding(LlamaRotaryEmbedding):
pass


class PhiLinearScalingRotaryEmbedding(LlamaLinearScalingRotaryEmbedding, PhiRotaryEmbedding):
pass


class PhiDynamicNTKScalingRotaryEmbedding(LlamaDynamicNTKScalingRotaryEmbedding, PhiRotaryEmbedding):
pass


class PhiMLP(CLIPMLP):
pass

Expand Down Expand Up @@ -304,7 +293,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

Expand All @@ -320,16 +309,7 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
cos, sin = position_embeddings

# Partial rotary embedding
query_rot, query_pass = (
Expand Down Expand Up @@ -415,7 +395,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# PhiFlashAttention2 attention does not support output_attentions
Expand All @@ -439,16 +419,7 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
cos, sin = position_embeddings

# Partial rotary embedding
query_rot, query_pass = (
Expand Down Expand Up @@ -551,7 +522,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
Expand All @@ -568,6 +539,8 @@ def forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)

bsz, q_len, _ = hidden_states.size()
Expand All @@ -584,16 +557,7 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
cos, sin = position_embeddings

# Partial rotary embedding
query_rot, query_pass = (
Expand Down Expand Up @@ -680,7 +644,7 @@ def forward(
use_cache: Optional[bool] = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand Down

0 comments on commit 9bf4eae

Please sign in to comment.