Skip to content

Commit

Permalink
Update for NVLS OP (#24)
Browse files Browse the repository at this point in the history
Support NVLS in msccl-tools
  • Loading branch information
Binyang2014 authored Nov 26, 2024
1 parent 79ed5ae commit c7ba9b0
Show file tree
Hide file tree
Showing 7 changed files with 315 additions and 33 deletions.
49 changes: 49 additions & 0 deletions examples/mscclang/mscclpp/allreduce_nvls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
from msccl.language import *
from msccl.topologies import *
from msccl.language.collectives import AllReduce


def allreduce_allpairs(gpus, instances):
size = gpus
chunksperloop = gpus
topology = fully_connected(size)
collective = AllReduce(size, chunksperloop, True)
with MSCCLPPProgram(
"allreduce_nvls",
topology,
collective,
instances,
):
# Each rank sends the nth chunk to the nth rank into scratch space
for rank in range(size):
index = rank
c = chunk(rank, Buffer.input, index)
reduce_chunks = []
# make sure the data is ready
for nghr in range(size):
if rank != nghr:
c_peer = chunk(nghr, Buffer.input, index)
reduce_chunks.append(c_peer)
c.signal(nghr, Buffer.input, index, sendtb=0)
for nghr in range(size):
if rank != nghr:
c.wait(nghr, Buffer.input, index, recvtb=0)
c = c.group_load_reduce(reduce_chunks, recvtb=0)
ngbrs = [nghr for nghr in range(size) if nghr != rank]
c.group_store(ngbrs, sendtb=0)

Json()
Check()


parser = argparse.ArgumentParser()
parser.add_argument("num_gpus", type=int, help="number of gpus")
parser.add_argument("instances", type=int, help="number of instances")

args = parser.parse_args()

allreduce_allpairs(args.num_gpus, args.instances)
26 changes: 19 additions & 7 deletions msccl/language/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,24 @@


class Collective:
def __init__(self, num_ranks, chunk_factor, inplace, num_chunk_groups=1):

def __init__(self, num_ranks, chunk_factor, inplace, num_ranks_per_node=-1, **kwargs):
self.num_ranks = num_ranks
self.chunk_factor = chunk_factor
self.inplace = inplace
self.name = "custom"
# Devide the buffer into num_chunk_groups groups
self.num_chunk_groups = num_chunk_groups
# Devide the buffer into num_chunk_groups group
if num_ranks_per_node == -1:
self.num_ranks_per_node = num_ranks
else:
self.num_ranks_per_node = num_ranks_per_node

# kwargs
# Number of chunk groups: which means we will group n chunks into m groups.
# We will gurantee that the group size is the same.
# But in the same group, the chunk size may be different due to group size
# can not be divided by the number of chunks.
self.num_chunk_groups = kwargs.get("num_chunk_groups", 1)

def init_buffers(self):
pass
Expand Down Expand Up @@ -120,10 +131,11 @@ def get_buffer_index(self, rank, buffer, index):

class AllReduce(Collective):

def __init__(self, num_ranks, chunk_factor, inplace, num_chunk_groups=None):
if num_chunk_groups == None:
num_chunk_groups = num_ranks
Collective.__init__(self, num_ranks, chunk_factor, inplace, num_chunk_groups)
def __init__(self, num_ranks, chunk_factor, inplace, num_ranks_per_node=-1, **kwargs):
num_chunk_groups = kwargs.get('num_chunk_groups', num_ranks)
Collective.__init__(
self, num_ranks, chunk_factor, inplace, num_ranks_per_node, num_chunk_groups=num_chunk_groups
)
self.name = "allreduce"

def init_buffers(self):
Expand Down
61 changes: 60 additions & 1 deletion msccl/language/mscclpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def lower(self):
self.instr_dag.optimize()
self.instr_dag.lower_pt1(self.instances)
gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.replication_policy)
return Program(
program = Program(
self.name,
self.collective.name,
self.collective.inplace,
Expand All @@ -142,6 +142,11 @@ def lower(self):
self.num_threads_per_block,
self.use_double_scratch_buffer,
)
for gpu in program.gpus:
gpu.input_chunks = len(self.buffers[gpu.rank][Buffer.input]) * self.instances
if not self.collective.inplace:
gpu.output_chunks = len(self.buffers[gpu.rank][Buffer.output]) * self.instances
return program

def generate_json(self):
return ir_to_json(self.lower())
Expand Down Expand Up @@ -327,6 +332,60 @@ def reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.sm):
def reduce_packet(self, other_chunkref, recvtb=-1):
return self._reduce(other_chunkref, recvtb, use_packet=True)

# """
# Group operations. These operations are used to perform collective operations across multiple chunks.
# For now, all chunks must has the same buffer type and offset.
# """
# Reads the chunk(s) referenced by other_chunkref and reduce into the chunk referenced by this chunkref
def group_load_reduce(self, other_chunkrefs: list, recvtb=-1, chan_type=ChannelType.nvls):
assert (
len(other_chunkrefs) > 0 and chan_type == ChannelType.nvls
), "Group load reduce only supports nvls channel"
nranks_per_node = self.prog.collective.num_ranks_per_node
for other_chunkref in other_chunkrefs:
assert (
self.rank // nranks_per_node == other_chunkref.rank // nranks_per_node
), "Group load reduce only supports chunks on the same node"
assert self.buffer == other_chunkref.buffer, "Group load reduce only supports chunks with the same buffer"
assert self.index == other_chunkref.index, "Group load reduce only supports chunks with the same index"

src_chunkref = other_chunkref
self.prog.apply_reduce(
src_chunkref.rank,
src_chunkref.buffer,
src_chunkref.index,
self.rank,
self.buffer,
self.index,
self.size,
)
self.prog.instr_dag.add_group_load_reduce(self.rank, other_chunkrefs, self, recvtb, chan_type)
return self

# Copies the chunk(s) referenced by this chunkref onto other_chunkrefs
def group_store(self, dsts: list, index=-1, buffer=None, sendtb=-1, chan_type=ChannelType.nvls):
for dst in dsts:
self.prog.check_buffer_exists(dst, buffer)
assert index == -1 or self.index == index, "Group store only supports chunks with the same index"
assert chan_type == ChannelType.nvls, "Group store only supports nvls channel"

other_chunkrefs = []
nrank_per_node = self.prog.collective.num_ranks_per_node
for dst in dsts:
# Direct linked
buffer, index = self._get_buffer_index(dst, buffer, index)
assert self.prog.topo.link(self.rank, dst) or dst == self.rank, f"No link from {self.rank} to {dst}"
assert self.buffer == buffer, "Group store only supports chunks with the same buffer"
assert (
self.rank // nrank_per_node == dst // nrank_per_node
), "Group store only supports chunks on the same node"

dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size)
other_chunkrefs.append(dst_chunkref)
# add new op here
self.prog.instr_dag.add_group_store(self.rank, self, other_chunkrefs, sendtb, chan_type)

def get_origin_index(self, index=0):
return self._get_chunk(index + self.index).origin_index

Expand Down
113 changes: 97 additions & 16 deletions msccl/language/mscclpp/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,29 +191,88 @@ def add_read_reduce(self, rank, send_ref, recv_ref, tb, ch_type):
self._write(rank, buffer, index, size, op, read=True)
return op

def add_group_load_reduce(self, rank, send_refs, recv_ref, tb, ch_type):
tb_step = self._get_tb_step(rank, tb)
op = Op(
Instruction.group_load_reduce,
rank,
recv_ref,
recv_ref,
next=set(),
prev=set(),
tb=tb,
channel_type=ch_type,
step=tb_step,
)
# treat recv_ref as src for group_load_reduce
op.srcs.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step))
for send_ref in send_refs:
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
buffer = recv_ref.buffer
index = recv_ref.index
size = recv_ref.size
self._write(rank, buffer, index, size, op, read=True)

def add_group_store(self, rank, send_ref, recv_refs, tb, ch_type):
tb_step = self._get_tb_step(rank, tb)
op = Op(
Instruction.group_store,
rank,
send_ref,
send_ref,
next=set(),
prev=set(),
tb=tb,
channel_type=ch_type,
step=tb_step,
)
# treat send_ref as dst for group_store
op.dsts.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
for recv_ref in recv_refs:
op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step))
buffer = send_ref.buffer
index = send_ref.index
size = send_ref.size
self._read(rank, buffer, index, size, op)
return op

def complete_channels(self):
send_op = [Instruction.put, Instruction.signal, Instruction.put_packet]
recv_op = [Instruction.wait, Instruction.get, Instruction.read_reduce_copy]
group_send_op = [Instruction.group_store]
group_recv_op = [Instruction.group_load_reduce]
for rank, rank_tbs in enumerate(self.tbs):
for tbid, tb in rank_tbs.items():
chans = set()
for op in tb.ops:
src_buffer = (
Buffer.scratch
if op.src.buffer is not Buffer.input and op.src.buffer is not Buffer.output
else op.src.buffer
)
dst_buffer = (
Buffer.scratch
if op.dst.buffer is not Buffer.input and op.dst.buffer is not Buffer.output
else op.dst.buffer
)
if op.inst in send_op:
chan = Channel(src_buffer, dst_buffer, op.channel_type, op.dst.rank)
chans.add(chan)
elif op.inst in recv_op:
chan = Channel(src_buffer, dst_buffer, op.channel_type, op.src.rank)
chans.add(chan)
if op.src != None:
src_buffer = (
Buffer.scratch
if op.src.buffer is not Buffer.input and op.src.buffer is not Buffer.output
else op.src.buffer
)
if op.dst != None:
dst_buffer = (
Buffer.scratch
if op.dst.buffer is not Buffer.input and op.dst.buffer is not Buffer.output
else op.dst.buffer
)
if op.channel_type == ChannelType.nvls:
if op.inst in group_send_op:
ranks = [dst[0].rank for dst in op.dsts]
chan = Channel(src_buffer, dst_buffer, op.channel_type, ranks)
chans.add(chan)
elif op.inst in group_recv_op:
ranks = [src[0].rank for src in op.srcs]
chan = Channel(src_buffer, dst_buffer, op.channel_type, ranks)
chans.add(chan)
else:
if op.inst in send_op:
chan = Channel(src_buffer, dst_buffer, op.channel_type, op.dst.rank)
chans.add(chan)
elif op.inst in recv_op:
chan = Channel(src_buffer, dst_buffer, op.channel_type, op.src.rank)
chans.add(chan)
tb.channels = list(chans)

def remove_redundant_signal_wait(self):
Expand Down Expand Up @@ -326,6 +385,27 @@ def _optimize_rrcs_rs(self):
continue
queue = queue[1:]

# glre(srcs, sbuf, si, _, _, _), gstore (_, _, _, dsts, dbuf, di) -> glres(srcs, sbuf, si, dsts, dbuf, di)
def _optimize_group_ops(self):
optimizer = InstructionOptimizer()
inst_types = [
Instruction.group_load_reduce,
]
for _, rank_tbs in enumerate(self.tbs):
for _, tb in rank_tbs.items():
queue = list(tb.ops)
while len(queue) > 0:
op = queue[0]
fused = False
if op.inst in inst_types:
for next_op in op.next:
fused = optimizer.try_fuse_with_group_store(op, next_op, tb, queue)
if fused:
break
if fused:
continue
queue = queue[1:]

# merge ops which are independent of other operations and no other operations in between
# get(src, sbuf. si, dst, dbuf, di) get(src, sbuf, si, dst, dbuf, di) -> get(list[src,sbuf,si], list[dst,dbuf,di])
# put(src, sbuf, si, dst, dbuf, di) put(src, sbuf, si, dst, dbuf, di) -> put(list[src,sbuf,si], list[dst,dbuf,di])
Expand Down Expand Up @@ -372,6 +452,7 @@ def optimize(self):
self._fuse_instructions_using_proxy_channel()
self._fuse_same_instructions()
self._optimize_rrcs_rs()
self._optimize_group_ops()
self._compact_instructions()

def replicate(self, instances: int, replication_policy: ReplicationPolicy):
Expand Down
29 changes: 29 additions & 0 deletions msccl/language/mscclpp/instruction_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,35 @@ def try_fuse_instructions_using_proxy_channel(
return True
return False

def try_fuse_with_group_store(self, op: Op, next_op: Op, tb: Threadblock, queue: list) -> bool:
"""
Attempts to fuse 'gruop_load_reduce' operations with 'group_store' operations.
:param op: The current operation.
:param next_op: The next operation to potentially merge with.
:param tb: The thread block containing the operations.
:param queue: The queue of operations.
:return: True if operations are merged, False otherwise.
"""
if (
next_op.inst == Instruction.group_store
and same_count(op, next_op)
and buf_dst_src_match(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
and all_prevs_visited_after_merge(op, next_op)
):
# Append the destination chunk from next_op
op.inst = Instruction.group_load_reduce_store
op.src = next_op.src
for dst in next_op.dsts:
op.dsts.append(dst)
# Merge operations
merge_op(op, next_op)
tb.ops.remove(next_op)
queue.remove(next_op)
return True
return False

def try_remove_op(self, pending_remove_op: Op, condition: bool) -> bool:
if condition:
remove_op(pending_remove_op)
Expand Down
Loading

0 comments on commit c7ba9b0

Please sign in to comment.