Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Oct 17, 2023
1 parent dc58fc7 commit ed5a477
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
23 changes: 17 additions & 6 deletions llmfoundry/models/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
# Code modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
class RotaryEmbedding(torch.nn.Module):

def __init__(self, dim: int, max_position_embeddings: int, base: int,
device: torch.device):
def __init__(self,
dim: int,
max_position_embeddings: int,
base: int,
device=None):
super().__init__()

self.dim = dim
Expand Down Expand Up @@ -60,8 +63,12 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
Credits to the Reddit user /u/kaiokendev
"""

def __init__(self, dim: int, max_position_embeddings: int, base: int,
device: torch.device, scaling_factor: float):
def __init__(self,
dim: int,
max_position_embeddings: int,
base: int,
scaling_factor: float,
device=None):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)

Expand Down Expand Up @@ -90,8 +97,12 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""

def __init__(self, dim: int, max_position_embeddings: int, base: int,
device: torch.device, scaling_factor: float):
def __init__(self,
dim: int,
max_position_embeddings: int,
base: int,
scaling_factor: float,
device=None):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)

Expand Down
5 changes: 1 addition & 4 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,18 @@ def _rotary_embedding(config: MPTConfig):
if config.attn_config['rope_scaling']['type'] == 'no_scaling':
return RotaryEmbedding(rope_head_dim,
max_position_embeddings=config.max_seq_len,
base=config.attn_config['rope_theta'],
device=config.init_device)
base=config.attn_config['rope_theta'])
elif config.attn_config['rope_scaling']['type'] == 'linear':
return LinearScalingRotaryEmbedding(
rope_head_dim,
max_position_embeddings=config.max_seq_len,
base=config.attn_config['rope_theta'],
device=config.init_device,
scaling_factor=config.attn_config['rope_scaling']['factor'])
elif config.attn_config['rope_scaling']['type'] == 'dynamic':
return DynamicNTKScalingRotaryEmbedding(
rope_head_dim,
max_position_embeddings=config.max_seq_len,
base=config.attn_config['rope_theta'],
device=config.init_device,
scaling_factor=config.attn_config['rope_scaling']['factor'])


Expand Down

0 comments on commit ed5a477

Please sign in to comment.