From 54f14e9f1c638ab53734c123f33bfaa7de0eac8e Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 8 May 2024 01:58:28 +0000 Subject: [PATCH] add more tests --- .github/workflows/tests.yaml | 2 +- pytest.ini | 2 +- tests/test_language.py | 46 ++++++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index a0c1bdd..133e11a 100755 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -31,7 +31,7 @@ jobs: run: | pip install --upgrade pip pip install -r requirements.txt - - name: Run tests and check at least 85% coverage + - name: Run tests and check at least 90% coverage run: | pytest diff --git a/pytest.ini b/pytest.ini index 4621e92..d68bf05 100755 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +1,2 @@ [pytest] -addopts = --cov=msccl --cov-report term-missing:skip-covered --cov-fail-under 85 -n auto +addopts = --cov=msccl --cov-report term-missing:skip-covered --cov-fail-under 90 -n auto diff --git a/tests/test_language.py b/tests/test_language.py index 1b95cd2..fcc0b71 100755 --- a/tests/test_language.py +++ b/tests/test_language.py @@ -363,3 +363,49 @@ def test_routines_allreduce_packet_inplace_mscclpp(): c.copy_packet(r, Buffer.input, peer * size, sendtb=peer) Json() assert Check() + +def test_routines_allreduce_inplace_mscclpp(): + size = 8 + topology = fully_connected(size) + collective = AllReduce(size, size * size, True) + with MSCCLPPProgram("allreduce_pairs", topology, collective, 2, protocol="Simple"): + # Each rank sends the nth chunk to the nth rank into scratch space + for rank in range(size): + for tb in range(size): + index = rank * size + c = chunk(rank, Buffer.input, index + tb) + # make sure the data is ready + for nghr in range(size): + peer_index = nghr * size + if rank != nghr: + c_peer = chunk(rank, Buffer.input, peer_index + tb) + c_peer.signal(nghr, Buffer.input, peer_index + tb, sendtb=tb) + for nghr in range(size): + if rank != nghr: + c.wait(nghr, Buffer.input, index + tb, recvtb=tb) + # reduce the chunks + for i in range(size): + nghr = (rank + i) % size + if rank != nghr: + c.reduce(chunk(nghr, Buffer.input, index + tb), recvtb=tb) + for nghr in range(size): + if rank != nghr: + c.signal(nghr, Buffer.input, index + tb, sendtb=tb) + + # wait for all the chunks is ready, then get the chunks + for rank in range(size): + for tb in range(size): + for nghr in range(size): + if rank != nghr: + index = nghr * size + c = chunk(rank, Buffer.input, index + tb) + c.wait(nghr, Buffer.input, index + tb, recvtb=tb) + for i in range(size): + nghr = (rank + i) % size + index = nghr * size + if rank != nghr: + c = chunk(rank, Buffer.input, index + tb) + c.get(nghr, Buffer.input, index + tb, recvtb=tb) + + Json() + assert Check()