Skip to content

Commit

Permalink
Add dtype and change beta type to int from float
Browse files Browse the repository at this point in the history
  • Loading branch information
seungduk-yanolja authored May 27, 2024
1 parent 688606f commit eadc7f8
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,11 +523,12 @@ def __init__(
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: float = 32,
beta_slow: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
) -> None:
Expand All @@ -542,7 +543,7 @@ def __init__(
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
attn_factor)
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style)
is_neox_style, dtype)

def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(torch.arange(
Expand Down Expand Up @@ -684,7 +685,7 @@ def get_rope(
}
rotary_emb = DeepseekScalingRotaryEmbedding(
head_size, rotary_dim, original_max_position, base,
is_neox_style, scaling_factor, **extra_kwargs)
is_neox_style, scaling_factor, dtype, **extra_kwargs)
elif scaling_type == "su":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
Expand Down

0 comments on commit eadc7f8

Please sign in to comment.