Skip to content

Commit

Permalink
Bug fix: avoid op merge if circle dependences will be introduced (#6)
Browse files Browse the repository at this point in the history
- generate json for reduce_packet OP
- check if circle dependencies will be introduced when do op fusion
  • Loading branch information
Binyang2014 authored Jun 13, 2024
1 parent f6a3b39 commit 610a499
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 3 deletions.
6 changes: 4 additions & 2 deletions msccl/language/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,10 @@ def get_buffer_index(self, rank, buffer, index):

class AllReduce(Collective):

def __init__(self, num_ranks, chunk_factor, inplace):
Collective.__init__(self, num_ranks, chunk_factor, inplace, num_ranks)
def __init__(self, num_ranks, chunk_factor, inplace, num_chunk_groups=None):
if num_chunk_groups == None:
num_chunk_groups = num_ranks
Collective.__init__(self, num_ranks, chunk_factor, inplace, num_chunk_groups)
self.name = "allreduce"

def init_buffers(self):
Expand Down
20 changes: 20 additions & 0 deletions msccl/language/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@ def remove_op(op: Op):
for p in op.prev:
p.next.remove(op)
p.next += op.next
p.next = list(set(p.next))

for n in op.next:
n.prev.remove(op)
n.prev = op.prev.union(n.prev)

op.next = []
op.prev = []


def merge_op(op: Op, other_op: Op):
if other_op in op.next:
Expand All @@ -34,6 +38,22 @@ def merge_op(op: Op, other_op: Op):
op.next = list(set(op.next + other_op.next))


def circular_dep_after_merge(op: Op, other_op: Op):
root = set([op, other_op])
frontier = set(op.next)
if other_op in frontier:
frontier.remove(other_op)
frontier = list(frontier.union(other_op.next))
while len(frontier) > 0:
current = frontier[0]
for n in current.next:
# The root node will be visited again if there is a circular dependency
if n in root:
return True
frontier.append(n)
frontier = frontier[1:]


def same_tb(op1: Op, op2: Op):
return op1.tb == op2.tb and op1.channel == op2.channel

Expand Down
16 changes: 16 additions & 0 deletions msccl/language/mscclpp/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
buf_dst_src_match,
merge_op,
remove_op,
circular_dep_after_merge,
same_buf_dst,
same_buf_src,
same_chan_type,
Expand Down Expand Up @@ -234,6 +235,7 @@ def _optimize_rrc_r_signal_wait(self):
and same_count(op, next_op)
and same_buf_dst(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
):
op.srcs.append(
(
Expand All @@ -257,6 +259,7 @@ def _optimize_rrc_r_signal_wait(self):
next_op.inst == Instruction.reduce
and same_buf_dst(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
):
op.srcs.append(
(
Expand All @@ -280,6 +283,7 @@ def _optimize_rrc_r_signal_wait(self):
next_op.inst == Instruction.reduce_packet
and same_buf_dst(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
):
op.srcs.append(
(
Expand All @@ -303,6 +307,7 @@ def _optimize_rrc_r_signal_wait(self):
next_op.inst == Instruction.signal
and same_buf_src(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
):
op.dsts.append(
(
Expand Down Expand Up @@ -334,6 +339,7 @@ def _optimize_rrc_r_signal_wait(self):
next_op.inst == Instruction.wait
and same_buf_dst(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
):
op.srcs.append(
(
Expand Down Expand Up @@ -376,6 +382,7 @@ def _optimize_rrcs_rs(self):
and same_count(op, next_op)
and buf_dst_src_match(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
):
if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer:
continue
Expand Down Expand Up @@ -404,6 +411,7 @@ def _optimize_rrcs_rs(self):
and same_count(op, next_op)
and buf_dst_src_match(op, next_op)
and next_op.channel_type == ChannelType.sm
and not circular_dep_after_merge(op, next_op)
):
if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer:
continue
Expand Down Expand Up @@ -433,6 +441,7 @@ def _optimize_rrcs_rs(self):
and same_count(op, next_op)
and buf_dst_src_match(op, next_op)
and next_op.channel_type == ChannelType.sm
and not circular_dep_after_merge(op, next_op)
):
if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer:
continue
Expand Down Expand Up @@ -473,6 +482,7 @@ def _optimize_get_put(self):
and same_src_dst_buffer_type(op, seq_op)
and same_chan_type(op, seq_op)
and same_count(op, seq_op)
and not circular_dep_after_merge(op, seq_op)
):
op.dsts.append(
(
Expand Down Expand Up @@ -501,6 +511,7 @@ def _optimize_get_put(self):
and same_src_dst_buffer_type(op, seq_op)
and same_chan_type(op, seq_op)
and same_count(op, seq_op)
and not circular_dep_after_merge(op, seq_op)
):
op.dsts.append(
(
Expand Down Expand Up @@ -529,6 +540,7 @@ def _optimize_get_put(self):
and same_src_dst_buffer_type(op, seq_op)
and same_chan_type(op, seq_op)
and same_count(op, seq_op)
and not circular_dep_after_merge(op, seq_op)
):
op.dsts.append(
(
Expand All @@ -546,6 +558,8 @@ def _optimize_get_put(self):
tb.ops.remove(seq_op)
queue.remove(seq_op)
fused = True
if fused:
continue
queue = queue[1:]

# For signal/wait ops, if they are independent of other operations and no other operations in between,
Expand All @@ -567,6 +581,7 @@ def _parallel_signal_wait(self):
seq_op.inst == Instruction.signal
and same_src_dst_buffer_type(op, seq_op)
and same_chan_type(op, seq_op)
and not circular_dep_after_merge(op, seq_op)
):
op.dsts.append(
(
Expand Down Expand Up @@ -594,6 +609,7 @@ def _parallel_signal_wait(self):
seq_op.inst == Instruction.wait
and same_src_dst_buffer_type(op, seq_op)
and same_chan_type(op, seq_op)
and not circular_dep_after_merge(op, seq_op)
):
op.dsts.append(
(
Expand Down
2 changes: 1 addition & 1 deletion msccl/language/mscclpp/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def remove_empty_fields(d):
srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs))
dst = op.dst
src = op.dst # TODO(binyli): fix this
elif op.inst == Instruction.reduce:
elif op.inst == Instruction.reduce or op.inst == Instruction.reduce_packet:
srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs))
dst = op.dst
elif op.inst == Instruction.nop:
Expand Down
29 changes: 29 additions & 0 deletions tests/test_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,35 @@ def test_instruction_fusion_mscclpp():
assert lowered_prgm.gpus[2].threadblocks[0].ops[2].inst == MscclppInstruction.signal


def test_instruction_fusion_multi_deps_mscclpp():
topology = fully_connected(3)
collective = AllReduce(3, 1, True)
prgm = MSCCLPPProgram("allreduce", topology, collective, 1)
# The dependency graph for rank 1 is as follows:
# put(0i to 1s) => reduce(1s to 1i) => put(2i to 1s) => reduce(1s to 1i)
# | => put(1i to 0s) ^
# | => put(1i to 2s)------------------- -|
# put(2i to 1s) => reduce(1s to 1i) for read after write
# put(1i to 2s) => reduce(1s to 1i) for write after read
# when we try to merge reduce(1s to 1i) => put(2i to 1s) => reduce(1s to 1i),
# circular dependency is introduced
with prgm:
c0 = chunk(0, Buffer.input, 0)
c0.put_packet(1, "scratch", 0, sendtb=0)
c1s = chunk(1, "scratch", 0)
c1 = chunk(1, Buffer.input, 0)
c1 = c1.reduce_packet(c1s, recvtb=0)
c1.put_packet(0, "scratch", 0, sendtb=0)
c1.put_packet(2, "scratch", 0, sendtb=0)
c2 = chunk(2, Buffer.input, 0)
c2.put_packet(1, "scratch", 0, sendtb=0)
c1.reduce_packet(c1s, recvtb=0)
lowered_prgm = prgm.lower()
lowered_prgm.gpus[1].threadblocks = [tb for tb in lowered_prgm.gpus[1].threadblocks if tb.id != -1]
assert lowered_prgm.gpus[1].threadblocks[0].ops[0].inst == MscclppInstruction.reduce_send_packet
assert lowered_prgm.gpus[1].threadblocks[0].ops[1].inst == MscclppInstruction.reduce_packet


def test_replication():
topology = fully_connected(2)
collective = AllToAll(2, 1, False)
Expand Down

0 comments on commit 610a499

Please sign in to comment.