Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Nov 3, 2023
1 parent e59f784 commit c602a06
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 72 deletions.
9 changes: 5 additions & 4 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@
log = logging.getLogger(__name__)


def _rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int,
rope_dail_config: dict, rope_hf_config: dict,
max_seq_len: int):
def gen_rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int,
rope_dail_config: dict, rope_hf_config: dict,
max_seq_len: int):
if rope_impl == 'dail':
return DAILRotaryEmbedding(
dim=rope_head_dim,
Expand Down Expand Up @@ -127,6 +127,7 @@ def _rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int,
device=
'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu
)
raise ValueError('rope_impl needs to be either dail or hf')


class MPTPreTrainedModel(PreTrainedModel):
Expand Down Expand Up @@ -186,7 +187,7 @@ def __init__(self, config: MPTConfig):
self.rope_impl = None
if self.rope:
self.rope_impl = config.attn_config['rope_impl']
self.rotary_embedding = _rotary_embedding(
self.rotary_embedding = gen_rotary_embedding(
rope_head_dim=config.d_model // config.n_heads,
rope_impl=self.rope_impl,
rope_theta=config.attn_config['rope_theta'],
Expand Down
7 changes: 5 additions & 2 deletions tests/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from omegaconf import OmegaConf as om

from llmfoundry.models.layers.attention import is_flash_v2_installed
from tests.test_rope_dail_vs_hf import gen_rotary_embedding
from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding


def allclose_helper(t0: torch.Tensor,
Expand Down Expand Up @@ -130,7 +130,10 @@ def gen_bias(attn_impl: str):
if rope:
rotary_embedding = gen_rotary_embedding(
rope_head_dim=cfg.d_model // cfg.n_heads,
pos_emb_config=pos_emb_config,
rope_impl=pos_emb_config['rope_impl'],
rope_theta=pos_emb_config['rope_theta'],
rope_dail_config=pos_emb_config.get('rope_dail_config', {}),
rope_hf_config=pos_emb_config.get('rope_hf_config', {}),
max_seq_len=s).to(device)
pos = torch.arange(s).unsqueeze(0).to(device=device)
# adjust the position indices to account for padding tokens
Expand Down
82 changes: 16 additions & 66 deletions tests/test_rope_dail_vs_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,67 +4,10 @@
import pytest
import torch
from composer.core.precision import get_precision_context
from omegaconf import OmegaConf as om

from llmfoundry.models.layers.attention import is_flash_v2_installed

if is_flash_v2_installed():
from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding
from omegaconf import OmegaConf as om
from transformers.models.llama.modeling_llama import \
LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding
from transformers.models.llama.modeling_llama import \
LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding


def gen_rotary_embedding(rope_head_dim: int, pos_emb_config: dict,
max_seq_len: int):
if pos_emb_config['rope_impl'] == 'dail':
return DAILRotaryEmbedding(
dim=rope_head_dim,
base=pos_emb_config['rope_theta'],
interleaved=False,
scale_base=pos_emb_config['rope_dail_config']['xpos_scale_base'] if
(pos_emb_config['rope_dail_config']['type'] == 'xpos') else None,
pos_idx_in_fp32=pos_emb_config['rope_dail_config']
['pos_idx_in_fp32'],
device=
'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu
)
elif pos_emb_config['rope_impl'] == 'hf':
if pos_emb_config['rope_hf_config']['type'] == 'no_scaling':
return HFRotaryEmbedding(
rope_head_dim,
max_position_embeddings=max_seq_len,
base=pos_emb_config['rope_theta'],
device=
'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu
)
elif pos_emb_config['rope_hf_config']['type'] == 'linear':
return HFLinearScalingRotaryEmbedding(
rope_head_dim,
max_position_embeddings=max_seq_len,
base=pos_emb_config['rope_theta'],
scaling_factor=pos_emb_config['rope_hf_config']['factor'],
device=
'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu
)
elif pos_emb_config['rope_hf_config']['type'] == 'dynamic':
return HFDynamicNTKScalingRotaryEmbedding(
rope_head_dim,
max_position_embeddings=max_seq_len,
base=pos_emb_config['rope_theta'],
scaling_factor=pos_emb_config['rope_hf_config']['factor'],
device=
'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu
)
else:
raise ValueError(
f'Invalid scaling type: {pos_emb_config["rope_hf_config"]["type"]}'
)
else:
raise ValueError(f'Invalid rope_impl: {pos_emb_config["rope_impl"]}')
from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding


@pytest.mark.gpu
Expand Down Expand Up @@ -128,20 +71,27 @@ def test_rope_dail_vs_hf(clip_qkv: bool,
}
}

dail_rope = gen_rotary_embedding(rope_head_dim=cfg.d_model //
cfg.n_heads,
pos_emb_config=dail_rope_config,
max_seq_len=seq_len).to('cuda')
dail_rope = gen_rotary_embedding(
rope_head_dim=cfg.d_model // cfg.n_heads,
rope_impl=dail_rope_config['rope_impl'],
rope_theta=dail_rope_config['rope_theta'],
rope_dail_config=dail_rope_config['rope_dail_config'],
rope_hf_config={},
max_seq_len=seq_len).to('cuda')
dail_rope_w_meta_info = {
'imp': 'dail',
'rotary_emb': dail_rope,
'offset_info': 0,
'seq_len': seq_len,
}

hf_rope = gen_rotary_embedding(rope_head_dim=cfg.d_model // cfg.n_heads,
pos_emb_config=hf_rope_config,
max_seq_len=seq_len).to('cuda')
hf_rope = gen_rotary_embedding(
rope_head_dim=cfg.d_model // cfg.n_heads,
rope_impl=hf_rope_config['rope_impl'],
rope_theta=hf_rope_config['rope_theta'],
rope_dail_config={},
rope_hf_config=hf_rope_config['rope_hf_config'],
max_seq_len=seq_len).to('cuda')
pos = torch.arange(seq_len).unsqueeze(0).to(device='cuda')
# adjust the position indices to account for padding tokens
pos = torch.clamp(
Expand Down

0 comments on commit c602a06

Please sign in to comment.