From 0798c323a46c7724619152bf9e7b1379310e12ff Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 13 Nov 2023 09:35:10 -0800 Subject: [PATCH] Add `test_send_recv_protocol_ucxx` --- python/raft-dask/pytest.ini | 1 + python/raft-dask/raft_dask/test/conftest.py | 22 ++++++++++++++++++ python/raft-dask/raft_dask/test/test_comms.py | 23 +++++++++++++++++++ 3 files changed, 46 insertions(+) diff --git a/python/raft-dask/pytest.ini b/python/raft-dask/pytest.ini index f810351201..c70b37b39d 100644 --- a/python/raft-dask/pytest.ini +++ b/python/raft-dask/pytest.ini @@ -7,3 +7,4 @@ markers = memleak: marks a test as a memory leak test nccl: marks a test as using NCCL ucx: marks a test as using UCX + ucxx: marks a test as using UCXX diff --git a/python/raft-dask/raft_dask/test/conftest.py b/python/raft-dask/raft_dask/test/conftest.py index d1baa684d4..dac50b8884 100644 --- a/python/raft-dask/raft_dask/test/conftest.py +++ b/python/raft-dask/raft_dask/test/conftest.py @@ -34,6 +34,21 @@ def ucx_cluster(): cluster.close() +@pytest.fixture(scope="session") +def ucxx_cluster(): + pytest.importorskip("distributed_ucxx") + + scheduler_file = os.environ.get("SCHEDULER_FILE") + if scheduler_file: + yield scheduler_file + else: + cluster = LocalCUDACluster( + protocol="ucxx", + ) + yield cluster + cluster.close() + + @pytest.fixture(scope="session") def client(cluster): client = create_client(cluster) @@ -48,6 +63,13 @@ def ucx_client(ucx_cluster): client.close() +@pytest.fixture() +def ucxx_client(ucxx_cluster): + client = create_client(ucxx_cluster) + yield client + client.close() + + def create_client(cluster): """ Create a Dask distributed client for a specified cluster. diff --git a/python/raft-dask/raft_dask/test/test_comms.py b/python/raft-dask/raft_dask/test/test_comms.py index e18f70a718..89b0424167 100644 --- a/python/raft-dask/raft_dask/test/test_comms.py +++ b/python/raft-dask/raft_dask/test/test_comms.py @@ -310,6 +310,29 @@ def test_send_recv_protocol_ucx(n_trials, ucx_client): assert list(map(lambda x: x.result(), dfs)) +@pytest.mark.ucxx +@pytest.mark.parametrize("n_trials", [1, 5]) +def test_send_recv_protocol_ucxx(n_trials, ucxx_client): + + cb = Comms(comms_p2p=True, verbose=True) + cb.init() + + dfs = [ + ucxx_client.submit( + func_test_send_recv, + cb.sessionId, + n_trials, + pure=False, + workers=[w], + ) + for w in cb.worker_addresses + ] + + wait(dfs, timeout=5) + + assert list(map(lambda x: x.result(), dfs)) + + @pytest.mark.nccl @pytest.mark.parametrize("n_trials", [1, 5]) def test_device_send_or_recv(n_trials, client):