Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support rope scaling #1391

Merged
merged 24 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def _validate_config(self) -> None:
'no_scaling',
'linear',
'dynamic',
'llama3',
]:
raise ValueError(
'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".',
Expand Down
64 changes: 62 additions & 2 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.models.llama.modeling_llama import LlamaConfig
from transformers.models.llama.modeling_llama import \
LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding
from transformers.models.llama.modeling_llama import \
LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding

Expand Down Expand Up @@ -88,14 +90,59 @@
log = logging.getLogger(__name__)


class InvalidConfigAccessError(KeyError):
pass


_ALLOWED_LLAMA_CONFIG_KEYS = {
# these are the only config keys that are set and are safe to read from
milocress marked this conversation as resolved.
Show resolved Hide resolved
'rope_scaling',
'rope_theta',
'max_position_embeddings',
'hidden_size',
'num_attention_heads',

# not set but llama modeling code tries to read this attribute
milocress marked this conversation as resolved.
Show resolved Hide resolved
'partial_rotary_factor',

# benign transformers attributes needed for __init__
milocress marked this conversation as resolved.
Show resolved Hide resolved
'_get_generation_defaults',
'label2id',
'id2label',
'torch_dtype',
'problem_type',
'__class__',
}


class PartialLlamaConfig(LlamaConfig):
milocress marked this conversation as resolved.
Show resolved Hide resolved

def __getattribute__(self, key: str):
if key not in _ALLOWED_LLAMA_CONFIG_KEYS:
raise InvalidConfigAccessError(key)

return super().__getattribute__(key)

def __getitem__(self, key: str):
if key not in _ALLOWED_LLAMA_CONFIG_KEYS:
raise InvalidConfigAccessError(key)

return super().__getitem__(key)

def _get_generation_defaults(self):
milocress marked this conversation as resolved.
Show resolved Hide resolved
return {}


def gen_rotary_embedding(
rope_head_dim: int,
rope_impl: str,
rope_theta: int,
rope_dail_config: dict,
rope_hf_config: dict,
max_seq_len: int,
d_model: int,
n_heads: int,
):
rope_head_dim = d_model // n_heads
if rope_impl == 'dail':
return DAILRotaryEmbedding(
dim=rope_head_dim,
Expand Down Expand Up @@ -134,6 +181,18 @@ def gen_rotary_embedding(
device=
'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu
milocress marked this conversation as resolved.
Show resolved Hide resolved
)
elif rope_hf_config['type'] == 'llama3':
llama_rope_config = {**rope_hf_config}
llama_rope_config['rope_type'] = rope_hf_config.get('type')
return LlamaRotaryEmbedding(
config=PartialLlamaConfig(
milocress marked this conversation as resolved.
Show resolved Hide resolved
rope_scaling=llama_rope_config,
rope_theta=rope_theta,
max_position_embeddings=max_seq_len,
hidden_size=d_model,
num_attention_heads=n_heads,
),
)
raise ValueError('rope_impl needs to be either dail or hf')


Expand Down Expand Up @@ -399,12 +458,13 @@ def __init__(self, config: MPTConfig):
if self.rope:
self.rope_impl = config.attn_config['rope_impl']
self.rotary_embedding = gen_rotary_embedding(
rope_head_dim=config.d_model // config.n_heads,
rope_impl=self.rope_impl,
rope_theta=config.attn_config['rope_theta'],
rope_dail_config=config.attn_config['rope_dail_config'],
rope_hf_config=config.attn_config['rope_hf_config'],
max_seq_len=self.config.max_seq_len,
d_model=config.d_model,
n_heads=config.n_heads,
)

if config.init_device != 'meta':
Expand Down
6 changes: 4 additions & 2 deletions tests/models/layers/test_flash_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,13 @@ def gen_bias(attn_impl: str):
rotary_emb_w_meta_info = None
if rope:
rotary_embedding = gen_rotary_embedding(
rope_head_dim=cfg.d_model // cfg.n_heads,
rope_impl=pos_emb_config['rope_impl'],
rope_theta=pos_emb_config['rope_theta'],
rope_dail_config=pos_emb_config.get('rope_dail_config', {}),
rope_hf_config=pos_emb_config.get('rope_hf_config', {}),
max_seq_len=s,
d_model=cfg.d_model,
n_heads=cfg.n_heads,
).to(device)
pos = torch.arange(s).unsqueeze(0).to(device=device)
# adjust the position indices to account for padding tokens
Expand Down Expand Up @@ -664,12 +665,13 @@ def gen_bias(attn_impl: str):
rotary_emb_w_meta_info = None
if rope:
rotary_embedding = gen_rotary_embedding(
rope_head_dim=cfg['d_model'] // cfg['n_heads'],
rope_impl=pos_emb_config['rope_impl'],
rope_theta=pos_emb_config['rope_theta'],
rope_dail_config=pos_emb_config.get('rope_dail_config', {}),
rope_hf_config=pos_emb_config.get('rope_hf_config', {}),
max_seq_len=s,
d_model=cfg['d_model'],
n_heads=cfg['n_heads'],
).to(device)
pos = torch.arange(s).unsqueeze(0).to(device=device)
# adjust the position indices to account for padding tokens
Expand Down
6 changes: 4 additions & 2 deletions tests/models/test_rope_dail_vs_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'):
}

dail_rope = gen_rotary_embedding(
rope_head_dim=cfg.d_model // cfg.n_heads,
rope_impl=dail_rope_config['rope_impl'],
rope_theta=dail_rope_config['rope_theta'],
rope_dail_config=dail_rope_config['rope_dail_config'],
rope_hf_config={},
max_seq_len=seq_len,
d_model=cfg.d_model,
n_heads=cfg.n_heads,
).to('cuda')
dail_rope_w_meta_info = {
'impl': 'dail',
Expand All @@ -92,12 +93,13 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'):
}

hf_rope = gen_rotary_embedding(
rope_head_dim=cfg.d_model // cfg.n_heads,
rope_impl=hf_rope_config['rope_impl'],
rope_theta=hf_rope_config['rope_theta'],
rope_dail_config={},
rope_hf_config=hf_rope_config['rope_hf_config'],
max_seq_len=seq_len,
d_model=cfg.d_model,
n_heads=cfg.n_heads,
).to('cuda')
pos = torch.arange(seq_len).unsqueeze(0).to(device='cuda')
# adjust the position indices to account for padding tokens
Expand Down
36 changes: 36 additions & 0 deletions tests/models/test_rope_scaling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
milocress marked this conversation as resolved.
Show resolved Hide resolved

from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding

rope_config = {
'rope_theta': 500000.0,
'rope_impl': 'hf',
'rope_hf_config': {
'factor': 8.0,
'low_freq_factor': 1.0,
'high_freq_factor': 4.0,
'original_max_position_embeddings': 8192,
'type': 'llama3',
},
}

rope_dail_config = {}
dakinggg marked this conversation as resolved.
Show resolved Hide resolved


def test_rope_scaling():
d_model = 128
n_heads = 32
max_seq_len = 131_000

embedding = gen_rotary_embedding(
d_model=d_model,
n_heads=n_heads,
rope_dail_config=rope_dail_config,
max_seq_len=max_seq_len,
**rope_config,
)

assert isinstance(embedding, LlamaRotaryEmbedding)
Loading