Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Sep 12, 2024
1 parent 7818988 commit bebbbe1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 18 deletions.
26 changes: 8 additions & 18 deletions msccl/language/mscclpp/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,8 @@

from msccl.language.buffer import Buffer
from msccl.language.instruction_dag import (
buf_dst_src_match,
merge_op,
remove_op,
circular_dep_after_merge,
same_buf_dst,
same_buf_src,
same_chan_type,
same_count,
same_src_dst_buffer_type,
)
from msccl.language.instruction_dag import InstructionDAG
Expand Down Expand Up @@ -223,30 +217,26 @@ def complete_channels(self):
tb.channels = list(chans)

def _remove_redundant_signal_wait(self):
optimizer = InstructionOptimizer()
# For packet ops, we can remove signal/wait
for rank, rank_tbs in enumerate(self.tbs):
for tbid, tb in rank_tbs.items():
queue = list(tb.ops)
while len(queue) > 0:
op = queue[0]
fused = False
if op.inst == Instruction.put_packet:
fused = False
for next_op in op.next:
if next_op.inst == Instruction.signal:
remove_op(next_op)
fused = True
fused = optimizer.try_remove_op(next_op, tb, queue, next_op.inst == Instruction.signal)
if fused:
break
if fused:
continue
elif op.inst == Instruction.reduce_packet or op.inst == Instruction.copy_packet:
fused = False
for prev_op in op.prev:
if prev_op.inst == Instruction.wait:
remove_op(prev_op)
fused = True
fused = optimizer.try_remove_op(prev_op, tb, queue, next_op.inst == Instruction.wait)
if fused:
break
if fused:
continue
if fused:
continue
queue = queue[1:]

# put(src, sbuf, si, dst, dbuf, di) signal(src, sbuf, si, dst, dbuf, di)
Expand Down
7 changes: 7 additions & 0 deletions msccl/language/mscclpp/instruction_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
buf_dst_src_match,
circular_dep_after_merge,
merge_op,
remove_op,
same_chan_type,
same_count,
same_buf_dst,
Expand Down Expand Up @@ -180,3 +181,9 @@ def try_fuse_instructions_using_proxy_channel(
queue.remove(next_op)
return True
return False

def try_remove_op(self, pending_remove_op: Op, condition: bool) -> bool:
if condition:
remove_op(pending_remove_op)
return True
return False

0 comments on commit bebbbe1

Please sign in to comment.