From 610a499465e632c205528d0d4624107fa7897f7a Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Thu, 13 Jun 2024 14:02:26 +0800 Subject: [PATCH] Bug fix: avoid op merge if circle dependences will be introduced (#6) - generate json for reduce_packet OP - check if circle dependencies will be introduced when do op fusion --- msccl/language/collectives.py | 6 +++-- msccl/language/instruction_dag.py | 20 ++++++++++++++++ msccl/language/mscclpp/instruction_dag.py | 16 +++++++++++++ msccl/language/mscclpp/ir.py | 2 +- tests/test_language.py | 29 +++++++++++++++++++++++ 5 files changed, 70 insertions(+), 3 deletions(-) diff --git a/msccl/language/collectives.py b/msccl/language/collectives.py index a6fbcfa..9e01982 100755 --- a/msccl/language/collectives.py +++ b/msccl/language/collectives.py @@ -122,8 +122,10 @@ def get_buffer_index(self, rank, buffer, index): class AllReduce(Collective): - def __init__(self, num_ranks, chunk_factor, inplace): - Collective.__init__(self, num_ranks, chunk_factor, inplace, num_ranks) + def __init__(self, num_ranks, chunk_factor, inplace, num_chunk_groups=None): + if num_chunk_groups == None: + num_chunk_groups = num_ranks + Collective.__init__(self, num_ranks, chunk_factor, inplace, num_chunk_groups) self.name = "allreduce" def init_buffers(self): diff --git a/msccl/language/instruction_dag.py b/msccl/language/instruction_dag.py index d099126..451ea7e 100755 --- a/msccl/language/instruction_dag.py +++ b/msccl/language/instruction_dag.py @@ -12,11 +12,15 @@ def remove_op(op: Op): for p in op.prev: p.next.remove(op) p.next += op.next + p.next = list(set(p.next)) for n in op.next: n.prev.remove(op) n.prev = op.prev.union(n.prev) + op.next = [] + op.prev = [] + def merge_op(op: Op, other_op: Op): if other_op in op.next: @@ -34,6 +38,22 @@ def merge_op(op: Op, other_op: Op): op.next = list(set(op.next + other_op.next)) +def circular_dep_after_merge(op: Op, other_op: Op): + root = set([op, other_op]) + frontier = set(op.next) + if other_op in frontier: + frontier.remove(other_op) + frontier = list(frontier.union(other_op.next)) + while len(frontier) > 0: + current = frontier[0] + for n in current.next: + # The root node will be visited again if there is a circular dependency + if n in root: + return True + frontier.append(n) + frontier = frontier[1:] + + def same_tb(op1: Op, op2: Op): return op1.tb == op2.tb and op1.channel == op2.channel diff --git a/msccl/language/mscclpp/instruction_dag.py b/msccl/language/mscclpp/instruction_dag.py index 32d15b1..65a6455 100644 --- a/msccl/language/mscclpp/instruction_dag.py +++ b/msccl/language/mscclpp/instruction_dag.py @@ -8,6 +8,7 @@ buf_dst_src_match, merge_op, remove_op, + circular_dep_after_merge, same_buf_dst, same_buf_src, same_chan_type, @@ -234,6 +235,7 @@ def _optimize_rrc_r_signal_wait(self): and same_count(op, next_op) and same_buf_dst(op, next_op) and same_chan_type(op, next_op) + and not circular_dep_after_merge(op, next_op) ): op.srcs.append( ( @@ -257,6 +259,7 @@ def _optimize_rrc_r_signal_wait(self): next_op.inst == Instruction.reduce and same_buf_dst(op, next_op) and same_chan_type(op, next_op) + and not circular_dep_after_merge(op, next_op) ): op.srcs.append( ( @@ -280,6 +283,7 @@ def _optimize_rrc_r_signal_wait(self): next_op.inst == Instruction.reduce_packet and same_buf_dst(op, next_op) and same_chan_type(op, next_op) + and not circular_dep_after_merge(op, next_op) ): op.srcs.append( ( @@ -303,6 +307,7 @@ def _optimize_rrc_r_signal_wait(self): next_op.inst == Instruction.signal and same_buf_src(op, next_op) and same_chan_type(op, next_op) + and not circular_dep_after_merge(op, next_op) ): op.dsts.append( ( @@ -334,6 +339,7 @@ def _optimize_rrc_r_signal_wait(self): next_op.inst == Instruction.wait and same_buf_dst(op, next_op) and same_chan_type(op, next_op) + and not circular_dep_after_merge(op, next_op) ): op.srcs.append( ( @@ -376,6 +382,7 @@ def _optimize_rrcs_rs(self): and same_count(op, next_op) and buf_dst_src_match(op, next_op) and same_chan_type(op, next_op) + and not circular_dep_after_merge(op, next_op) ): if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer: continue @@ -404,6 +411,7 @@ def _optimize_rrcs_rs(self): and same_count(op, next_op) and buf_dst_src_match(op, next_op) and next_op.channel_type == ChannelType.sm + and not circular_dep_after_merge(op, next_op) ): if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer: continue @@ -433,6 +441,7 @@ def _optimize_rrcs_rs(self): and same_count(op, next_op) and buf_dst_src_match(op, next_op) and next_op.channel_type == ChannelType.sm + and not circular_dep_after_merge(op, next_op) ): if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer: continue @@ -473,6 +482,7 @@ def _optimize_get_put(self): and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op) and same_count(op, seq_op) + and not circular_dep_after_merge(op, seq_op) ): op.dsts.append( ( @@ -501,6 +511,7 @@ def _optimize_get_put(self): and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op) and same_count(op, seq_op) + and not circular_dep_after_merge(op, seq_op) ): op.dsts.append( ( @@ -529,6 +540,7 @@ def _optimize_get_put(self): and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op) and same_count(op, seq_op) + and not circular_dep_after_merge(op, seq_op) ): op.dsts.append( ( @@ -546,6 +558,8 @@ def _optimize_get_put(self): tb.ops.remove(seq_op) queue.remove(seq_op) fused = True + if fused: + continue queue = queue[1:] # For signal/wait ops, if they are independent of other operations and no other operations in between, @@ -567,6 +581,7 @@ def _parallel_signal_wait(self): seq_op.inst == Instruction.signal and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op) + and not circular_dep_after_merge(op, seq_op) ): op.dsts.append( ( @@ -594,6 +609,7 @@ def _parallel_signal_wait(self): seq_op.inst == Instruction.wait and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op) + and not circular_dep_after_merge(op, seq_op) ): op.dsts.append( ( diff --git a/msccl/language/mscclpp/ir.py b/msccl/language/mscclpp/ir.py index 0442e88..d7d068a 100644 --- a/msccl/language/mscclpp/ir.py +++ b/msccl/language/mscclpp/ir.py @@ -236,7 +236,7 @@ def remove_empty_fields(d): srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs)) dst = op.dst src = op.dst # TODO(binyli): fix this - elif op.inst == Instruction.reduce: + elif op.inst == Instruction.reduce or op.inst == Instruction.reduce_packet: srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs)) dst = op.dst elif op.inst == Instruction.nop: diff --git a/tests/test_language.py b/tests/test_language.py index fcc0b71..e61c6af 100755 --- a/tests/test_language.py +++ b/tests/test_language.py @@ -236,6 +236,35 @@ def test_instruction_fusion_mscclpp(): assert lowered_prgm.gpus[2].threadblocks[0].ops[2].inst == MscclppInstruction.signal +def test_instruction_fusion_multi_deps_mscclpp(): + topology = fully_connected(3) + collective = AllReduce(3, 1, True) + prgm = MSCCLPPProgram("allreduce", topology, collective, 1) + # The dependency graph for rank 1 is as follows: + # put(0i to 1s) => reduce(1s to 1i) => put(2i to 1s) => reduce(1s to 1i) + # | => put(1i to 0s) ^ + # | => put(1i to 2s)------------------- -| + # put(2i to 1s) => reduce(1s to 1i) for read after write + # put(1i to 2s) => reduce(1s to 1i) for write after read + # when we try to merge reduce(1s to 1i) => put(2i to 1s) => reduce(1s to 1i), + # circular dependency is introduced + with prgm: + c0 = chunk(0, Buffer.input, 0) + c0.put_packet(1, "scratch", 0, sendtb=0) + c1s = chunk(1, "scratch", 0) + c1 = chunk(1, Buffer.input, 0) + c1 = c1.reduce_packet(c1s, recvtb=0) + c1.put_packet(0, "scratch", 0, sendtb=0) + c1.put_packet(2, "scratch", 0, sendtb=0) + c2 = chunk(2, Buffer.input, 0) + c2.put_packet(1, "scratch", 0, sendtb=0) + c1.reduce_packet(c1s, recvtb=0) + lowered_prgm = prgm.lower() + lowered_prgm.gpus[1].threadblocks = [tb for tb in lowered_prgm.gpus[1].threadblocks if tb.id != -1] + assert lowered_prgm.gpus[1].threadblocks[0].ops[0].inst == MscclppInstruction.reduce_send_packet + assert lowered_prgm.gpus[1].threadblocks[0].ops[1].inst == MscclppInstruction.reduce_packet + + def test_replication(): topology = fully_connected(2) collective = AllToAll(2, 1, False)