Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Nov 18, 2023
1 parent f23a3f6 commit c125b3b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 40 deletions.
18 changes: 12 additions & 6 deletions tests/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from omegaconf import OmegaConf as om

from llmfoundry.models.layers import attention
from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding

Expand All @@ -17,8 +18,14 @@ def allclose_helper(t0: torch.Tensor,


@pytest.mark.gpu
@pytest.mark.parametrize('attn_impl_0', ['flash', 'triton', 'torch'])
@pytest.mark.parametrize('attn_impl_1', ['flash', 'triton', 'torch'])
@pytest.mark.parametrize('attn_impl_0,attn_impl_1', [
('flash', 'flash'),
('flash', 'triton'),
('flash', 'torch'),
('triton', 'triton'),
('triton', 'torch'),
('torch', 'torch'),
])
@pytest.mark.parametrize('clip_qkv', [True, False])
@pytest.mark.parametrize('qk_ln', [True, False])
@pytest.mark.parametrize('pos_emb_config', [{
Expand Down Expand Up @@ -62,11 +69,10 @@ def test_attn_impl(attn_impl_0: str,
Includes testing with and without attn_clip_qkv, attn_qk_ln, alibi, and
rope.
"""
from llmfoundry.models.layers import attention
alibi = pos_emb_config['alibi']
rope = pos_emb_config['rope']
if alibi and (attn_impl_0 == 'flash' or attn_impl_1 == 'flash'):
pytest.xfail('flash attn does not support alibi')
pytest.skip('flash attn does not support alibi')

if rope and (pos_emb_config['rope_impl']
== 'dail') and (not is_flash_v2_installed()):
Expand All @@ -81,7 +87,7 @@ def test_attn_impl(attn_impl_0: str,
'qk_ln': qk_ln,
})

n, s, f = 2, 16, cfg.d_model
n, s, f = 2, 4, cfg.d_model
assert cfg.d_model % cfg.n_heads == 0
if attn_type == 'grouped_query_attention':
cfg.kv_n_heads = 2
Expand Down Expand Up @@ -311,7 +317,7 @@ def test_grouped_attention_heads(attn_impl: str,
'kv_n_heads': kv_n_heads
})

n, s, f = 2, 16, cfg.d_model
n, s, f = 2, 4, cfg.d_model

mmhsa = attention.GroupedQueryAttention(**cfg).to(device)

Expand Down
31 changes: 16 additions & 15 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def test_callback_inits():
@pytest.mark.parametrize('log_to_mlflow', [True, False])
@pytest.mark.parametrize(
'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints',
[('3ba', '2ba', '7ba', 3, 4), ('1dur', '2ba', '1ep', 1, 4)])
[('3ba', '2ba', '4ba', 2, 2), ('1dur', '2ba', '1ep', 1, 2)])
@patch('os.cpu_count', MagicMock(return_value=None))
def test_huggingface_conversion_callback_interval(
tmp_path: pathlib.Path, log_to_mlflow: bool, hf_save_interval: str,
Expand All @@ -273,12 +273,12 @@ def test_huggingface_conversion_callback_interval(

dist.initialize_dist(get_device('gpu'))

max_seq_len = 16
device_batch_size = 1
dataset_size = 14
max_seq_len = 4
device_batch_size = 2
dataset_size = 8
precision_str = 'bfloat16'
precision = torch.bfloat16
batches_per_epoch = math.ceil(dataset_size / (device_batch_size * 2))
batches_per_epoch = math.ceil(dataset_size / device_batch_size)

checkpointer_callback = HuggingFaceCheckpointer(
save_folder=os.path.join(tmp_path, 'checkpoints'),
Expand All @@ -292,7 +292,7 @@ def test_huggingface_conversion_callback_interval(
model_cfg = {
'name': 'mpt_causal_lm',
'init_device': 'cpu',
'd_model': 128,
'd_model': 64,
'n_heads': 2,
'n_layers': 2,
'expansion_ratio': 4,
Expand Down Expand Up @@ -401,7 +401,7 @@ def test_huggingface_conversion_callback_interval(
]
assert len(normal_checkpoints) == expected_normal_checkpoints
assert len(huggingface_checkpoints) == expected_hf_checkpoints

print(huggingface_checkpoints)
# Load the last huggingface checkpoint
loaded_model = transformers.AutoModelForCausalLM.from_pretrained(
os.path.join(tmp_path, 'checkpoints', 'huggingface',
Expand All @@ -428,7 +428,7 @@ def test_huggingface_conversion_callback_interval(
trust_remote_code=True,
)

check_hf_model_equivalence(trainer.state.model.module.model.to(precision),
check_hf_model_equivalence(trainer.state.model.model.to(precision),
loaded_model)
check_hf_tokenizer_equivalence(tokenizer, loaded_tokenizer)

Expand All @@ -442,14 +442,15 @@ def test_huggingface_conversion_callback_interval(
[('mpt', True), ('mpt', False), ('neo', None), ('llama2', None)],
)
@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None])
@pytest.mark.parametrize(
'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints',
[('3ba', '2ba', '7ba', 3, 4)])
@patch('os.cpu_count', MagicMock(return_value=None))
def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
tie_word_embeddings: bool,
fsdp_state_dict_type: Optional[str],
hf_save_interval: str,
save_interval: str, max_duration: str,
expected_hf_checkpoints: int,
expected_normal_checkpoints: int):
def test_huggingface_conversion_callback(
model: str, tmp_path: pathlib.Path, tie_word_embeddings: bool,
fsdp_state_dict_type: Optional[str],
hf_save_interval: str, save_interval: str, max_duration: str,
expected_hf_checkpoints: int, expected_normal_checkpoints: int):
delete_transformers_cache()

dist.initialize_dist(get_device('gpu'))
Expand Down
25 changes: 6 additions & 19 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,11 +874,11 @@ def test_generate_with_device_map(tmp_path: pathlib.Path, world_size: int,
save_path = tmp_path / 'test-device-map'
hf_config = MPTConfig(
init_device='cpu',
d_model=128,
d_model=64,
n_heads=4,
n_layers=2,
expansion_ratio=2,
max_seq_len=2048,
max_seq_len=4,
emb_pdrop=0.1,
resid_pdrop=0.2,
attn_config={
Expand Down Expand Up @@ -914,8 +914,8 @@ def test_generate_with_device_map(tmp_path: pathlib.Path, world_size: int,
)
with torch.autocast('cuda', dtype=torch.bfloat16):
_ = pipe(
'The quick fox jumped over',
max_length=10,
'The fox',
max_new_tokens=2,
do_sample=True,
)

Expand Down Expand Up @@ -1482,18 +1482,17 @@ def test_model_to(attention_impl: str, pos_emb_config: dict,

hf_config = MPTConfig(
init_device='cpu',
d_model=128,
d_model=64,
n_heads=4,
n_layers=2,
expansion_ratio=2,
max_seq_len=2048,
max_seq_len=4,
emb_pdrop=0.1,
resid_pdrop=0.2,
attn_config={
'attn_impl': attention_impl,
**pos_emb_config,
},
use_cache=True,
init_config={
'name': 'baseline_',
'init_std': 0.02,
Expand All @@ -1509,11 +1508,9 @@ def test_model_to(attention_impl: str, pos_emb_config: dict,
input_ids = torch.tensor([[11274, 16390, 11]]).to('cuda')
attention_mask = torch.tensor([[1, 1, 1]]).bool().to('cuda')

# with get_precision_context('amp_bf16'):
_ = mpt(input_ids, attention_mask=attention_mask)

# move the model around using different methods
mpt = mpt.bfloat16()
mpt = mpt.to('cpu')

# verify the model still works
Expand All @@ -1523,23 +1520,13 @@ def test_model_to(attention_impl: str, pos_emb_config: dict,
_ = mpt(input_ids.to('cpu'),
attention_mask=attention_mask.to('cpu'))

mpt = mpt.cuda()
mpt = mpt.bfloat16()

# verify the model still works
if attention_impl == 'torch':
with torch.autocast('cuda', dtype=torch.bfloat16, enabled=True):
_ = mpt(input_ids, attention_mask=attention_mask)

mpt = mpt.to('cpu')
mpt = mpt.float()

# verify the model still works
if attention_impl == 'torch' and not (
pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail'):
_ = mpt(input_ids.to('cpu'), attention_mask=attention_mask.to('cpu'))

mpt = mpt.half()
mpt = mpt.to(0) # move to rank0
mpt = mpt.bfloat16()

Expand Down

0 comments on commit c125b3b

Please sign in to comment.