diff --git a/tests/test_model.py b/tests/test_model.py index 5d8b069fc1..69aa05a362 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -675,9 +675,10 @@ def test_forward_with_padding(attention_impl: str, device: str, right_padding_output[0, :3], left_padding_output[0, 3:], atol=1e-6 if attention_impl == 'torch' else 1e-8) - if not (alibi or rope): + if not (alibi or (rope and pos_emb_config['rope_imp'] == 'dail')): # check that right padding and middle padding produce the same output # Note: alibi not implemented for middle padding. + # Note: dail implementation of rope does not support middle padding. assert torch.allclose( right_padding_output[0, :3], middle_padding_output[0, [0, 1, 5]], @@ -695,9 +696,10 @@ def test_forward_with_padding(attention_impl: str, device: str, right_padding_output[0, :3], batched_output[0, :3], atol=1e-6 if attention_impl == 'torch' else 1e-8) - if not (alibi or rope): + if not (alibi or (rope and pos_emb_config['rope_imp'] == 'dail')): # check that middle padding and middle padding in a batch produce the same output # Note: alibi not implemented for middle padding. + # Note: dail implementation of rope does not support middle padding. assert torch.allclose( middle_padding_output[0], batched_output[1, :],