diff --git a/msccl/language/instruction_dag.py b/msccl/language/instruction_dag.py index d099126..451ea7e 100755 --- a/msccl/language/instruction_dag.py +++ b/msccl/language/instruction_dag.py @@ -12,11 +12,15 @@ def remove_op(op: Op): for p in op.prev: p.next.remove(op) p.next += op.next + p.next = list(set(p.next)) for n in op.next: n.prev.remove(op) n.prev = op.prev.union(n.prev) + op.next = [] + op.prev = [] + def merge_op(op: Op, other_op: Op): if other_op in op.next: @@ -34,6 +38,22 @@ def merge_op(op: Op, other_op: Op): op.next = list(set(op.next + other_op.next)) +def circular_dep_after_merge(op: Op, other_op: Op): + root = set([op, other_op]) + frontier = set(op.next) + if other_op in frontier: + frontier.remove(other_op) + frontier = list(frontier.union(other_op.next)) + while len(frontier) > 0: + current = frontier[0] + for n in current.next: + # The root node will be visited again if there is a circular dependency + if n in root: + return True + frontier.append(n) + frontier = frontier[1:] + + def same_tb(op1: Op, op2: Op): return op1.tb == op2.tb and op1.channel == op2.channel diff --git a/msccl/language/mscclpp/instruction_dag.py b/msccl/language/mscclpp/instruction_dag.py index 91964e2..65a6455 100644 --- a/msccl/language/mscclpp/instruction_dag.py +++ b/msccl/language/mscclpp/instruction_dag.py @@ -8,6 +8,7 @@ buf_dst_src_match, merge_op, remove_op, + circular_dep_after_merge, same_buf_dst, same_buf_src, same_chan_type, @@ -234,6 +235,7 @@ def _optimize_rrc_r_signal_wait(self): and same_count(op, next_op) and same_buf_dst(op, next_op) and same_chan_type(op, next_op) + and not circular_dep_after_merge(op, next_op) ): op.srcs.append( ( @@ -257,6 +259,7 @@ def _optimize_rrc_r_signal_wait(self): next_op.inst == Instruction.reduce and same_buf_dst(op, next_op) and same_chan_type(op, next_op) + and not circular_dep_after_merge(op, next_op) ): op.srcs.append( ( @@ -280,6 +283,7 @@ def _optimize_rrc_r_signal_wait(self): next_op.inst == Instruction.reduce_packet and same_buf_dst(op, next_op) and same_chan_type(op, next_op) + and not circular_dep_after_merge(op, next_op) ): op.srcs.append( ( @@ -303,6 +307,7 @@ def _optimize_rrc_r_signal_wait(self): next_op.inst == Instruction.signal and same_buf_src(op, next_op) and same_chan_type(op, next_op) + and not circular_dep_after_merge(op, next_op) ): op.dsts.append( ( @@ -334,6 +339,7 @@ def _optimize_rrc_r_signal_wait(self): next_op.inst == Instruction.wait and same_buf_dst(op, next_op) and same_chan_type(op, next_op) + and not circular_dep_after_merge(op, next_op) ): op.srcs.append( ( @@ -376,6 +382,7 @@ def _optimize_rrcs_rs(self): and same_count(op, next_op) and buf_dst_src_match(op, next_op) and same_chan_type(op, next_op) + and not circular_dep_after_merge(op, next_op) ): if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer: continue @@ -404,6 +411,7 @@ def _optimize_rrcs_rs(self): and same_count(op, next_op) and buf_dst_src_match(op, next_op) and next_op.channel_type == ChannelType.sm + and not circular_dep_after_merge(op, next_op) ): if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer: continue @@ -433,6 +441,7 @@ def _optimize_rrcs_rs(self): and same_count(op, next_op) and buf_dst_src_match(op, next_op) and next_op.channel_type == ChannelType.sm + and not circular_dep_after_merge(op, next_op) ): if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer: continue @@ -473,6 +482,7 @@ def _optimize_get_put(self): and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op) and same_count(op, seq_op) + and not circular_dep_after_merge(op, seq_op) ): op.dsts.append( ( @@ -501,6 +511,7 @@ def _optimize_get_put(self): and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op) and same_count(op, seq_op) + and not circular_dep_after_merge(op, seq_op) ): op.dsts.append( ( @@ -529,6 +540,7 @@ def _optimize_get_put(self): and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op) and same_count(op, seq_op) + and not circular_dep_after_merge(op, seq_op) ): op.dsts.append( ( @@ -569,6 +581,7 @@ def _parallel_signal_wait(self): seq_op.inst == Instruction.signal and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op) + and not circular_dep_after_merge(op, seq_op) ): op.dsts.append( ( @@ -596,6 +609,7 @@ def _parallel_signal_wait(self): seq_op.inst == Instruction.wait and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op) + and not circular_dep_after_merge(op, seq_op) ): op.dsts.append( (