Skip to content

Commit

Permalink
add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress committed Jul 24, 2024
1 parent 77fd401 commit 0682283
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 16 deletions.
37 changes: 25 additions & 12 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,30 +94,41 @@ class InvalidConfigAccessError(KeyError):
pass


_ALLOWED_LLAMA_CONFIG_KEYS = {
'rope_scaling',
'rope_theta',
'max_position_embeddings',
'hidden_size',
'num_attention_heads',
'_get_generation_defaults',
'label2id',
'id2label',
'torch_dtype',
'problem_type',
'__class__',
'partial_rotary_factor',
}


class PartialLlamaConfig(LlamaConfig):
_ALLOWED_KEYS = {
'rope_scaling',
'rope_theta',
'max_position_embeddings',
'hidden_size',
'num_attention_heads',
}

def __getattribute__(self, key: str):
if key not in self._ALLOWED_KEYS:
if key not in _ALLOWED_LLAMA_CONFIG_KEYS:
raise InvalidConfigAccessError(key)

return super().__getattribute__(key)

def __getitem__(self, key: str):
if key not in self._ALLOWED_KEYS:
if key not in _ALLOWED_LLAMA_CONFIG_KEYS:
raise InvalidConfigAccessError(key)

return super().__getitem__(key)

def _get_generation_defaults(self):
return {}


def gen_rotary_embedding(
rope_head_dim: int,
rope_impl: str,
rope_theta: int,
rope_dail_config: dict,
Expand All @@ -126,6 +137,7 @@ def gen_rotary_embedding(
d_model: int,
n_heads: int,
):
rope_head_dim = d_model // n_heads
if rope_impl == 'dail':
return DAILRotaryEmbedding(
dim=rope_head_dim,
Expand Down Expand Up @@ -165,9 +177,11 @@ def gen_rotary_embedding(
'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu
)
elif rope_hf_config['type'] == 'llama3':
llama_rope_config = {**rope_hf_config}
llama_rope_config['rope_type'] = rope_hf_config.pop('type')
return LlamaRotaryEmbedding(
config=PartialLlamaConfig(
rope_scaling=rope_hf_config,
rope_scaling=llama_rope_config,
rope_theta=rope_theta,
max_position_embeddings=max_seq_len,
hidden_size=d_model,
Expand Down Expand Up @@ -439,7 +453,6 @@ def __init__(self, config: MPTConfig):
if self.rope:
self.rope_impl = config.attn_config['rope_impl']
self.rotary_embedding = gen_rotary_embedding(
rope_head_dim=config.d_model // config.n_heads,
rope_impl=self.rope_impl,
rope_theta=config.attn_config['rope_theta'],
rope_dail_config=config.attn_config['rope_dail_config'],
Expand Down
2 changes: 0 additions & 2 deletions tests/models/layers/test_flash_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ def gen_bias(attn_impl: str):
rotary_emb_w_meta_info = None
if rope:
rotary_embedding = gen_rotary_embedding(
rope_head_dim=cfg.d_model // cfg.n_heads,
rope_impl=pos_emb_config['rope_impl'],
rope_theta=pos_emb_config['rope_theta'],
rope_dail_config=pos_emb_config.get('rope_dail_config', {}),
Expand Down Expand Up @@ -666,7 +665,6 @@ def gen_bias(attn_impl: str):
rotary_emb_w_meta_info = None
if rope:
rotary_embedding = gen_rotary_embedding(
rope_head_dim=cfg['d_model'] // cfg['n_heads'],
rope_impl=pos_emb_config['rope_impl'],
rope_theta=pos_emb_config['rope_theta'],
rope_dail_config=pos_emb_config.get('rope_dail_config', {}),
Expand Down
2 changes: 0 additions & 2 deletions tests/models/test_rope_dail_vs_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'):
}

dail_rope = gen_rotary_embedding(
rope_head_dim=cfg.d_model // cfg.n_heads,
rope_impl=dail_rope_config['rope_impl'],
rope_theta=dail_rope_config['rope_theta'],
rope_dail_config=dail_rope_config['rope_dail_config'],
Expand All @@ -94,7 +93,6 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'):
}

hf_rope = gen_rotary_embedding(
rope_head_dim=cfg.d_model // cfg.n_heads,
rope_impl=hf_rope_config['rope_impl'],
rope_theta=hf_rope_config['rope_theta'],
rope_dail_config={},
Expand Down
36 changes: 36 additions & 0 deletions tests/models/test_rope_scaling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding

from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding

rope_config = {
'rope_theta': 500000.0,
'rope_impl': 'hf',
'rope_hf_config': {
'factor': 8.0,
'low_freq_factor': 1.0,
'high_freq_factor': 4.0,
'original_max_position_embeddings': 8192,
'type': 'llama3',
},
}

rope_dail_config = {}


def test_rope_scaling():
d_model = 128
n_heads = 32
max_seq_len = 131_000

embedding = gen_rotary_embedding(
d_model=d_model,
n_heads=n_heads,
rope_dail_config=rope_dail_config,
max_seq_len=max_seq_len,
**rope_config,
)

assert isinstance(embedding, LlamaRotaryEmbedding)

0 comments on commit 0682283

Please sign in to comment.