Skip to content

Commit

Permalink
Add test_send_recv_protocol_ucxx
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Nov 13, 2023
1 parent 40d5471 commit 0798c32
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/raft-dask/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ markers =
memleak: marks a test as a memory leak test
nccl: marks a test as using NCCL
ucx: marks a test as using UCX
ucxx: marks a test as using UCXX
22 changes: 22 additions & 0 deletions python/raft-dask/raft_dask/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@ def ucx_cluster():
cluster.close()


@pytest.fixture(scope="session")
def ucxx_cluster():
pytest.importorskip("distributed_ucxx")

scheduler_file = os.environ.get("SCHEDULER_FILE")
if scheduler_file:
yield scheduler_file
else:
cluster = LocalCUDACluster(
protocol="ucxx",
)
yield cluster
cluster.close()


@pytest.fixture(scope="session")
def client(cluster):
client = create_client(cluster)
Expand All @@ -48,6 +63,13 @@ def ucx_client(ucx_cluster):
client.close()


@pytest.fixture()
def ucxx_client(ucxx_cluster):
client = create_client(ucxx_cluster)
yield client
client.close()


def create_client(cluster):
"""
Create a Dask distributed client for a specified cluster.
Expand Down
23 changes: 23 additions & 0 deletions python/raft-dask/raft_dask/test/test_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,29 @@ def test_send_recv_protocol_ucx(n_trials, ucx_client):
assert list(map(lambda x: x.result(), dfs))


@pytest.mark.ucxx
@pytest.mark.parametrize("n_trials", [1, 5])
def test_send_recv_protocol_ucxx(n_trials, ucxx_client):

cb = Comms(comms_p2p=True, verbose=True)
cb.init()

dfs = [
ucxx_client.submit(
func_test_send_recv,
cb.sessionId,
n_trials,
pure=False,
workers=[w],
)
for w in cb.worker_addresses
]

wait(dfs, timeout=5)

assert list(map(lambda x: x.result(), dfs))


@pytest.mark.nccl
@pytest.mark.parametrize("n_trials", [1, 5])
def test_device_send_or_recv(n_trials, client):
Expand Down

0 comments on commit 0798c32

Please sign in to comment.