Skip to content

Commit

Permalink
give up on testing hf
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress committed Jul 24, 2024
1 parent cf460ec commit d48bcb4
Showing 1 changed file with 0 additions and 96 deletions.
96 changes: 0 additions & 96 deletions tests/models/test_rope_scaling.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import math
from typing import Tuple

import torch
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding

from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding
Expand All @@ -24,66 +19,6 @@
rope_dail_config = {}


def apply_scaling(freqs: torch.Tensor):
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen -
low_freq_factor) / (high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)


def precompute_freqs_cis(
dim: int,
end: int,
theta: float = 10000.0,
use_scaled: bool = False,
):
freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
if use_scaled:
freqs = apply_scaling(freqs)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)


def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)


def test_rope_scaling():
d_model = 128
n_heads = 32
Expand All @@ -98,34 +33,3 @@ def test_rope_scaling():
)

assert isinstance(embedding, LlamaRotaryEmbedding)

xq = torch.randn(1, max_seq_len, n_heads, d_model // n_heads)
xk = torch.randn(1, max_seq_len, n_heads, d_model // n_heads)
position_ids = torch.arange(max_seq_len).unsqueeze(0)

freqs_cis = precompute_freqs_cis(
d_model // n_heads,
max_seq_len,
rope_config['rope_theta'],
use_scaled=True,
)

rope_embeddings_q, rope_embeddings_k = embedding.forward(
xq,
position_ids,
), embedding.forward(xk, position_ids)

rope_embeddings_q, rope_embeddings_k = torch.stack(
rope_embeddings_q,
dim=-1,
), torch.stack(rope_embeddings_k, dim=-1)

rope_embeddings_q, rope_embeddings_k = rope_embeddings_q.reshape(
*rope_embeddings_q.shape[:-2],
-1,
), rope_embeddings_k.reshape(*rope_embeddings_k.shape[:-2], -1)

expected_q, expected_k = apply_rotary_emb(xq, xk, freqs_cis)

torch.testing.assert_close(rope_embeddings_q, expected_q)
torch.testing.assert_close(rope_embeddings_k, expected_k)

0 comments on commit d48bcb4

Please sign in to comment.