Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Nov 18, 2023
1 parent e96b234 commit f02034b
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,11 +519,12 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool):
assert block.resid_ffn_dropout.p == 0.2


@pytest.mark.gpu
@pytest.mark.parametrize('attention_impl,device', [('torch', 'cpu'),
('flash', 'gpu'),
('triton', 'gpu'),
('torch', 'gpu')])
@pytest.mark.parametrize('attn_impl', [
'torch',
pytest.param('flash', marks=pytest.mark.gpu),
pytest.param('triton', marks=pytest.mark.gpu),
pytest.param('torch', marks=pytest.mark.gpu),
])
@pytest.mark.parametrize('pos_emb_config', [{
'alibi': True,
'rope': False
Expand All @@ -548,10 +549,11 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool):
},
}])
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_sequence_id_based_masking(attention_impl: str, device: str,
def test_sequence_id_based_masking(attention_impl: str,
pos_emb_config: dict,
tie_word_embeddings: bool):
# Testing the output of concatenated sequence with sequence id masking vs individual sequences.
device = 'gpu' if torch.cuda.is_available() else 'cpu'
if not torch.cuda.is_available() and device == 'gpu':
pytest.skip(
f'This test requires CUDA to be available in order to run with {attention_impl} attention.'
Expand Down

0 comments on commit f02034b

Please sign in to comment.