diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 0b56f687a3c..713cdf06b39 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -89,43 +89,6 @@ def static(cls, config, dim, base, device): if rope_type == "linear": pass - elif rope_type == "longrope": - short_factor = torch.tensor( - rope_scaling["short_factor"], dtype=torch.float32, device=device - ) - long_factor = torch.tensor( - rope_scaling["long_factor"], dtype=torch.float32, device=device - ) - short_mscale = rope_scaling["short_mscale"] - long_mscale = rope_scaling["long_mscale"] - original_max_position_embeddings = ( - config.original_max_position_embeddings - ) - return Phi3LongRoPEScaledRotaryEmbedding( - short_inv_freq=1.0 - / ( - short_factor - * base - ** ( - torch.arange(0, dim, 2, device=device, dtype=torch.float32) - / dim - ) - ), - long_inv_freq=1.0 - / ( - long_factor - * base - ** ( - torch.arange(0, dim, 2, device=device, dtype=torch.float32) - / dim - ) - ), - max_position_embeddings=config.max_position_embeddings, - short_mscale=short_mscale, - long_mscale=long_mscale, - original_max_position_embeddings=original_max_position_embeddings, - ) - elif rope_type == "dynamic": scaling_factor = rope_scaling["factor"] return DynamicPositionRotaryEmbedding( @@ -203,6 +166,20 @@ def static(cls, config, dim, base, device): 1 + math.log(scale) / math.log(original_max_position_embeddings) ) + # if short_mscale and long_mscale are provided we need to scale the freqs + # using the Phi3LongRoPEScaledRotaryEmbedding + if ("short_mscale" in rope_scaling) and ("long_mscale" in rope_scaling): + short_mscale = rope_scaling["short_mscale"] + long_mscale = rope_scaling["long_mscale"] + return Phi3LongRoPEScaledRotaryEmbedding( + short_inv_freq=short_inv_freq, + long_inv_freq=long_inv_freq, + max_position_embeddings=config.max_position_embeddings, + short_mscale=short_mscale, + long_mscale=long_mscale, + original_max_position_embeddings=original_max_position_embeddings, + ) + return SuRotaryEmbedding( short_inv_freq=short_inv_freq, long_inv_freq=long_inv_freq,