diff --git a/examples/mscclang/mscclpp/allreduce_nvls.py b/examples/mscclang/mscclpp/allreduce_nvls.py new file mode 100644 index 0000000..35d5ea2 --- /dev/null +++ b/examples/mscclang/mscclpp/allreduce_nvls.py @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +from msccl.language import * +from msccl.topologies import * +from msccl.language.collectives import AllReduce + + +def allreduce_allpairs(gpus, instances): + size = gpus + chunksperloop = gpus + topology = fully_connected(size) + collective = AllReduce(size, chunksperloop, True) + with MSCCLPPProgram( + "allreduce_nvls", + topology, + collective, + instances, + ): + # Each rank sends the nth chunk to the nth rank into scratch space + for rank in range(size): + index = rank + c = chunk(rank, Buffer.input, index) + reduce_chunks = [] + # make sure the data is ready + for nghr in range(size): + if rank != nghr: + c_peer = chunk(nghr, Buffer.input, index) + reduce_chunks.append(c_peer) + c.signal(nghr, Buffer.input, index, sendtb=0) + for nghr in range(size): + if rank != nghr: + c.wait(nghr, Buffer.input, index, recvtb=0) + c = c.group_load_reduce(reduce_chunks, recvtb=0) + ngbrs = [nghr for nghr in range(size) if nghr != rank] + c.group_store(ngbrs, sendtb=0) + + Json() + Check() + + +parser = argparse.ArgumentParser() +parser.add_argument("num_gpus", type=int, help="number of gpus") +parser.add_argument("instances", type=int, help="number of instances") + +args = parser.parse_args() + +allreduce_allpairs(args.num_gpus, args.instances) diff --git a/msccl/language/collectives.py b/msccl/language/collectives.py index 6dc20cc..c8434da 100755 --- a/msccl/language/collectives.py +++ b/msccl/language/collectives.py @@ -4,13 +4,24 @@ class Collective: - def __init__(self, num_ranks, chunk_factor, inplace, num_chunk_groups=1): + + def __init__(self, num_ranks, chunk_factor, inplace, num_ranks_per_node=-1, **kwargs): self.num_ranks = num_ranks self.chunk_factor = chunk_factor self.inplace = inplace self.name = "custom" - # Devide the buffer into num_chunk_groups groups - self.num_chunk_groups = num_chunk_groups + # Devide the buffer into num_chunk_groups group + if num_ranks_per_node == -1: + self.num_ranks_per_node = num_ranks + else: + self.num_ranks_per_node = num_ranks_per_node + + # kwargs + # Number of chunk groups: which means we will group n chunks into m groups. + # We will gurantee that the group size is the same. + # But in the same group, the chunk size may be different due to group size + # can not be divided by the number of chunks. + self.num_chunk_groups = kwargs.get("num_chunk_groups", 1) def init_buffers(self): pass @@ -120,10 +131,11 @@ def get_buffer_index(self, rank, buffer, index): class AllReduce(Collective): - def __init__(self, num_ranks, chunk_factor, inplace, num_chunk_groups=None): - if num_chunk_groups == None: - num_chunk_groups = num_ranks - Collective.__init__(self, num_ranks, chunk_factor, inplace, num_chunk_groups) + def __init__(self, num_ranks, chunk_factor, inplace, num_ranks_per_node=-1, **kwargs): + num_chunk_groups = kwargs.get('num_chunk_groups', num_ranks) + Collective.__init__( + self, num_ranks, chunk_factor, inplace, num_ranks_per_node, num_chunk_groups=num_chunk_groups + ) self.name = "allreduce" def init_buffers(self): diff --git a/msccl/language/mscclpp/__init__.py b/msccl/language/mscclpp/__init__.py index 9f00ad4..13ea870 100644 --- a/msccl/language/mscclpp/__init__.py +++ b/msccl/language/mscclpp/__init__.py @@ -132,7 +132,7 @@ def lower(self): self.instr_dag.optimize() self.instr_dag.lower_pt1(self.instances) gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.replication_policy) - return Program( + program = Program( self.name, self.collective.name, self.collective.inplace, @@ -142,6 +142,11 @@ def lower(self): self.num_threads_per_block, self.use_double_scratch_buffer, ) + for gpu in program.gpus: + gpu.input_chunks = len(self.buffers[gpu.rank][Buffer.input]) * self.instances + if not self.collective.inplace: + gpu.output_chunks = len(self.buffers[gpu.rank][Buffer.output]) * self.instances + return program def generate_json(self): return ir_to_json(self.lower()) @@ -327,6 +332,60 @@ def reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.sm): def reduce_packet(self, other_chunkref, recvtb=-1): return self._reduce(other_chunkref, recvtb, use_packet=True) + # """ + # Group operations. These operations are used to perform collective operations across multiple chunks. + # For now, all chunks must has the same buffer type and offset. + # """ + # Reads the chunk(s) referenced by other_chunkref and reduce into the chunk referenced by this chunkref + def group_load_reduce(self, other_chunkrefs: list, recvtb=-1, chan_type=ChannelType.nvls): + assert ( + len(other_chunkrefs) > 0 and chan_type == ChannelType.nvls + ), "Group load reduce only supports nvls channel" + nranks_per_node = self.prog.collective.num_ranks_per_node + for other_chunkref in other_chunkrefs: + assert ( + self.rank // nranks_per_node == other_chunkref.rank // nranks_per_node + ), "Group load reduce only supports chunks on the same node" + assert self.buffer == other_chunkref.buffer, "Group load reduce only supports chunks with the same buffer" + assert self.index == other_chunkref.index, "Group load reduce only supports chunks with the same index" + + src_chunkref = other_chunkref + self.prog.apply_reduce( + src_chunkref.rank, + src_chunkref.buffer, + src_chunkref.index, + self.rank, + self.buffer, + self.index, + self.size, + ) + self.prog.instr_dag.add_group_load_reduce(self.rank, other_chunkrefs, self, recvtb, chan_type) + return self + + # Copies the chunk(s) referenced by this chunkref onto other_chunkrefs + def group_store(self, dsts: list, index=-1, buffer=None, sendtb=-1, chan_type=ChannelType.nvls): + for dst in dsts: + self.prog.check_buffer_exists(dst, buffer) + assert index == -1 or self.index == index, "Group store only supports chunks with the same index" + assert chan_type == ChannelType.nvls, "Group store only supports nvls channel" + + other_chunkrefs = [] + nrank_per_node = self.prog.collective.num_ranks_per_node + for dst in dsts: + # Direct linked + buffer, index = self._get_buffer_index(dst, buffer, index) + assert self.prog.topo.link(self.rank, dst) or dst == self.rank, f"No link from {self.rank} to {dst}" + assert self.buffer == buffer, "Group store only supports chunks with the same buffer" + assert ( + self.rank // nrank_per_node == dst // nrank_per_node + ), "Group store only supports chunks on the same node" + + dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) + self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) + other_chunkrefs.append(dst_chunkref) + # add new op here + self.prog.instr_dag.add_group_store(self.rank, self, other_chunkrefs, sendtb, chan_type) + def get_origin_index(self, index=0): return self._get_chunk(index + self.index).origin_index diff --git a/msccl/language/mscclpp/instruction_dag.py b/msccl/language/mscclpp/instruction_dag.py index 6f20abf..0f692bf 100644 --- a/msccl/language/mscclpp/instruction_dag.py +++ b/msccl/language/mscclpp/instruction_dag.py @@ -191,29 +191,88 @@ def add_read_reduce(self, rank, send_ref, recv_ref, tb, ch_type): self._write(rank, buffer, index, size, op, read=True) return op + def add_group_load_reduce(self, rank, send_refs, recv_ref, tb, ch_type): + tb_step = self._get_tb_step(rank, tb) + op = Op( + Instruction.group_load_reduce, + rank, + recv_ref, + recv_ref, + next=set(), + prev=set(), + tb=tb, + channel_type=ch_type, + step=tb_step, + ) + # treat recv_ref as src for group_load_reduce + op.srcs.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step)) + for send_ref in send_refs: + op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) + buffer = recv_ref.buffer + index = recv_ref.index + size = recv_ref.size + self._write(rank, buffer, index, size, op, read=True) + + def add_group_store(self, rank, send_ref, recv_refs, tb, ch_type): + tb_step = self._get_tb_step(rank, tb) + op = Op( + Instruction.group_store, + rank, + send_ref, + send_ref, + next=set(), + prev=set(), + tb=tb, + channel_type=ch_type, + step=tb_step, + ) + # treat send_ref as dst for group_store + op.dsts.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) + for recv_ref in recv_refs: + op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step)) + buffer = send_ref.buffer + index = send_ref.index + size = send_ref.size + self._read(rank, buffer, index, size, op) + return op + def complete_channels(self): send_op = [Instruction.put, Instruction.signal, Instruction.put_packet] recv_op = [Instruction.wait, Instruction.get, Instruction.read_reduce_copy] + group_send_op = [Instruction.group_store] + group_recv_op = [Instruction.group_load_reduce] for rank, rank_tbs in enumerate(self.tbs): for tbid, tb in rank_tbs.items(): chans = set() for op in tb.ops: - src_buffer = ( - Buffer.scratch - if op.src.buffer is not Buffer.input and op.src.buffer is not Buffer.output - else op.src.buffer - ) - dst_buffer = ( - Buffer.scratch - if op.dst.buffer is not Buffer.input and op.dst.buffer is not Buffer.output - else op.dst.buffer - ) - if op.inst in send_op: - chan = Channel(src_buffer, dst_buffer, op.channel_type, op.dst.rank) - chans.add(chan) - elif op.inst in recv_op: - chan = Channel(src_buffer, dst_buffer, op.channel_type, op.src.rank) - chans.add(chan) + if op.src != None: + src_buffer = ( + Buffer.scratch + if op.src.buffer is not Buffer.input and op.src.buffer is not Buffer.output + else op.src.buffer + ) + if op.dst != None: + dst_buffer = ( + Buffer.scratch + if op.dst.buffer is not Buffer.input and op.dst.buffer is not Buffer.output + else op.dst.buffer + ) + if op.channel_type == ChannelType.nvls: + if op.inst in group_send_op: + ranks = [dst[0].rank for dst in op.dsts] + chan = Channel(src_buffer, dst_buffer, op.channel_type, ranks) + chans.add(chan) + elif op.inst in group_recv_op: + ranks = [src[0].rank for src in op.srcs] + chan = Channel(src_buffer, dst_buffer, op.channel_type, ranks) + chans.add(chan) + else: + if op.inst in send_op: + chan = Channel(src_buffer, dst_buffer, op.channel_type, op.dst.rank) + chans.add(chan) + elif op.inst in recv_op: + chan = Channel(src_buffer, dst_buffer, op.channel_type, op.src.rank) + chans.add(chan) tb.channels = list(chans) def remove_redundant_signal_wait(self): @@ -326,6 +385,27 @@ def _optimize_rrcs_rs(self): continue queue = queue[1:] + # glre(srcs, sbuf, si, _, _, _), gstore (_, _, _, dsts, dbuf, di) -> glres(srcs, sbuf, si, dsts, dbuf, di) + def _optimize_group_ops(self): + optimizer = InstructionOptimizer() + inst_types = [ + Instruction.group_load_reduce, + ] + for _, rank_tbs in enumerate(self.tbs): + for _, tb in rank_tbs.items(): + queue = list(tb.ops) + while len(queue) > 0: + op = queue[0] + fused = False + if op.inst in inst_types: + for next_op in op.next: + fused = optimizer.try_fuse_with_group_store(op, next_op, tb, queue) + if fused: + break + if fused: + continue + queue = queue[1:] + # merge ops which are independent of other operations and no other operations in between # get(src, sbuf. si, dst, dbuf, di) get(src, sbuf, si, dst, dbuf, di) -> get(list[src,sbuf,si], list[dst,dbuf,di]) # put(src, sbuf, si, dst, dbuf, di) put(src, sbuf, si, dst, dbuf, di) -> put(list[src,sbuf,si], list[dst,dbuf,di]) @@ -372,6 +452,7 @@ def optimize(self): self._fuse_instructions_using_proxy_channel() self._fuse_same_instructions() self._optimize_rrcs_rs() + self._optimize_group_ops() self._compact_instructions() def replicate(self, instances: int, replication_policy: ReplicationPolicy): diff --git a/msccl/language/mscclpp/instruction_optimizer.py b/msccl/language/mscclpp/instruction_optimizer.py index f36303b..14ddf1a 100644 --- a/msccl/language/mscclpp/instruction_optimizer.py +++ b/msccl/language/mscclpp/instruction_optimizer.py @@ -187,6 +187,35 @@ def try_fuse_instructions_using_proxy_channel( return True return False + def try_fuse_with_group_store(self, op: Op, next_op: Op, tb: Threadblock, queue: list) -> bool: + """ + Attempts to fuse 'gruop_load_reduce' operations with 'group_store' operations. + :param op: The current operation. + :param next_op: The next operation to potentially merge with. + :param tb: The thread block containing the operations. + :param queue: The queue of operations. + :return: True if operations are merged, False otherwise. + """ + if ( + next_op.inst == Instruction.group_store + and same_count(op, next_op) + and buf_dst_src_match(op, next_op) + and same_chan_type(op, next_op) + and not circular_dep_after_merge(op, next_op) + and all_prevs_visited_after_merge(op, next_op) + ): + # Append the destination chunk from next_op + op.inst = Instruction.group_load_reduce_store + op.src = next_op.src + for dst in next_op.dsts: + op.dsts.append(dst) + # Merge operations + merge_op(op, next_op) + tb.ops.remove(next_op) + queue.remove(next_op) + return True + return False + def try_remove_op(self, pending_remove_op: Op, condition: bool) -> bool: if condition: remove_op(pending_remove_op) diff --git a/msccl/language/mscclpp/ir.py b/msccl/language/mscclpp/ir.py index f5ba0fd..3eaa651 100644 --- a/msccl/language/mscclpp/ir.py +++ b/msccl/language/mscclpp/ir.py @@ -20,6 +20,8 @@ Instruction.reduce_packet, Instruction.reduce_send, Instruction.reduce_send_packet, + Instruction.group_load_reduce_store, + Instruction.group_store, } _local_dst_insts_mscclpp: set = { Instruction.get, @@ -33,6 +35,8 @@ Instruction.reduce_send, Instruction.reduce_packet, Instruction.reduce_send_packet, + Instruction.group_load_reduce_store, + Instruction.group_load_reduce, } _insts_no_need_sync_barrier: set = { @@ -148,20 +152,32 @@ def dump_to_json(program: Program): def get_channel_ids(chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_type): channel_ids = [] - for c in chunk_list: - key = (src_buffer, dst_buffer, chan_type) + key = (src_buffer, dst_buffer, chan_type) + if chan_type == ChannelType.nvls: + ranks = [] + for c in chunk_list: + ranks.append(c.rank) channel_ids.extend( - [ - {"id": id, "off": c.index} - for id, ele in enumerate(tb_channel_dict[key]["connectedTo"]) - if ele == c.rank - ] + [{"id": id} for id, ele in enumerate(tb_channel_dict[key]["connectedTo"]) if set(ele) == set(ranks)] ) + else: + for c in chunk_list: + channel_ids.extend( + [ + {"id": id, "off": c.index} + for id, ele in enumerate(tb_channel_dict[key]["connectedTo"]) + if ele == c.rank + ] + ) return channel_ids def remove_empty_fields(d): return {k: v for k, v in d.items() if v not in [None, "", [], {}]} + max_scratch = max(gpu.scratch_chunks for gpu in program.gpus) + max_input = max(gpu.input_chunks for gpu in program.gpus) + max_output = max(gpu.output_chunks for gpu in program.gpus) + for id, gpu in enumerate(program.gpus): gpu_instance = { "id": id, @@ -179,9 +195,27 @@ def remove_empty_fields(d): "type": type.value, "connectedTo": [eles[1] for eles in channels], } + if type == ChannelType.nvls: + obj["connectedTo"] = [sorted(list(eles)) for eles in obj["connectedTo"]] gpu_instance["channels"].append(obj) gpu_instance["channels"] = list(filter(lambda x: x["type"] != "none", gpu_instance["channels"])) gpu_instance["channels"] = sorted(gpu_instance["channels"], key=lambda x: (x["srcbuff"], x["dstbuff"])) + + # render for GPU NVLS channels + for i, chan in enumerate(gpu_instance["channels"]): + if chan["type"] == "nvls": + buff = chan["srcbuff"] + buffer_size = ( + max_input + if buff == Buffer.input.value + else max_output if buff == Buffer.output.value else max_scratch + ) + gpu_instance["channels"][i] = { + "buff": chan["srcbuff"], + "type": chan["type"], + "rankGroups": [{"size": buffer_size, "ranks": ranks} for ranks in chan["connectedTo"]], + } + for tb in gpu.threadblocks: if tb.id < 0: continue @@ -282,6 +316,15 @@ def remove_empty_fields(d): ): src = op.src dst = op.dst + elif op.inst == Instruction.group_load_reduce_store: + src = op.src + dst = op.dst + src_channel_ids = get_channel_ids( + op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type + ) + dst_channel_ids = get_channel_ids( + op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type + ) if op.inst != Instruction.nop: instr = { "name": op.inst.value, diff --git a/msccl/language/types.py b/msccl/language/types.py index 9a8e676..ef2addd 100644 --- a/msccl/language/types.py +++ b/msccl/language/types.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field from enum import Enum -from typing import Union +from typing import Union, List from msccl.language.buffer import Buffer @@ -124,6 +124,9 @@ class MscclppInstruction(Enum): wait = "wait" signal = "signal" flush = "flush" + group_store = "gstore" + group_load_reduce = "glre" + group_load_reduce_store = "glres" def __str__(self): return self.value @@ -144,6 +147,7 @@ class ChannelType(Enum): proxy = "proxy" sm = "sm" none = "none" + nvls = "nvls" def __str__(self): return self.value @@ -154,7 +158,12 @@ class Channel: srcBuffer: Buffer dstBuffer: Buffer type: ChannelType - connected_to: int + connected_to: Union[int, List[int]] + + def __hash__(self): + # Ensure connected_to is converted to a tuple if it's a list + connected_to_hashable = tuple(self.connected_to) if isinstance(self.connected_to, list) else self.connected_to + return hash((self.srcBuffer, self.dstBuffer, self.type, connected_to_hashable)) @dataclass