Skip to content

Commit

Permalink
fix: consolidate long rope paths
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Sep 6, 2024
1 parent 13a69a8 commit 4b8856d
Showing 1 changed file with 14 additions and 37 deletions.
51 changes: 14 additions & 37 deletions server/text_generation_server/layers/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,43 +89,6 @@ def static(cls, config, dim, base, device):

if rope_type == "linear":
pass
elif rope_type == "longrope":
short_factor = torch.tensor(
rope_scaling["short_factor"], dtype=torch.float32, device=device
)
long_factor = torch.tensor(
rope_scaling["long_factor"], dtype=torch.float32, device=device
)
short_mscale = rope_scaling["short_mscale"]
long_mscale = rope_scaling["long_mscale"]
original_max_position_embeddings = (
config.original_max_position_embeddings
)
return Phi3LongRoPEScaledRotaryEmbedding(
short_inv_freq=1.0
/ (
short_factor
* base
** (
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
/ dim
)
),
long_inv_freq=1.0
/ (
long_factor
* base
** (
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
/ dim
)
),
max_position_embeddings=config.max_position_embeddings,
short_mscale=short_mscale,
long_mscale=long_mscale,
original_max_position_embeddings=original_max_position_embeddings,
)

elif rope_type == "dynamic":
scaling_factor = rope_scaling["factor"]
return DynamicPositionRotaryEmbedding(
Expand Down Expand Up @@ -203,6 +166,20 @@ def static(cls, config, dim, base, device):
1 + math.log(scale) / math.log(original_max_position_embeddings)
)

# if short_mscale and long_mscale are provided we need to scale the freqs
# using the Phi3LongRoPEScaledRotaryEmbedding
if ("short_mscale" in rope_scaling) and ("long_mscale" in rope_scaling):
short_mscale = rope_scaling["short_mscale"]
long_mscale = rope_scaling["long_mscale"]
return Phi3LongRoPEScaledRotaryEmbedding(
short_inv_freq=short_inv_freq,
long_inv_freq=long_inv_freq,
max_position_embeddings=config.max_position_embeddings,
short_mscale=short_mscale,
long_mscale=long_mscale,
original_max_position_embeddings=original_max_position_embeddings,
)

return SuRotaryEmbedding(
short_inv_freq=short_inv_freq,
long_inv_freq=long_inv_freq,
Expand Down

0 comments on commit 4b8856d

Please sign in to comment.