diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 4ba4be44..0d1e4632 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -1,5 +1,4 @@ import os -from contextlib import nullcontext as does_not_raise from typing import Any import pytest @@ -147,25 +146,13 @@ def _test_column_linear( @pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)]) -@pytest.mark.parametrize( - "tp_mode,async_communication,expectation", - [ - pytest.param(TensorParallelLinearMode.ALL_REDUCE, False, does_not_raise()), - pytest.param(TensorParallelLinearMode.REDUCE_SCATTER, False, does_not_raise()), - pytest.param(TensorParallelLinearMode.REDUCE_SCATTER, True, does_not_raise()), - pytest.param( - TensorParallelLinearMode.ALL_REDUCE, - True, - pytest.raises( - ValueError, - match=r"Cf this: https://github.com/huggingface/nanotron/blob/bf82cded9eef1ba77864b48e65bffefad4076339/src/nanotron/core/parallel/tensor_parallel/nn.py#L132", - ), - ), - ], -) +@pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode)) +@pytest.mark.parametrize("async_communication", [False, True]) def test_row_linear( tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode, async_communication: bool, expectation: Any ): + if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication: + pytest.skip("ALL_REDUCE mode does not support async communication") init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)( tp_mode=tp_mode, async_communication=async_communication, expectation=expectation )