Skip to content

Commit

Permalink
phi3 longrope
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed Dec 20, 2024
1 parent 41dc110 commit f3ff26a
Show file tree
Hide file tree
Showing 20 changed files with 66 additions and 94 deletions.
9 changes: 2 additions & 7 deletions src/transformers/modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,15 +279,10 @@ def _compute_longrope_parameters(
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`.
if hasattr(config, "original_max_position_embeddings"):
if seq_len and seq_len < config.original_max_position_embeddings:
expanded_max_position_embeddings = config.original_max_position_embeddings
else:
expanded_max_position_embeddings = config.max_position_embeddings
max_position_embeddings = config.original_max_position_embeddings
factor = expanded_max_position_embeddings / max_position_embeddings
factor = config.max_position_embeddings / config.original_max_position_embeddings
else:
max_position_embeddings = config.max_position_embeddings
expanded_max_position_embeddings = max_position_embeddings * factor

# Sets the attention factor as suggested in the paper
if attention_factor is None:
Expand All @@ -297,7 +292,7 @@ def _compute_longrope_parameters(
attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))

# Compute the inverse frequencies -- scaled based on the target sequence length
if expanded_max_position_embeddings > max_position_embeddings:
if seq_len and seq_len > max_position_embeddings:
ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
else:
ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
Expand Down
6 changes: 1 addition & 5 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,11 +723,7 @@ def _init_weights(self, module):


class AriaTextRotaryEmbedding(nn.Module):
def __init__(
self,
config: AriaTextConfig,
device=None,
):
def __init__(self, config: AriaTextConfig, device=None):
super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type"
Expand Down
6 changes: 1 addition & 5 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,7 @@ def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=


class BambaRotaryEmbedding(nn.Module):
def __init__(
self,
config: BambaConfig,
device=None,
):
def __init__(self, config: BambaConfig, device=None):
super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type"
Expand Down
6 changes: 1 addition & 5 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@ def forward(self, hidden_states):


class CohereRotaryEmbedding(nn.Module):
def __init__(
self,
config: CohereConfig,
device=None,
):
def __init__(self, config: CohereConfig, device=None):
super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type"
Expand Down
6 changes: 1 addition & 5 deletions src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,7 @@


class Cohere2RotaryEmbedding(nn.Module):
def __init__(
self,
config: Cohere2Config,
device=None,
):
def __init__(self, config: Cohere2Config, device=None):
super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type"
Expand Down
6 changes: 1 addition & 5 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,7 @@ def forward(self, x):


class GemmaRotaryEmbedding(nn.Module):
def __init__(
self,
config: GemmaConfig,
device=None,
):
def __init__(self, config: GemmaConfig, device=None):
super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type"
Expand Down
6 changes: 1 addition & 5 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,7 @@ def forward(


class Gemma2RotaryEmbedding(nn.Module):
def __init__(
self,
config: Gemma2Config,
device=None,
):
def __init__(self, config: Gemma2Config, device=None):
super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type"
Expand Down
6 changes: 1 addition & 5 deletions src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,7 @@ def extra_repr(self):


class GlmRotaryEmbedding(nn.Module):
def __init__(
self,
config: GlmConfig,
device=None,
):
def __init__(self, config: GlmConfig, device=None):
super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type"
Expand Down
6 changes: 1 addition & 5 deletions src/transformers/models/granite/modeling_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,7 @@ def forward(


class GraniteRotaryEmbedding(nn.Module):
def __init__(
self,
config: GraniteConfig,
device=None,
):
def __init__(self, config: GraniteConfig, device=None):
super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type"
Expand Down
6 changes: 1 addition & 5 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,7 @@ def extra_repr(self):


class LlamaRotaryEmbedding(nn.Module):
def __init__(
self,
config: LlamaConfig,
device=None,
):
def __init__(self, config: LlamaConfig, device=None):
super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type"
Expand Down
6 changes: 1 addition & 5 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,7 @@ def forward(


class MistralRotaryEmbedding(nn.Module):
def __init__(
self,
config: MistralConfig,
device=None,
):
def __init__(self, config: MistralConfig, device=None):
super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type"
Expand Down
6 changes: 1 addition & 5 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,11 +392,7 @@ def forward(


class MixtralRotaryEmbedding(nn.Module):
def __init__(
self,
config: MixtralConfig,
device=None,
):
def __init__(self, config: MixtralConfig, device=None):
super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type"
Expand Down
6 changes: 1 addition & 5 deletions src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,7 @@ def forward(


class OlmoRotaryEmbedding(nn.Module):
def __init__(
self,
config: OlmoConfig,
device=None,
):
def __init__(self, config: OlmoConfig, device=None):
super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type"
Expand Down
6 changes: 1 addition & 5 deletions src/transformers/models/olmo2/modeling_olmo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,7 @@ def forward(


class Olmo2RotaryEmbedding(nn.Module):
def __init__(
self,
config: Olmo2Config,
device=None,
):
def __init__(self, config: Olmo2Config, device=None):
super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type"
Expand Down
6 changes: 1 addition & 5 deletions src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,7 @@ def forward(


class PhiRotaryEmbedding(nn.Module):
def __init__(
self,
config: PhiConfig,
device=None,
):
def __init__(self, config: PhiConfig, device=None):
super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type"
Expand Down
17 changes: 11 additions & 6 deletions src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,7 @@ def forward(


class Phi3RotaryEmbedding(nn.Module):
def __init__(
self,
config: Phi3Config,
device=None,
):
def __init__(self, config: Phi3Config, device=None):
super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type"
Expand All @@ -339,6 +335,9 @@ def __init__(
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
if self.rope_type == "longrope":
short_inv_freq, _ = self.rope_init_fn(self.config, device, seq_len=0, **self.rope_kwargs)
self.register_buffer("short_inv_freq", short_inv_freq, persistent=False)

def _dynamic_frequency_update(self, position_ids, device):
"""
Expand All @@ -360,11 +359,17 @@ def _dynamic_frequency_update(self, position_ids, device):

@torch.no_grad()
def forward(self, x, position_ids):
inv_freq = self.inv_freq
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
inv_freq = self.inv_freq
elif self.rope_type == "longrope":
seq_len = torch.max(position_ids) + 1
if seq_len <= self.config.original_max_position_embeddings:
inv_freq = self.short_inv_freq

# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
Expand Down
35 changes: 34 additions & 1 deletion src/transformers/models/phi3/modular_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,40 @@ def forward(


class Phi3RotaryEmbedding(MistralRotaryEmbedding):
pass
def __init__(self, config: Phi3Config, device=None):
super().__init__(config, device)
if self.rope_type == "longrope":
short_inv_freq, _ = self.rope_init_fn(self.config, device, seq_len=0, **self.rope_kwargs)
self.register_buffer("short_inv_freq", short_inv_freq, persistent=False)

@torch.no_grad()
def forward(self, x, position_ids):
inv_freq = self.inv_freq
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
inv_freq = self.inv_freq
elif self.rope_type == "longrope":
seq_len = torch.max(position_ids) + 1
if seq_len <= self.config.original_max_position_embeddings:
inv_freq = self.short_inv_freq

# Core RoPE block
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling

return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class Phi3PreTrainedModel(MistralPreTrainedModel):
Expand Down
6 changes: 1 addition & 5 deletions src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,11 +283,7 @@ def forward(


class Qwen2RotaryEmbedding(nn.Module):
def __init__(
self,
config: Qwen2Config,
device=None,
):
def __init__(self, config: Qwen2Config, device=None):
super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type"
Expand Down
6 changes: 1 addition & 5 deletions src/transformers/models/starcoder2/modeling_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,7 @@ def forward(


class Starcoder2RotaryEmbedding(nn.Module):
def __init__(
self,
config: Starcoder2Config,
device=None,
):
def __init__(self, config: Starcoder2Config, device=None):
super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type"
Expand Down
3 changes: 3 additions & 0 deletions tests/models/phi3/test_modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,9 @@ def test_model_rope_scaling_short_long_factor(self, scaling_type):
"long_factor": [5.0 for _ in range(n_factors)],
}
input_tensor = ids_tensor([1, 4090], config.vocab_size)
# Make sure we don't have padding tokens. If this is the case, then the actual number of "true" tokens may be shorter
# than `config.original_max_position_embeddings + 5`, invalidating this test
input_tensor[input_tensor == config.pad_token_id] += 1
model = Phi3ForCausalLM(config)
model.to(torch_device)
model.eval()
Expand Down

0 comments on commit f3ff26a

Please sign in to comment.