Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support rope scaling #1391

Merged
merged 24 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def _validate_config(self) -> None:
'no_scaling',
'linear',
'dynamic',
'llama3',
]:
raise ValueError(
'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".',
Expand Down
97 changes: 66 additions & 31 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,8 @@
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.models.llama.modeling_llama import \
LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding
from transformers.models.llama.modeling_llama import \
LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding
from transformers.models.llama.modeling_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding

Expand Down Expand Up @@ -88,14 +86,62 @@
log = logging.getLogger(__name__)


class InvalidConfigAccessError(KeyError):
pass


_ALLOWED_LLAMA_CONFIG_KEYS = {
# These are the only config keys that are set and are safe to read from
'rope_scaling',
'rope_theta',
'max_position_embeddings',
'hidden_size',
'num_attention_heads',

# Not set but llama modeling code tries to read this attribute
'partial_rotary_factor',

# Benign transformers attributes needed for __init__
'_get_generation_defaults',
'label2id',
'id2label',
'torch_dtype',
'problem_type',
'__class__',
}


class PartialLlamaConfig(LlamaConfig):
milocress marked this conversation as resolved.
Show resolved Hide resolved
"""Holds the rope config for Llama models and throws.

an `InvalidConfigAccessError` if any other config elements are read. This
class is necessary because the `LlamaRotaryEmbedding` class takes a full
`LlamaConfig` now instead of the old keyword arguments.
"""

def __getattribute__(self, key: str):
if key not in _ALLOWED_LLAMA_CONFIG_KEYS:
raise InvalidConfigAccessError(key)

return super().__getattribute__(key)

def __getitem__(self, key: str):
if key not in _ALLOWED_LLAMA_CONFIG_KEYS:
raise InvalidConfigAccessError(key)

return super().__getitem__(key)


def gen_rotary_embedding(
rope_head_dim: int,
rope_impl: str,
rope_theta: int,
rope_dail_config: dict,
rope_hf_config: dict,
max_seq_len: int,
d_model: int,
n_heads: int,
):
rope_head_dim = d_model // n_heads
if rope_impl == 'dail':
return DAILRotaryEmbedding(
dim=rope_head_dim,
Expand All @@ -108,32 +154,20 @@ def gen_rotary_embedding(
'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu
)
elif rope_impl == 'hf':
llama_rope_config = {**rope_hf_config}
llama_rope_config['rope_type'] = rope_hf_config.get('type')
milocress marked this conversation as resolved.
Show resolved Hide resolved
partial_llama_config = PartialLlamaConfig(
rope_scaling=llama_rope_config,
rope_theta=rope_theta,
max_position_embeddings=max_seq_len,
hidden_size=d_model,
num_attention_heads=n_heads,
)
if rope_hf_config['type'] == 'no_scaling':
return HFRotaryEmbeddingFoundry(
rope_head_dim,
max_position_embeddings=max_seq_len,
base=rope_theta,
device=
'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu
)
elif rope_hf_config['type'] == 'linear':
return HFLinearScalingRotaryEmbedding(
rope_head_dim,
max_position_embeddings=max_seq_len,
base=rope_theta,
scaling_factor=rope_hf_config['factor'],
device=
'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu
)
elif rope_hf_config['type'] == 'dynamic':
return HFDynamicNTKScalingRotaryEmbedding(
rope_head_dim,
max_position_embeddings=max_seq_len,
base=rope_theta,
scaling_factor=rope_hf_config['factor'],
device=
'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu
)
llama_rope_config['rope_type'] = 'default'
return HFRotaryEmbeddingFoundry(config=partial_llama_config)
elif rope_hf_config['type'] in {'llama3', 'linear', 'dynamic'}:
return LlamaRotaryEmbedding(config=partial_llama_config)
raise ValueError('rope_impl needs to be either dail or hf')


Expand Down Expand Up @@ -399,12 +433,13 @@ def __init__(self, config: MPTConfig):
if self.rope:
self.rope_impl = config.attn_config['rope_impl']
self.rotary_embedding = gen_rotary_embedding(
rope_head_dim=config.d_model // config.n_heads,
rope_impl=self.rope_impl,
rope_theta=config.attn_config['rope_theta'],
rope_dail_config=config.attn_config['rope_dail_config'],
rope_hf_config=config.attn_config['rope_hf_config'],
max_seq_len=self.config.max_seq_len,
d_model=config.d_model,
n_heads=config.n_heads,
)

if config.init_device != 'meta':
Expand Down
6 changes: 4 additions & 2 deletions tests/models/layers/test_flash_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,13 @@ def gen_bias(attn_impl: str):
rotary_emb_w_meta_info = None
if rope:
rotary_embedding = gen_rotary_embedding(
rope_head_dim=cfg.d_model // cfg.n_heads,
rope_impl=pos_emb_config['rope_impl'],
rope_theta=pos_emb_config['rope_theta'],
rope_dail_config=pos_emb_config.get('rope_dail_config', {}),
rope_hf_config=pos_emb_config.get('rope_hf_config', {}),
max_seq_len=s,
d_model=cfg.d_model,
n_heads=cfg.n_heads,
).to(device)
pos = torch.arange(s).unsqueeze(0).to(device=device)
# adjust the position indices to account for padding tokens
Expand Down Expand Up @@ -664,12 +665,13 @@ def gen_bias(attn_impl: str):
rotary_emb_w_meta_info = None
if rope:
rotary_embedding = gen_rotary_embedding(
rope_head_dim=cfg['d_model'] // cfg['n_heads'],
rope_impl=pos_emb_config['rope_impl'],
rope_theta=pos_emb_config['rope_theta'],
rope_dail_config=pos_emb_config.get('rope_dail_config', {}),
rope_hf_config=pos_emb_config.get('rope_hf_config', {}),
max_seq_len=s,
d_model=cfg['d_model'],
n_heads=cfg['n_heads'],
).to(device)
pos = torch.arange(s).unsqueeze(0).to(device=device)
# adjust the position indices to account for padding tokens
Expand Down
6 changes: 4 additions & 2 deletions tests/models/test_rope_dail_vs_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'):
}

dail_rope = gen_rotary_embedding(
rope_head_dim=cfg.d_model // cfg.n_heads,
rope_impl=dail_rope_config['rope_impl'],
rope_theta=dail_rope_config['rope_theta'],
rope_dail_config=dail_rope_config['rope_dail_config'],
rope_hf_config={},
max_seq_len=seq_len,
d_model=cfg.d_model,
n_heads=cfg.n_heads,
).to('cuda')
dail_rope_w_meta_info = {
'impl': 'dail',
Expand All @@ -92,12 +93,13 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'):
}

hf_rope = gen_rotary_embedding(
rope_head_dim=cfg.d_model // cfg.n_heads,
rope_impl=hf_rope_config['rope_impl'],
rope_theta=hf_rope_config['rope_theta'],
rope_dail_config={},
rope_hf_config=hf_rope_config['rope_hf_config'],
max_seq_len=seq_len,
d_model=cfg.d_model,
n_heads=cfg.n_heads,
).to('cuda')
pos = torch.arange(seq_len).unsqueeze(0).to(device='cuda')
# adjust the position indices to account for padding tokens
Expand Down
36 changes: 36 additions & 0 deletions tests/models/test_rope_scaling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
milocress marked this conversation as resolved.
Show resolved Hide resolved

from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding

rope_config = {
'rope_theta': 500000.0,
'rope_impl': 'hf',
'rope_hf_config': {
'factor': 8.0,
'low_freq_factor': 1.0,
'high_freq_factor': 4.0,
'original_max_position_embeddings': 8192,
'type': 'llama3',
},
}

rope_dail_config = {}
dakinggg marked this conversation as resolved.
Show resolved Hide resolved


def test_rope_scaling():
d_model = 128
n_heads = 32
max_seq_len = 65536

embedding = gen_rotary_embedding(
d_model=d_model,
n_heads=n_heads,
rope_dail_config=rope_dail_config,
max_seq_len=max_seq_len,
**rope_config,
)

assert isinstance(embedding, LlamaRotaryEmbedding)
Loading