diff --git a/msccl/language/instruction_dag.py b/msccl/language/instruction_dag.py index 7bd70bf..d099126 100755 --- a/msccl/language/instruction_dag.py +++ b/msccl/language/instruction_dag.py @@ -19,6 +19,9 @@ def remove_op(op: Op): def merge_op(op: Op, other_op: Op): + if other_op in op.next: + op.next.remove(other_op) + other_op.prev.remove(op) for p in other_op.prev: p.next.remove(other_op) p.next.append(op) @@ -28,7 +31,7 @@ def merge_op(op: Op, other_op: Op): n.prev.add(op) op.prev = op.prev.union(other_op.prev) - op.next += other_op.next + op.next = list(set(op.next + other_op.next)) def same_tb(op1: Op, op2: Op): diff --git a/msccl/language/mscclpp/instruction_dag.py b/msccl/language/mscclpp/instruction_dag.py index 28dcd8b..32d15b1 100644 --- a/msccl/language/mscclpp/instruction_dag.py +++ b/msccl/language/mscclpp/instruction_dag.py @@ -243,7 +243,7 @@ def _optimize_rrc_r_signal_wait(self): next_op.step, ) ) - remove_op(next_op) + merge_op(op, next_op) tb.ops.remove(next_op) queue.remove(next_op) fused = True @@ -266,7 +266,7 @@ def _optimize_rrc_r_signal_wait(self): next_op.step, ) ) - remove_op(next_op) + merge_op(op, next_op) tb.ops.remove(next_op) queue.remove(next_op) fused = True @@ -289,7 +289,7 @@ def _optimize_rrc_r_signal_wait(self): next_op.step, ) ) - remove_op(next_op) + merge_op(op, next_op) tb.ops.remove(next_op) queue.remove(next_op) fused = True @@ -320,7 +320,7 @@ def _optimize_rrc_r_signal_wait(self): next_op.step, ) ) - remove_op(next_op) + merge_op(op, next_op) tb.ops.remove(next_op) queue.remove(next_op) fused = True @@ -351,7 +351,7 @@ def _optimize_rrc_r_signal_wait(self): next_op.step, ) ) - remove_op(next_op) + merge_op(op, next_op) tb.ops.remove(next_op) queue.remove(next_op) fused = True @@ -389,7 +389,7 @@ def _optimize_rrcs_rs(self): next_op.step, ) ) - remove_op(next_op) + merge_op(op, next_op) tb.ops.remove(next_op) queue.remove(next_op) fused = True @@ -418,7 +418,7 @@ def _optimize_rrcs_rs(self): next_op.step, ) ) - remove_op(next_op) + merge_op(op, next_op) tb.ops.remove(next_op) queue.remove(next_op) fused = True @@ -447,7 +447,7 @@ def _optimize_rrcs_rs(self): next_op.step, ) ) - remove_op(next_op) + merge_op(op, next_op) tb.ops.remove(next_op) queue.remove(next_op) fused = True