diff --git a/tests/test_language.py b/tests/test_language.py index 4124b58..e61c6af 100755 --- a/tests/test_language.py +++ b/tests/test_language.py @@ -240,9 +240,14 @@ def test_instruction_fusion_multi_deps_mscclpp(): topology = fully_connected(3) collective = AllReduce(3, 1, True) prgm = MSCCLPPProgram("allreduce", topology, collective, 1) - # last reduce_packet depends on put_packets(write after read) - # and first reduce_packet(rank 2 put data to rank 1, data put depends on first reduce_packet). - # In this case, we don't need to fuse the operations. + # The dependency graph for rank 1 is as follows: + # put(0i to 1s) => reduce(1s to 1i) => put(2i to 1s) => reduce(1s to 1i) + # | => put(1i to 0s) ^ + # | => put(1i to 2s)------------------- -| + # put(2i to 1s) => reduce(1s to 1i) for read after write + # put(1i to 2s) => reduce(1s to 1i) for write after read + # when we try to merge reduce(1s to 1i) => put(2i to 1s) => reduce(1s to 1i), + # circular dependency is introduced with prgm: c0 = chunk(0, Buffer.input, 0) c0.put_packet(1, "scratch", 0, sendtb=0)