From 40d5471b34acb2dd81ee6aaf7c8accb08b1c3f4c Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 13 Nov 2023 09:33:34 -0800 Subject: [PATCH] Test UCX comms with Dask TCP and UCX protocols --- python/raft-dask/raft_dask/test/test_comms.py | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/python/raft-dask/raft_dask/test/test_comms.py b/python/raft-dask/raft_dask/test/test_comms.py index 68c9fee556..e18f70a718 100644 --- a/python/raft-dask/raft_dask/test/test_comms.py +++ b/python/raft-dask/raft_dask/test/test_comms.py @@ -266,7 +266,7 @@ def test_comm_split(client): @pytest.mark.ucx @pytest.mark.parametrize("n_trials", [1, 5]) -def test_send_recv(n_trials, client): +def test_send_recv_protocol_tcp(n_trials, client): cb = Comms(comms_p2p=True, verbose=True) cb.init() @@ -287,6 +287,29 @@ def test_send_recv(n_trials, client): assert list(map(lambda x: x.result(), dfs)) +@pytest.mark.ucx +@pytest.mark.parametrize("n_trials", [1, 5]) +def test_send_recv_protocol_ucx(n_trials, ucx_client): + + cb = Comms(comms_p2p=True, verbose=True) + cb.init() + + dfs = [ + ucx_client.submit( + func_test_send_recv, + cb.sessionId, + n_trials, + pure=False, + workers=[w], + ) + for w in cb.worker_addresses + ] + + wait(dfs, timeout=5) + + assert list(map(lambda x: x.result(), dfs)) + + @pytest.mark.nccl @pytest.mark.parametrize("n_trials", [1, 5]) def test_device_send_or_recv(n_trials, client):