Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Jun 12, 2024
1 parent 3a18b78 commit e3ae648
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
20 changes: 20 additions & 0 deletions msccl/language/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@ def remove_op(op: Op):
for p in op.prev:
p.next.remove(op)
p.next += op.next
p.next = list(set(p.next))

for n in op.next:
n.prev.remove(op)
n.prev = op.prev.union(n.prev)

op.next = []
op.prev = []


def merge_op(op: Op, other_op: Op):
if other_op in op.next:
Expand All @@ -34,6 +38,22 @@ def merge_op(op: Op, other_op: Op):
op.next = list(set(op.next + other_op.next))


def circular_dep_after_merge(op: Op, other_op: Op):
root = set([op, other_op])
frontier = set(op.next)
if other_op in frontier:
frontier.remove(other_op)
frontier = list(frontier.union(other_op.next))
while len(frontier) > 0:
current = frontier[0]
for n in current.next:
# The root node will be visited again if there is a circular dependency
if n in root:
return True
frontier.append(n)
frontier = frontier[1:]


def same_tb(op1: Op, op2: Op):
return op1.tb == op2.tb and op1.channel == op2.channel

Expand Down
14 changes: 14 additions & 0 deletions msccl/language/mscclpp/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
buf_dst_src_match,
merge_op,
remove_op,
circular_dep_after_merge,
same_buf_dst,
same_buf_src,
same_chan_type,
Expand Down Expand Up @@ -234,6 +235,7 @@ def _optimize_rrc_r_signal_wait(self):
and same_count(op, next_op)
and same_buf_dst(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
):
op.srcs.append(
(
Expand All @@ -257,6 +259,7 @@ def _optimize_rrc_r_signal_wait(self):
next_op.inst == Instruction.reduce
and same_buf_dst(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
):
op.srcs.append(
(
Expand All @@ -280,6 +283,7 @@ def _optimize_rrc_r_signal_wait(self):
next_op.inst == Instruction.reduce_packet
and same_buf_dst(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
):
op.srcs.append(
(
Expand All @@ -303,6 +307,7 @@ def _optimize_rrc_r_signal_wait(self):
next_op.inst == Instruction.signal
and same_buf_src(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
):
op.dsts.append(
(
Expand Down Expand Up @@ -334,6 +339,7 @@ def _optimize_rrc_r_signal_wait(self):
next_op.inst == Instruction.wait
and same_buf_dst(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
):
op.srcs.append(
(
Expand Down Expand Up @@ -376,6 +382,7 @@ def _optimize_rrcs_rs(self):
and same_count(op, next_op)
and buf_dst_src_match(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
):
if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer:
continue
Expand Down Expand Up @@ -404,6 +411,7 @@ def _optimize_rrcs_rs(self):
and same_count(op, next_op)
and buf_dst_src_match(op, next_op)
and next_op.channel_type == ChannelType.sm
and not circular_dep_after_merge(op, next_op)
):
if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer:
continue
Expand Down Expand Up @@ -433,6 +441,7 @@ def _optimize_rrcs_rs(self):
and same_count(op, next_op)
and buf_dst_src_match(op, next_op)
and next_op.channel_type == ChannelType.sm
and not circular_dep_after_merge(op, next_op)
):
if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer:
continue
Expand Down Expand Up @@ -473,6 +482,7 @@ def _optimize_get_put(self):
and same_src_dst_buffer_type(op, seq_op)
and same_chan_type(op, seq_op)
and same_count(op, seq_op)
and not circular_dep_after_merge(op, seq_op)
):
op.dsts.append(
(
Expand Down Expand Up @@ -501,6 +511,7 @@ def _optimize_get_put(self):
and same_src_dst_buffer_type(op, seq_op)
and same_chan_type(op, seq_op)
and same_count(op, seq_op)
and not circular_dep_after_merge(op, seq_op)
):
op.dsts.append(
(
Expand Down Expand Up @@ -529,6 +540,7 @@ def _optimize_get_put(self):
and same_src_dst_buffer_type(op, seq_op)
and same_chan_type(op, seq_op)
and same_count(op, seq_op)
and not circular_dep_after_merge(op, seq_op)
):
op.dsts.append(
(
Expand Down Expand Up @@ -569,6 +581,7 @@ def _parallel_signal_wait(self):
seq_op.inst == Instruction.signal
and same_src_dst_buffer_type(op, seq_op)
and same_chan_type(op, seq_op)
and not circular_dep_after_merge(op, seq_op)
):
op.dsts.append(
(
Expand Down Expand Up @@ -596,6 +609,7 @@ def _parallel_signal_wait(self):
seq_op.inst == Instruction.wait
and same_src_dst_buffer_type(op, seq_op)
and same_chan_type(op, seq_op)
and not circular_dep_after_merge(op, seq_op)
):
op.dsts.append(
(
Expand Down

0 comments on commit e3ae648

Please sign in to comment.