diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 56bfe364e4ebc9..551f610908bb5d 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -336,8 +336,10 @@ def __init__(self, config: Phi3Config, device=None): self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq if self.rope_type == "longrope": - short_inv_freq, _ = self.rope_init_fn(self.config, device, seq_len=0, **self.rope_kwargs) - self.register_buffer("short_inv_freq", short_inv_freq, persistent=False) + long_inv_freq, _ = self.rope_init_fn( + self.config, device, seq_len=config.original_max_position_embeddings + 1, **self.rope_kwargs + ) + self.register_buffer("long_inv_freq", long_inv_freq, persistent=False) def _dynamic_frequency_update(self, position_ids, device): """ @@ -365,8 +367,8 @@ def forward(self, x, position_ids): inv_freq = self.inv_freq elif self.rope_type == "longrope": seq_len = torch.max(position_ids) + 1 - if seq_len <= self.config.original_max_position_embeddings: - inv_freq = self.short_inv_freq + if seq_len > self.config.original_max_position_embeddings: + inv_freq = self.long_inv_freq # Core RoPE block inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) diff --git a/src/transformers/models/phi3/modular_phi3.py b/src/transformers/models/phi3/modular_phi3.py index 0b52e5bc4e9f4d..af44efadd341cf 100644 --- a/src/transformers/models/phi3/modular_phi3.py +++ b/src/transformers/models/phi3/modular_phi3.py @@ -216,8 +216,10 @@ class Phi3RotaryEmbedding(MistralRotaryEmbedding): def __init__(self, config: Phi3Config, device=None): super().__init__(config, device) if self.rope_type == "longrope": - short_inv_freq, _ = self.rope_init_fn(self.config, device, seq_len=0, **self.rope_kwargs) - self.register_buffer("short_inv_freq", short_inv_freq, persistent=False) + long_inv_freq, _ = self.rope_init_fn( + self.config, device, seq_len=config.original_max_position_embeddings + 1, **self.rope_kwargs + ) + self.register_buffer("long_inv_freq", long_inv_freq, persistent=False) @torch.no_grad() def forward(self, x, position_ids): @@ -227,8 +229,8 @@ def forward(self, x, position_ids): inv_freq = self.inv_freq elif self.rope_type == "longrope": seq_len = torch.max(position_ids) + 1 - if seq_len <= self.config.original_max_position_embeddings: - inv_freq = self.short_inv_freq + if seq_len > self.config.original_max_position_embeddings: + inv_freq = self.long_inv_freq # Core RoPE block inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)