Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Nov 4, 2024
1 parent 2a04c61 commit 4b30b54
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 31 deletions.
50 changes: 50 additions & 0 deletions examples/mscclang/mscclpp/allreduce_nvls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
from msccl.language import *
from msccl.topologies import *
from msccl.language.collectives import AllReduce


def allreduce_allpairs(gpus, instances):
size = gpus
chunksperloop = gpus
topology = fully_connected(size)
collective = AllReduce(size, chunksperloop, True)
with MSCCLPPProgram(
"allreduce_nvls",
topology,
collective,
instances,
):

# Each rank sends the nth chunk to the nth rank into scratch space
for rank in range(size):
index = rank
c = chunk(rank, Buffer.input, index)
reduce_chunks = []
# make sure the data is ready
for nghr in range(size):
if rank != nghr:
c_peer = chunk(rank, Buffer.input, index)
reduce_chunks.append(c_peer)
c_peer.signal(nghr, Buffer.input, index, sendtb=0)
for nghr in range(size):
if rank != nghr:
c.wait(nghr, Buffer.input, index, recvtb=0)
c.group_load_reduce(reduce_chunks, recvtb=0)
ngbrs = [nghr for nghr in range(size) if nghr != rank]
c.group_store(ngbrs, sendtb=0)

Json()
# Check()


parser = argparse.ArgumentParser()
parser.add_argument("num_gpus", type=int, help="number of gpus")
parser.add_argument("instances", type=int, help="number of instances")

args = parser.parse_args()

allreduce_allpairs(args.num_gpus, args.instances)
54 changes: 23 additions & 31 deletions msccl/language/mscclpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,39 +331,31 @@ def reduce_packet(self, other_chunkref, recvtb=-1):
# Group operations. These operations are used to perform collective operations across multiple chunks.
# For now, all chunks must has the same buffer type and offset.
# """
def _assert_same_node(self, other_chunkrefs):
nranks_per_node = self.prog.collective.num_ranks_per_node
for i in range(len(other_chunkrefs)):
assert (
self.rank % nranks_per_node == other_chunkrefs[i].rank % nranks_per_node
), "Group operations only supports chunks on the same node"

def _assert_same_index(self, other_chunkrefs):
for i in range(len(other_chunkrefs)):
assert self.index == other_chunkrefs[i].index, "Group operations only supports chunks with the same index"

def _assert_same_buffer(self, other_chunkrefs):
for i in range(len(other_chunkrefs)):
assert (
self.buffer == other_chunkrefs[i].buffer
), "Group operations only supports chunks with the same buffer"

def _group_load_reduce(self, other_chunkrefs: list, recvtb=-1):
# may need to check if sharp supported in topologies
pass

def _group_store(self, other_chunkrefs: list, sendtb=-1):
pass

# Reads the chunk(s) referenced by other_chunkref and reduce into the chunk referenced by this chunkref
def group_load_reduce(self, other_chunkrefs: list, recvtb: int, chan_type=ChannelType.nvls):
def group_load_reduce(self, other_chunkrefs: list, recvtb=-1, chan_type=ChannelType.nvls):
assert (
len(other_chunkrefs) > 0 and chan_type == ChannelType.nvls
), "Group load reduce only supports nvls channel"
self._assert_same_node(other_chunkrefs)
self._assert_same_index(other_chunkrefs)
self._assert_same_buffer(other_chunkrefs)
self._group_load_reduce(other_chunkrefs, recvtb)
nranks_per_node = self.prog.collective.num_ranks_per_node
for other_chunkref in other_chunkrefs:
assert (
self.rank // nranks_per_node == other_chunkref.rank // nranks_per_node
), "Group load reduce only supports chunks on the same node"
assert self.buffer == other_chunkref.buffer, "Group load reduce only supports chunks with the same buffer"
assert self.index == other_chunkref.index, "Group load reduce only supports chunks with the same index"

src_chunkref = other_chunkref
self.prog.apply_reduce(
src_chunkref.rank,
src_chunkref.buffer,
src_chunkref.index,
self.rank,
self.buffer,
self.index,
self.size,
)
self.prog.instr_dag.add_group_load_reduce(self.rank, other_chunkrefs, self, recvtb, chan_type)
return self

# Copies the chunk(s) referenced by this chunkref onto other_chunkrefs
def group_store(self, dsts: list, index=-1, buffer=None, sendtb=-1, chan_type=ChannelType.nvls):
Expand All @@ -380,14 +372,14 @@ def group_store(self, dsts: list, index=-1, buffer=None, sendtb=-1, chan_type=Ch
assert self.prog.topo.link(self.rank, dst) or dst == self.rank, f"No link from {self.rank} to {dst}"
assert self.buffer == buffer, "Group store only supports chunks with the same buffer"
assert (
self.rank % nrank_per_node == dst % nrank_per_node
self.rank // nrank_per_node == dst // nrank_per_node
), "Group store only supports chunks on the same node"

dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size)
other_chunkrefs.append(dst_chunkref)
# add new op here
pass
self.prog.instr_dag.add_group_store(self.rank, self, other_chunkrefs, sendtb, chan_type)

def get_origin_index(self, index=0):
return self._get_chunk(index + self.index).origin_index
Expand Down
43 changes: 43 additions & 0 deletions msccl/language/mscclpp/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,56 @@ def add_read_reduce(self, rank, send_ref, recv_ref, tb, ch_type):
self._write(rank, buffer, index, size, op, read=True)
return op

def add_group_load_reduce(self, rank, send_refs, recv_ref, tb, ch_type):
tb_step = self._get_tb_step(rank, tb)
op = Op(
Instruction.group_load_reduce,
rank,
None,
recv_ref,
next=set(),
prev=set(),
tb=tb,
channel_type=ch_type,
step=tb_step,
)
for send_ref in send_refs:
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
buffer = recv_ref.buffer
index = recv_ref.index
size = recv_ref.size
self._write(rank, buffer, index, size, op, read=True)

def add_group_store(self, rank, send_ref, recv_refs, tb, ch_type):
tb_step = self._get_tb_step(rank, tb)
op = Op(
Instruction.group_store,
rank,
send_ref,
None,
next=set(),
prev=set(),
tb=tb,
channel_type=ch_type,
step=tb_step,
)
for recv_ref in recv_refs:
op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step))
buffer = send_ref.buffer
index = send_ref.index
size = send_ref.size
self._read(rank, buffer, index, size, op)
return op

def complete_channels(self):
send_op = [Instruction.put, Instruction.signal, Instruction.put_packet]
recv_op = [Instruction.wait, Instruction.get, Instruction.read_reduce_copy]
for rank, rank_tbs in enumerate(self.tbs):
for tbid, tb in rank_tbs.items():
chans = set()
for op in tb.ops:
if op.channel_type == ChannelType.none or op.channel_type == ChannelType.nvls:
continue
src_buffer = (
Buffer.scratch
if op.src.buffer is not Buffer.input and op.src.buffer is not Buffer.output
Expand Down
3 changes: 3 additions & 0 deletions msccl/language/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ class MscclppInstruction(Enum):
wait = "wait"
signal = "signal"
flush = "flush"
group_store = "gstore"
group_load_reduce = "glre"
group_load_reduce_store = "glres"

def __str__(self):
return self.value
Expand Down

0 comments on commit 4b30b54

Please sign in to comment.