From f3ff26a17d23906293fccf2aad51141c0bbd932c Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Dec 2024 17:52:18 +0100 Subject: [PATCH] phi3 longrope --- src/transformers/modeling_rope_utils.py | 9 ++--- src/transformers/models/aria/modeling_aria.py | 6 +--- .../models/bamba/modeling_bamba.py | 6 +--- .../models/cohere/modeling_cohere.py | 6 +--- .../models/cohere2/modeling_cohere2.py | 6 +--- .../models/gemma/modeling_gemma.py | 6 +--- .../models/gemma2/modeling_gemma2.py | 6 +--- src/transformers/models/glm/modeling_glm.py | 6 +--- .../models/granite/modeling_granite.py | 6 +--- .../models/llama/modeling_llama.py | 6 +--- .../models/mistral/modeling_mistral.py | 6 +--- .../models/mixtral/modeling_mixtral.py | 6 +--- src/transformers/models/olmo/modeling_olmo.py | 6 +--- .../models/olmo2/modeling_olmo2.py | 6 +--- src/transformers/models/phi/modeling_phi.py | 6 +--- src/transformers/models/phi3/modeling_phi3.py | 17 +++++---- src/transformers/models/phi3/modular_phi3.py | 35 ++++++++++++++++++- .../models/qwen2/modeling_qwen2.py | 6 +--- .../models/starcoder2/modeling_starcoder2.py | 6 +--- tests/models/phi3/test_modeling_phi3.py | 3 ++ 20 files changed, 66 insertions(+), 94 deletions(-) diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index c617420a5896de..55d63fa4514a88 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -279,15 +279,10 @@ def _compute_longrope_parameters( # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two # values to compute the default attention scaling factor, instead of using `factor`. if hasattr(config, "original_max_position_embeddings"): - if seq_len and seq_len < config.original_max_position_embeddings: - expanded_max_position_embeddings = config.original_max_position_embeddings - else: - expanded_max_position_embeddings = config.max_position_embeddings max_position_embeddings = config.original_max_position_embeddings - factor = expanded_max_position_embeddings / max_position_embeddings + factor = config.max_position_embeddings / config.original_max_position_embeddings else: max_position_embeddings = config.max_position_embeddings - expanded_max_position_embeddings = max_position_embeddings * factor # Sets the attention factor as suggested in the paper if attention_factor is None: @@ -297,7 +292,7 @@ def _compute_longrope_parameters( attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings)) # Compute the inverse frequencies -- scaled based on the target sequence length - if expanded_max_position_embeddings > max_position_embeddings: + if seq_len and seq_len > max_position_embeddings: ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device) else: ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 6481d6f3c434c7..3547a1a0b49d50 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -723,11 +723,7 @@ def _init_weights(self, module): class AriaTextRotaryEmbedding(nn.Module): - def __init__( - self, - config: AriaTextConfig, - device=None, - ): + def __init__(self, config: AriaTextConfig, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index c89d8d7853008d..e17e5ef712cf7a 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -120,11 +120,7 @@ def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device= class BambaRotaryEmbedding(nn.Module): - def __init__( - self, - config: BambaConfig, - device=None, - ): + def __init__(self, config: BambaConfig, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 714f04a54ee3b8..58a1891f00ebc2 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -73,11 +73,7 @@ def forward(self, hidden_states): class CohereRotaryEmbedding(nn.Module): - def __init__( - self, - config: CohereConfig, - device=None, - ): + def __init__(self, config: CohereConfig, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index cefef6e98cd47a..d0463573d2b1d8 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -53,11 +53,7 @@ class Cohere2RotaryEmbedding(nn.Module): - def __init__( - self, - config: Cohere2Config, - device=None, - ): + def __init__(self, config: Cohere2Config, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index e2ea12b03fe434..f2611b16c34558 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -92,11 +92,7 @@ def forward(self, x): class GemmaRotaryEmbedding(nn.Module): - def __init__( - self, - config: GemmaConfig, - device=None, - ): + def __init__(self, config: GemmaConfig, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 67fc6c86a3bac6..0bf2b154f9c221 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -324,11 +324,7 @@ def forward( class Gemma2RotaryEmbedding(nn.Module): - def __init__( - self, - config: Gemma2Config, - device=None, - ): + def __init__(self, config: Gemma2Config, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index af8ab4f85e276f..87a1c75e5084ed 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -255,11 +255,7 @@ def extra_repr(self): class GlmRotaryEmbedding(nn.Module): - def __init__( - self, - config: GlmConfig, - device=None, - ): + def __init__(self, config: GlmConfig, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 2e045e149d95de..49b6b7ed2df237 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -309,11 +309,7 @@ def forward( class GraniteRotaryEmbedding(nn.Module): - def __init__( - self, - config: GraniteConfig, - device=None, - ): + def __init__(self, config: GraniteConfig, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5be33c26414cd7..61e7fa0f6ca429 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -80,11 +80,7 @@ def extra_repr(self): class LlamaRotaryEmbedding(nn.Module): - def __init__( - self, - config: LlamaConfig, - device=None, - ): + def __init__(self, config: LlamaConfig, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 90c38895b4280b..8c0d8af3ece737 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -270,11 +270,7 @@ def forward( class MistralRotaryEmbedding(nn.Module): - def __init__( - self, - config: MistralConfig, - device=None, - ): + def __init__(self, config: MistralConfig, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 84ed327d9be920..a726b69fb6688f 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -392,11 +392,7 @@ def forward( class MixtralRotaryEmbedding(nn.Module): - def __init__( - self, - config: MixtralConfig, - device=None, - ): + def __init__(self, config: MixtralConfig, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 11d3d99f4f72c9..8f43f303dc8869 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -274,11 +274,7 @@ def forward( class OlmoRotaryEmbedding(nn.Module): - def __init__( - self, - config: OlmoConfig, - device=None, - ): + def __init__(self, config: OlmoConfig, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 49ae798e7f1101..dc5893cad55473 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -275,11 +275,7 @@ def forward( class Olmo2RotaryEmbedding(nn.Module): - def __init__( - self, - config: Olmo2Config, - device=None, - ): + def __init__(self, config: Olmo2Config, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 477896decd5318..b4079c3ef7d417 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -270,11 +270,7 @@ def forward( class PhiRotaryEmbedding(nn.Module): - def __init__( - self, - config: PhiConfig, - device=None, - ): + def __init__(self, config: PhiConfig, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index c8467c8df0424c..56bfe364e4ebc9 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -318,11 +318,7 @@ def forward( class Phi3RotaryEmbedding(nn.Module): - def __init__( - self, - config: Phi3Config, - device=None, - ): + def __init__(self, config: Phi3Config, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" @@ -339,6 +335,9 @@ def __init__( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq + if self.rope_type == "longrope": + short_inv_freq, _ = self.rope_init_fn(self.config, device, seq_len=0, **self.rope_kwargs) + self.register_buffer("short_inv_freq", short_inv_freq, persistent=False) def _dynamic_frequency_update(self, position_ids, device): """ @@ -360,11 +359,17 @@ def _dynamic_frequency_update(self, position_ids, device): @torch.no_grad() def forward(self, x, position_ids): + inv_freq = self.inv_freq if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) + inv_freq = self.inv_freq + elif self.rope_type == "longrope": + seq_len = torch.max(position_ids) + 1 + if seq_len <= self.config.original_max_position_embeddings: + inv_freq = self.short_inv_freq # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type diff --git a/src/transformers/models/phi3/modular_phi3.py b/src/transformers/models/phi3/modular_phi3.py index 02f9bbdae5d9fd..0b52e5bc4e9f4d 100644 --- a/src/transformers/models/phi3/modular_phi3.py +++ b/src/transformers/models/phi3/modular_phi3.py @@ -213,7 +213,40 @@ def forward( class Phi3RotaryEmbedding(MistralRotaryEmbedding): - pass + def __init__(self, config: Phi3Config, device=None): + super().__init__(config, device) + if self.rope_type == "longrope": + short_inv_freq, _ = self.rope_init_fn(self.config, device, seq_len=0, **self.rope_kwargs) + self.register_buffer("short_inv_freq", short_inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids): + inv_freq = self.inv_freq + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + inv_freq = self.inv_freq + elif self.rope_type == "longrope": + seq_len = torch.max(position_ids) + 1 + if seq_len <= self.config.original_max_position_embeddings: + inv_freq = self.short_inv_freq + + # Core RoPE block + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class Phi3PreTrainedModel(MistralPreTrainedModel): diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 36fb1ddf1390ac..c0e0d88431c268 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -283,11 +283,7 @@ def forward( class Qwen2RotaryEmbedding(nn.Module): - def __init__( - self, - config: Qwen2Config, - device=None, - ): + def __init__(self, config: Qwen2Config, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 3b4fdbcb81ccc4..a510ca1e1ca26a 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -274,11 +274,7 @@ def forward( class Starcoder2RotaryEmbedding(nn.Module): - def __init__( - self, - config: Starcoder2Config, - device=None, - ): + def __init__(self, config: Starcoder2Config, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/tests/models/phi3/test_modeling_phi3.py b/tests/models/phi3/test_modeling_phi3.py index 2c5557dfd67aae..6ec663c6636fa6 100644 --- a/tests/models/phi3/test_modeling_phi3.py +++ b/tests/models/phi3/test_modeling_phi3.py @@ -459,6 +459,9 @@ def test_model_rope_scaling_short_long_factor(self, scaling_type): "long_factor": [5.0 for _ in range(n_factors)], } input_tensor = ids_tensor([1, 4090], config.vocab_size) + # Make sure we don't have padding tokens. If this is the case, then the actual number of "true" tokens may be shorter + # than `config.original_max_position_embeddings + 5`, invalidating this test + input_tensor[input_tensor == config.pad_token_id] += 1 model = Phi3ForCausalLM(config) model.to(torch_device) model.eval()