Skip to content

Commit

Permalink
Add flush op (#13)
Browse files Browse the repository at this point in the history
Add flush op
Refactor optimizer part
  • Loading branch information
Binyang2014 authored Sep 24, 2024
1 parent 9c94c02 commit ea828bf
Show file tree
Hide file tree
Showing 7 changed files with 428 additions and 412 deletions.
50 changes: 50 additions & 0 deletions examples/mscclang/mscclpp/send_recv_proxy_mscclpp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
from msccl.language import *
from msccl.topologies import *
from msccl.language.collectives import SendRecv


def send_recv(instances):
size = 2
chunksperloop = 1
topology = fully_connected(size)
collective = SendRecv(size, chunksperloop, False)
with MSCCLPPProgram(
"send_recv",
topology,
collective,
instances,
):
for r in range(size):
for nghr in range(size):
if nghr == r:
continue
c = chunk(r, Buffer.input, 0)
c.put(
nghr,
"scratch",
1,
sendtb=0,
chan_type=ChannelType.proxy,
)
c.signal(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.proxy)
c.flush(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.proxy)

for r in range(size):
c = chunk(r, "scratch", 1)
c.wait(1-r, Buffer.input, 0, recvtb=0, chan_type=ChannelType.proxy)
c.copy(r, Buffer.output, 0, sendtb=0)

Json()
Check()


parser = argparse.ArgumentParser()
parser.add_argument("instances", type=int, help="number of instances")

args = parser.parse_args()

send_recv(args.instances)
13 changes: 13 additions & 0 deletions msccl/language/mscclpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def check(self):
def lower(self):
self._convert_to_exectuion_plan()
self.instr_dag.complete_channels()
self.instr_dag.remove_redundant_signal_wait()
if self.instr_fusion:
self.instr_dag.optimize()
self.instr_dag.lower_pt1(self.instances)
Expand Down Expand Up @@ -255,6 +256,18 @@ def signal(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
self.prog.instr_dag.add_signal(sender, self, dst_chunkref, sendtb, chan_type)

# only proxy channel need to use this function
def flush(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.proxy):
assert chan_type == ChannelType.proxy, "Only proxy channel can use flush"
sender = self.rank
receiver = dst
assert sender != receiver, "Cannot flush to the same rank"
buffer, index = self._get_buffer_index(dst, buffer, index)

assert self.prog.topo.link(self.rank, dst) or dst == self.rank, f"No link from {self.rank} to {dst}"
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
self.prog.instr_dag.add_flush(sender, self, dst_chunkref, sendtb)

def wait(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm):
sender = src
receiver = self.rank
Expand Down
Loading

0 comments on commit ea828bf

Please sign in to comment.