Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed May 8, 2024
1 parent 5ee79e2 commit 54f14e9
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
run: |
pip install --upgrade pip
pip install -r requirements.txt
- name: Run tests and check at least 85% coverage
- name: Run tests and check at least 90% coverage
run: |
pytest
Expand Down
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[pytest]
addopts = --cov=msccl --cov-report term-missing:skip-covered --cov-fail-under 85 -n auto
addopts = --cov=msccl --cov-report term-missing:skip-covered --cov-fail-under 90 -n auto
46 changes: 46 additions & 0 deletions tests/test_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,49 @@ def test_routines_allreduce_packet_inplace_mscclpp():
c.copy_packet(r, Buffer.input, peer * size, sendtb=peer)
Json()
assert Check()

def test_routines_allreduce_inplace_mscclpp():
size = 8
topology = fully_connected(size)
collective = AllReduce(size, size * size, True)
with MSCCLPPProgram("allreduce_pairs", topology, collective, 2, protocol="Simple"):
# Each rank sends the nth chunk to the nth rank into scratch space
for rank in range(size):
for tb in range(size):
index = rank * size
c = chunk(rank, Buffer.input, index + tb)
# make sure the data is ready
for nghr in range(size):
peer_index = nghr * size
if rank != nghr:
c_peer = chunk(rank, Buffer.input, peer_index + tb)
c_peer.signal(nghr, Buffer.input, peer_index + tb, sendtb=tb)
for nghr in range(size):
if rank != nghr:
c.wait(nghr, Buffer.input, index + tb, recvtb=tb)
# reduce the chunks
for i in range(size):
nghr = (rank + i) % size
if rank != nghr:
c.reduce(chunk(nghr, Buffer.input, index + tb), recvtb=tb)
for nghr in range(size):
if rank != nghr:
c.signal(nghr, Buffer.input, index + tb, sendtb=tb)

# wait for all the chunks is ready, then get the chunks
for rank in range(size):
for tb in range(size):
for nghr in range(size):
if rank != nghr:
index = nghr * size
c = chunk(rank, Buffer.input, index + tb)
c.wait(nghr, Buffer.input, index + tb, recvtb=tb)
for i in range(size):
nghr = (rank + i) % size
index = nghr * size
if rank != nghr:
c = chunk(rank, Buffer.input, index + tb)
c.get(nghr, Buffer.input, index + tb, recvtb=tb)

Json()
assert Check()

0 comments on commit 54f14e9

Please sign in to comment.