diff --git a/tests/models/test_rope_scaling.py b/tests/models/test_rope_scaling.py index 33f7f0b383..484ac2b23a 100644 --- a/tests/models/test_rope_scaling.py +++ b/tests/models/test_rope_scaling.py @@ -1,10 +1,5 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 - -import math -from typing import Tuple - -import torch from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding @@ -24,66 +19,6 @@ rope_dail_config = {} -def apply_scaling(freqs: torch.Tensor): - # Values obtained from grid search - scale_factor = 8 - low_freq_factor = 1 - high_freq_factor = 4 - old_context_len = 8192 # original llama3 length - - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - new_freqs = [] - for freq in freqs: - wavelen = 2 * math.pi / freq - if wavelen < high_freq_wavelen: - new_freqs.append(freq) - elif wavelen > low_freq_wavelen: - new_freqs.append(freq / scale_factor) - else: - assert low_freq_wavelen != high_freq_wavelen - smooth = (old_context_len / wavelen - - low_freq_factor) / (high_freq_factor - low_freq_factor) - new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) - return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) - - -def precompute_freqs_cis( - dim: int, - end: int, - theta: float = 10000.0, - use_scaled: bool = False, -): - freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device, dtype=torch.float32) - if use_scaled: - freqs = apply_scaling(freqs) - freqs = torch.outer(t, freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - - -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): - ndim = x.ndim - assert 0 <= 1 < ndim - assert freqs_cis.shape == (x.shape[1], x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - def test_rope_scaling(): d_model = 128 n_heads = 32 @@ -98,34 +33,3 @@ def test_rope_scaling(): ) assert isinstance(embedding, LlamaRotaryEmbedding) - - xq = torch.randn(1, max_seq_len, n_heads, d_model // n_heads) - xk = torch.randn(1, max_seq_len, n_heads, d_model // n_heads) - position_ids = torch.arange(max_seq_len).unsqueeze(0) - - freqs_cis = precompute_freqs_cis( - d_model // n_heads, - max_seq_len, - rope_config['rope_theta'], - use_scaled=True, - ) - - rope_embeddings_q, rope_embeddings_k = embedding.forward( - xq, - position_ids, - ), embedding.forward(xk, position_ids) - - rope_embeddings_q, rope_embeddings_k = torch.stack( - rope_embeddings_q, - dim=-1, - ), torch.stack(rope_embeddings_k, dim=-1) - - rope_embeddings_q, rope_embeddings_k = rope_embeddings_q.reshape( - *rope_embeddings_q.shape[:-2], - -1, - ), rope_embeddings_k.reshape(*rope_embeddings_k.shape[:-2], -1) - - expected_q, expected_k = apply_rotary_emb(xq, xk, freqs_cis) - - torch.testing.assert_close(rope_embeddings_q, expected_q) - torch.testing.assert_close(rope_embeddings_k, expected_k)