From bebbbe13c999e1589152938a001f76d436cf4eb9 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Thu, 12 Sep 2024 09:27:38 +0000 Subject: [PATCH] fix --- msccl/language/mscclpp/instruction_dag.py | 26 ++++++------------- .../language/mscclpp/instruction_optimizer.py | 7 +++++ 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/msccl/language/mscclpp/instruction_dag.py b/msccl/language/mscclpp/instruction_dag.py index 703b0e1..1fdda79 100644 --- a/msccl/language/mscclpp/instruction_dag.py +++ b/msccl/language/mscclpp/instruction_dag.py @@ -4,14 +4,8 @@ from msccl.language.buffer import Buffer from msccl.language.instruction_dag import ( - buf_dst_src_match, - merge_op, - remove_op, - circular_dep_after_merge, same_buf_dst, same_buf_src, - same_chan_type, - same_count, same_src_dst_buffer_type, ) from msccl.language.instruction_dag import InstructionDAG @@ -223,30 +217,26 @@ def complete_channels(self): tb.channels = list(chans) def _remove_redundant_signal_wait(self): + optimizer = InstructionOptimizer() # For packet ops, we can remove signal/wait for rank, rank_tbs in enumerate(self.tbs): for tbid, tb in rank_tbs.items(): queue = list(tb.ops) while len(queue) > 0: op = queue[0] + fused = False if op.inst == Instruction.put_packet: - fused = False for next_op in op.next: - if next_op.inst == Instruction.signal: - remove_op(next_op) - fused = True + fused = optimizer.try_remove_op(next_op, tb, queue, next_op.inst == Instruction.signal) + if fused: break - if fused: - continue elif op.inst == Instruction.reduce_packet or op.inst == Instruction.copy_packet: - fused = False for prev_op in op.prev: - if prev_op.inst == Instruction.wait: - remove_op(prev_op) - fused = True + fused = optimizer.try_remove_op(prev_op, tb, queue, next_op.inst == Instruction.wait) + if fused: break - if fused: - continue + if fused: + continue queue = queue[1:] # put(src, sbuf, si, dst, dbuf, di) signal(src, sbuf, si, dst, dbuf, di) diff --git a/msccl/language/mscclpp/instruction_optimizer.py b/msccl/language/mscclpp/instruction_optimizer.py index 6640851..3f186c4 100644 --- a/msccl/language/mscclpp/instruction_optimizer.py +++ b/msccl/language/mscclpp/instruction_optimizer.py @@ -5,6 +5,7 @@ buf_dst_src_match, circular_dep_after_merge, merge_op, + remove_op, same_chan_type, same_count, same_buf_dst, @@ -180,3 +181,9 @@ def try_fuse_instructions_using_proxy_channel( queue.remove(next_op) return True return False + + def try_remove_op(self, pending_remove_op: Op, condition: bool) -> bool: + if condition: + remove_op(pending_remove_op) + return True + return False