diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 0d52ee214d..059011a847 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -86,9 +86,9 @@ log = logging.getLogger(__name__) -def _rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int, - rope_dail_config: dict, rope_hf_config: dict, - max_seq_len: int): +def gen_rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int, + rope_dail_config: dict, rope_hf_config: dict, + max_seq_len: int): if rope_impl == 'dail': return DAILRotaryEmbedding( dim=rope_head_dim, @@ -127,6 +127,7 @@ def _rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int, device= 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) + raise ValueError('rope_impl needs to be either dail or hf') class MPTPreTrainedModel(PreTrainedModel): @@ -186,7 +187,7 @@ def __init__(self, config: MPTConfig): self.rope_impl = None if self.rope: self.rope_impl = config.attn_config['rope_impl'] - self.rotary_embedding = _rotary_embedding( + self.rotary_embedding = gen_rotary_embedding( rope_head_dim=config.d_model // config.n_heads, rope_impl=self.rope_impl, rope_theta=config.attn_config['rope_theta'], diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index cae9490cc3..3f79bf0c7e 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -6,7 +6,7 @@ from omegaconf import OmegaConf as om from llmfoundry.models.layers.attention import is_flash_v2_installed -from tests.test_rope_dail_vs_hf import gen_rotary_embedding +from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding def allclose_helper(t0: torch.Tensor, @@ -130,7 +130,10 @@ def gen_bias(attn_impl: str): if rope: rotary_embedding = gen_rotary_embedding( rope_head_dim=cfg.d_model // cfg.n_heads, - pos_emb_config=pos_emb_config, + rope_impl=pos_emb_config['rope_impl'], + rope_theta=pos_emb_config['rope_theta'], + rope_dail_config=pos_emb_config.get('rope_dail_config', {}), + rope_hf_config=pos_emb_config.get('rope_hf_config', {}), max_seq_len=s).to(device) pos = torch.arange(s).unsqueeze(0).to(device=device) # adjust the position indices to account for padding tokens diff --git a/tests/test_rope_dail_vs_hf.py b/tests/test_rope_dail_vs_hf.py index 55c6536871..9b2d471e19 100644 --- a/tests/test_rope_dail_vs_hf.py +++ b/tests/test_rope_dail_vs_hf.py @@ -4,67 +4,10 @@ import pytest import torch from composer.core.precision import get_precision_context +from omegaconf import OmegaConf as om from llmfoundry.models.layers.attention import is_flash_v2_installed - -if is_flash_v2_installed(): - from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding -from omegaconf import OmegaConf as om -from transformers.models.llama.modeling_llama import \ - LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding -from transformers.models.llama.modeling_llama import \ - LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding -from transformers.models.llama.modeling_llama import \ - LlamaRotaryEmbedding as HFRotaryEmbedding - - -def gen_rotary_embedding(rope_head_dim: int, pos_emb_config: dict, - max_seq_len: int): - if pos_emb_config['rope_impl'] == 'dail': - return DAILRotaryEmbedding( - dim=rope_head_dim, - base=pos_emb_config['rope_theta'], - interleaved=False, - scale_base=pos_emb_config['rope_dail_config']['xpos_scale_base'] if - (pos_emb_config['rope_dail_config']['type'] == 'xpos') else None, - pos_idx_in_fp32=pos_emb_config['rope_dail_config'] - ['pos_idx_in_fp32'], - device= - 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif pos_emb_config['rope_impl'] == 'hf': - if pos_emb_config['rope_hf_config']['type'] == 'no_scaling': - return HFRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=pos_emb_config['rope_theta'], - device= - 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif pos_emb_config['rope_hf_config']['type'] == 'linear': - return HFLinearScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=pos_emb_config['rope_theta'], - scaling_factor=pos_emb_config['rope_hf_config']['factor'], - device= - 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif pos_emb_config['rope_hf_config']['type'] == 'dynamic': - return HFDynamicNTKScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=pos_emb_config['rope_theta'], - scaling_factor=pos_emb_config['rope_hf_config']['factor'], - device= - 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - else: - raise ValueError( - f'Invalid scaling type: {pos_emb_config["rope_hf_config"]["type"]}' - ) - else: - raise ValueError(f'Invalid rope_impl: {pos_emb_config["rope_impl"]}') +from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding @pytest.mark.gpu @@ -128,10 +71,13 @@ def test_rope_dail_vs_hf(clip_qkv: bool, } } - dail_rope = gen_rotary_embedding(rope_head_dim=cfg.d_model // - cfg.n_heads, - pos_emb_config=dail_rope_config, - max_seq_len=seq_len).to('cuda') + dail_rope = gen_rotary_embedding( + rope_head_dim=cfg.d_model // cfg.n_heads, + rope_impl=dail_rope_config['rope_impl'], + rope_theta=dail_rope_config['rope_theta'], + rope_dail_config=dail_rope_config['rope_dail_config'], + rope_hf_config={}, + max_seq_len=seq_len).to('cuda') dail_rope_w_meta_info = { 'imp': 'dail', 'rotary_emb': dail_rope, @@ -139,9 +85,13 @@ def test_rope_dail_vs_hf(clip_qkv: bool, 'seq_len': seq_len, } - hf_rope = gen_rotary_embedding(rope_head_dim=cfg.d_model // cfg.n_heads, - pos_emb_config=hf_rope_config, - max_seq_len=seq_len).to('cuda') + hf_rope = gen_rotary_embedding( + rope_head_dim=cfg.d_model // cfg.n_heads, + rope_impl=hf_rope_config['rope_impl'], + rope_theta=hf_rope_config['rope_theta'], + rope_dail_config={}, + rope_hf_config=hf_rope_config['rope_hf_config'], + max_seq_len=seq_len).to('cuda') pos = torch.arange(seq_len).unsqueeze(0).to(device='cuda') # adjust the position indices to account for padding tokens pos = torch.clamp(