Skip to content

Commit

Permalink
Fix op fusion issue (#21)
Browse files Browse the repository at this point in the history
For case: op2.prev = [op1, op3]. op1.next = [op2]. op3.next = [op2]. And op1 and op2 are satisfied to merge.
We only apply the merge if all previous ops of op2 are visited after the merge. Make sure the results are respected to users' algo
  • Loading branch information
Binyang2014 authored Oct 28, 2024
1 parent 3c94ad5 commit 79ed5ae
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
10 changes: 10 additions & 0 deletions msccl/language/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ def circular_dep_after_merge(op: Op, other_op: Op):
frontier.append(n)
frontier = frontier[1:]

"""
For case: op2.prev = [op1, op3]. op1.next = [op2]. op3.next = [op2]. And op1 and op2 are satisfied to merge.
We only apply the merge if all previous ops of op2 are visited. (op1 is the last previous op of op2).
"""
def all_prevs_visited_after_merge(op: Op, other_op: Op):
step = op.step
for prev in other_op.prev:
if prev.step > step:
return False
return True

def same_tb(op1: Op, op2: Op):
return op1.tb == op2.tb and op1.channel == op2.channel
Expand Down
5 changes: 5 additions & 0 deletions msccl/language/mscclpp/instruction_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
same_count,
same_buf_dst,
same_buf_src,
all_prevs_visited_after_merge,
)
from msccl.language.types import ChunkRef, ChannelType, MscclppInstruction as Instruction, Op, Threadblock

Expand Down Expand Up @@ -41,6 +42,7 @@ def try_merge_same_instructions(
and same_count(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
and all_prevs_visited_after_merge(op, next_op)
):
# Append the source chunks from next_op
op.srcs.append(
Expand Down Expand Up @@ -85,6 +87,7 @@ def try_compact_instructions(
and same_chan_type(op, seq_op)
and same_count(op, seq_op)
and not circular_dep_after_merge(op, seq_op)
and all_prevs_visited_after_merge(op, seq_op)
):
# Append the source and destination chunks from seq_op
op.dsts.append(
Expand Down Expand Up @@ -124,6 +127,7 @@ def try_fuse_with_put(self, op: Op, next_op: Op, tb: Threadblock, queue: list) -
and next_op.channel_type == ChannelType.sm
and (op.channel_type == ChannelType.none or op.channel_type == ChannelType.sm)
and not circular_dep_after_merge(op, next_op)
and all_prevs_visited_after_merge(op, next_op)
):
if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer:
return False
Expand Down Expand Up @@ -170,6 +174,7 @@ def try_fuse_instructions_using_proxy_channel(
and same_chan_type(op, next_op)
and op.channel_type == ChannelType.proxy
and not circular_dep_after_merge(op, next_op)
and all_prevs_visited_after_merge(op, next_op)
):
if op.inst == Instruction.put and next_op.inst == Instruction.signal:
op.inst = Instruction.put_with_signal
Expand Down

0 comments on commit 79ed5ae

Please sign in to comment.