From 58b7d9f16eb22af7e90cdbde02ada135abc1f853 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 20 Mar 2024 08:17:22 +0000 Subject: [PATCH 01/76] integration branch --- msccl/language/__init__.py | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/msccl/language/__init__.py b/msccl/language/__init__.py index 1473849..16092f9 100755 --- a/msccl/language/__init__.py +++ b/msccl/language/__init__.py @@ -26,7 +26,7 @@ def __init__(self, name, topo, collective, instances, protocol='Simple', \ instr_fusion=True, check_xml=True, dependence_nop=False): self.name = name self.topo = topo - self.collective = collective + self.collective = collective self.num_ranks = topo.num_nodes() self.instances = instances self.protocol = protocol @@ -54,7 +54,7 @@ def __enter__(self): if _current_program != None: raise RuntimeError("There is already a MSCCL Program in context") _current_program = self - + def __exit__(self, exc_type, exc_value, exc_traceback): global _current_program if _current_program != self: @@ -120,14 +120,14 @@ def lower(self): gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.interleaved_replication) if self.check_xml: # Check generated MSCCL-IR for correctness - no circular dependencies, sends and receives are ordered - # For very large programs, turn off check_xml when shipping + # For very large programs, turn off check_xml when shipping check_dependency_cycles(self.instr_dag.tbs) check_threadblock_ordering(self.instr_dag) - return Program(self.name, self.collective.name, self.collective.inplace, self.protocol, gpu_prgms) + return Program(self.name, self.collective.name, self.collective.inplace, self.protocol, gpu_prgms) def generate_xml(self): return ir_to_xml(self.lower(), dependence_nop=self.dependence_nop) - + def print_chunk_dag(self): visualize_chunk_dag(self.chunk_dag.chunk_paths) @@ -189,7 +189,13 @@ def group(self, other): end = max(first._end(), second._end()) return Ref(self.rank, self.buffer, first.index, end - first.index, self.prog) - + + def put(self, dst, buffer=None, index=-1, sendtb=-1): + self.prog.check_buffer_exists(dst, buffer) + + def get(self, src, buffer=None, index=-1, recvtb=-1): + self.prog.check_buffer_exists(src, buffer) + # Copies the chunk(s) referenced by this chunkref onto Rank dst at location (buffer, index) def copy(self, dst, buffer=None, index=-1, sendtb=-1, recvtb=-1, ch=-1): self.prog.check_buffer_exists(dst, buffer) @@ -214,7 +220,7 @@ def copy(self, dst, buffer=None, index=-1, sendtb=-1, recvtb=-1, ch=-1): # chunks = self.prog.get_chunks(self.rank, self.buffer, self.index, self.size) # overwritten_chunks = self.prog.get_chunks(dst, buffer, index, self.size) - + self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) # self.prog.chunk_dag.add_send(chunks, overwritten_chunks, self, dst_chunkref, sendtb, recvtb, ch) @@ -266,7 +272,7 @@ def get_dst_rank(self, index=0): return self._get_chunk(index + self.index).dst_rank def print_chunk_info(self, index=0): - print(self._get_chunk(index + self.index)) + print(self._get_chunk(index + self.index)) # @dataclass @@ -278,7 +284,7 @@ def print_chunk_info(self, index=0): # recvtb: int = -1# For lowering to RankInstructions # ch: int = -1 # For lowering to RankInstructions # steps_from_start:int = -1 -# steps_to_end: int = -1 +# steps_to_end: int = -1 # prev: list = field(default_factory=list) # Previous ChunkOps # next: list = field(default_factory=list) # Next ChunkOps # visited = False @@ -291,7 +297,7 @@ def print_chunk_info(self, index=0): # return self.steps_from_start < other.steps_from_start # def __hash__(self): -# return hash((self.inst, self.dst.rank, self.dst.index, self.dst.buffer)) # TODO +# return hash((self.inst, self.dst.rank, self.dst.index, self.dst.buffer)) # TODO # def same_slot(ref1, ref2): # return ref1.rank == ref2.rank and ref1.buffer == ref2.buffer and ref1.index == ref2.index @@ -348,7 +354,7 @@ def print_chunk_info(self, index=0): # # steps_from_start = max(steps_from_start, prev_op.steps_from_start) # # prev_ops.append(prev_op) # op = ChunkOp(ChunkInstruction.send, src, dst, sendtb, recvtb, ch, steps_from_start+1) - + # for prev_op in prev_ops: # prev_op.next.append(op) # op.prev = prev_ops @@ -364,7 +370,7 @@ def print_chunk_info(self, index=0): # steps_from_start = max(prev_op_src.steps_from_start, prev_op_dst.steps_from_start, steps_from_start) # prev_ops.append(prev_op_src) # prev_ops.append(prev_op_dst) - + # op = ChunkOp(ChunkInstruction.reduce, src, dst, sendtb, recvtb, ch, steps_from_start+1) # for prev_op in prev_ops: @@ -387,14 +393,14 @@ def print_chunk_info(self, index=0): # for chunk, op in self.chunk_paths.items(): # if op.inst == ChunkInstruction.start: # dfs(op) - + # # Assigns each send and a reduce a channel for communication based of policies # def channel_assignment(self, channel_policy='zero'): # frontier = [] # visited = set() # for chunk, op in self.chunk_paths.items(): -# if len(op.prev) == 0: +# if len(op.prev) == 0: # heapq.heappush(frontier, op) # # If an op isn't annotated with a channel set it to 0 @@ -412,7 +418,7 @@ def print_chunk_info(self, index=0): # visited = set() # for chunk, op in self.chunk_paths.items(): -# if len(op.prev) == 0: +# if len(op.prev) == 0: # heapq.heappush(frontier, ((op.steps_from_start, op.steps_to_end), op)) # while len(frontier) > 0: From e0394848eecb0356fec0a2b43d04615c2db72033 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 20 Mar 2024 14:15:51 +0000 Subject: [PATCH 02/76] WIP --- .vscode/launch.json | 17 +++++++ examples/mscclang/allreduce_a100_allpairs.py | 10 ++-- .../allreduce_a100_allpairs_mscclpp.py | 51 +++++++++++++++++++ .../allreduce_a100_allpairs_mscclpp_v2.py | 45 ++++++++++++++++ msccl/language/__init__.py | 37 +++++++++++++- msccl/language/ir.py | 15 ++++-- msccl/language/rank_dag.py | 45 +++++++++------- 7 files changed, 191 insertions(+), 29 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 examples/mscclang/allreduce_a100_allpairs_mscclpp.py create mode 100644 examples/mscclang/allreduce_a100_allpairs_mscclpp_v2.py diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..1bdfdd3 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,17 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File with Arguments", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "args": "4 1 --protocol Simple", + "justMyCode": false + } + ] +} diff --git a/examples/mscclang/allreduce_a100_allpairs.py b/examples/mscclang/allreduce_a100_allpairs.py index e6ec80d..ea080bc 100755 --- a/examples/mscclang/allreduce_a100_allpairs.py +++ b/examples/mscclang/allreduce_a100_allpairs.py @@ -11,9 +11,9 @@ def allreduce_allpairs(gpus, instances, protocol): chunksperloop = gpus * gpus topology = fully_connected(size) collective = AllReduce(size, chunksperloop, True) - with MSCCLProgram("allreduce_pairs", topology, collective, instances, protocol=protocol, + with MSCCLProgram("allreduce_pairs", topology, collective, instances, protocol=protocol, interleaved_replication=False, threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True): - + # Each rank sends the nth chunk to the nth rank into scratch space for r1 in range(size): for r2 in range(size): @@ -28,7 +28,7 @@ def allreduce_allpairs(gpus, instances, protocol): for index in range(0, size * (size-1)): c = chunk(r, Buffer.input, r*size + (index % size)) c.reduce(chunk(r, 'scratch', index), sendtb=(index % size)) - + # Each rank sends the fully reduced nth chunk to all other gpus for r1 in range(size): for r2 in range(size): @@ -36,7 +36,7 @@ def allreduce_allpairs(gpus, instances, protocol): index = r1 * size c = chunk(r1, Buffer.input, index, size) c.copy(r2, Buffer.input, index, sendtb=r2, recvtb=r1) - + XML() Check() @@ -47,4 +47,4 @@ def allreduce_allpairs(gpus, instances, protocol): args = parser.parse_args() -allreduce_allpairs(args.num_gpus, args.instances, args.protocol) \ No newline at end of file +allreduce_allpairs(args.num_gpus, args.instances, args.protocol) diff --git a/examples/mscclang/allreduce_a100_allpairs_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_mscclpp.py new file mode 100644 index 0000000..53a81a0 --- /dev/null +++ b/examples/mscclang/allreduce_a100_allpairs_mscclpp.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +from msccl.language import * +from msccl.topologies import * +from msccl.language.collectives import AllReduce + +def allreduce_allpairs(gpus, instances, protocol): + size = gpus + chunksperloop = gpus * gpus + topology = fully_connected(size) + collective = AllReduce(size, chunksperloop, True) + with MSCCLProgram("allreduce_pairs", topology, collective, instances, protocol=protocol, + interleaved_replication=False, threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True): + + # Each rank sends the nth chunk to the nth rank into scratch space + for r1 in range(size): + for r2 in range(size): + if r1 != r2: + index = r2 * size + c = chunk(r1, Buffer.input, index, size=size) + c.put(r2, 'scratch', index=r1, sendtb=r2) + + # Each rank performs a local reduction on the nth chunk + # Utilize 8 threadblocks for this reduction for better parallelism + for r in range(size): + for index in range(0, size * (size-1)): + c = chunk(r, Buffer.input, r*size + (index % size)) + c.reduce(chunk(r, 'scratch', index), sendtb=(index % size)) + + # Each rank sends the fully reduced nth chunk to all other gpus + for r1 in range(size): + for r2 in range(size): + index = r1 * size + c = chunk(r1, Buffer.input, index + r2) + for r3 in range(size): + if r3 != r1: + c.put(r3, Buffer.input, index, sendtb=r2) + + XML() + Check() + +parser = argparse.ArgumentParser() +parser.add_argument('num_gpus', type=int, help ='number of gpus') +parser.add_argument('instances', type=int, help='number of instances') +parser.add_argument('--protocol', type=str, default='LL', choices=['Simple', 'LL'], help='Protocol') + +args = parser.parse_args() + +allreduce_allpairs(args.num_gpus, args.instances, args.protocol) diff --git a/examples/mscclang/allreduce_a100_allpairs_mscclpp_v2.py b/examples/mscclang/allreduce_a100_allpairs_mscclpp_v2.py new file mode 100644 index 0000000..4589e9b --- /dev/null +++ b/examples/mscclang/allreduce_a100_allpairs_mscclpp_v2.py @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +from msccl.language import * +from msccl.topologies import * +from msccl.language.collectives import AllReduce + +def allreduce_allpairs(gpus, instances, protocol): + size = gpus + chunksperloop = gpus * gpus + topology = fully_connected(size) + collective = AllReduce(size, chunksperloop, True) + with MSCCLProgram("allreduce_pairs", topology, collective, instances, protocol=protocol, + interleaved_replication=False, threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True): + + # Each rank sends the nth chunk to the nth rank into scratch space + for rank in range(size): + for tb in range(size): + index = rank * size + c = chunk(rank, Buffer.input, index + tb) + for nghr in range(size): + if rank != nghr: + c.reduce(chunk(nghr, 'input', index + tb), recvtb==tb) + + # Each rank sends the fully reduced nth chunk to all other gpus + for rank in range(size): + for tb in range(size): + index = rank * size + c = chunk(rank, Buffer.input, index + tb) + for nghr in range(size): + if rank != nghr: + c.put(nghr, Buffer.input, index, sendtb=tb) + + XML() + Check() + +parser = argparse.ArgumentParser() +parser.add_argument('num_gpus', type=int, help ='number of gpus') +parser.add_argument('instances', type=int, help='number of instances') +parser.add_argument('--protocol', type=str, default='LL', choices=['Simple', 'LL'], help='Protocol') + +args = parser.parse_args() + +allreduce_allpairs(args.num_gpus, args.instances, args.protocol) diff --git a/msccl/language/__init__.py b/msccl/language/__init__.py index 16092f9..f349119 100755 --- a/msccl/language/__init__.py +++ b/msccl/language/__init__.py @@ -138,6 +138,9 @@ def print_instr_dags(self, rank): else: visualize_instr_dag(self.instr_dags[rank].operations) +class MSCCLPPProgram: + pass + def Print(): _curr().print_chunk_dag() @@ -190,8 +193,40 @@ def group(self, other): end = max(first._end(), second._end()) return Ref(self.rank, self.buffer, first.index, end - first.index, self.prog) - def put(self, dst, buffer=None, index=-1, sendtb=-1): + def put(self, dst, buffer=None, index=-1, sendtb=-1, channel_type="SM"): self.prog.check_buffer_exists(dst, buffer) + # If index is not specified assume it is going to the same place in the next gpu + if index == -1 and buffer == None: + index = self.index + buffer = self.buffer + elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output: + index = self.prog.buffers[dst][buffer].instance_size() + + # Some inplace collectives have custom logic for buffers and index (ReduceScatter, AllGather) + buffer, index = self.prog.collective.get_buffer_index(self.rank, buffer, index) + + # Direct put + assert (self.prog.topo.link(self.rank, dst) or dst == self.rank), f'No link from {self.rank} to {dst}' + dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) + + # Check if we are copying the chunk to the same index (easy mistake when we are using inplace) + if dst_chunkref == self: + return + + # chunks = self.prog.get_chunks(self.rank, self.buffer, self.index, self.size) + # overwritten_chunks = self.prog.get_chunks(dst, buffer, index, self.size) + + self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) + + # self.prog.chunk_dag.add_send(chunks, overwritten_chunks, self, dst_chunkref, sendtb, recvtb, ch) + sender = self.rank + receiver = dst + if sender != receiver: + sop = self.prog.instr_dag.add_send(sender, self, dst_chunkref, sendtb) + else: + self.prog.instr_dag.add_copy(sender, self, dst_chunkref, sendtb) + + return dst_chunkref def get(self, src, buffer=None, index=-1, recvtb=-1): self.prog.check_buffer_exists(src, buffer) diff --git a/msccl/language/ir.py b/msccl/language/ir.py index 1cbfade..c1299d2 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -76,8 +76,13 @@ class Instruction(Enum): recv_reduce_copy_send = 'rrcs' copy = 'cpy' reduce = 're' - delete = 'd' + delete = 'd' start = 'st' + put = 'put' + get = 'get' + wait = 'wait' + signal = 'signal' + flush = 'flush' def __str__(self): return self.value @@ -93,7 +98,7 @@ def __str__(self): def __lt__(self, other): return self.value < other.value - + def __gt__(self, other): return self.value < other.value @@ -172,7 +177,7 @@ def send_peer(self): if self.is_send(): return self.dst.rank return -1 - + def recv_peer(self): if self.is_recv(): return self.src.rank @@ -244,7 +249,7 @@ def ir_to_xml(program: Program, old_format=True, use_scratch=True, pretty_print= op.depends = list( filter(lambda dep: op_tb_id[dep] != tb_id[tb], op.depends)) # Filter out redundant dependencies - # e.g. if op1 and op2 depend on op, and op1 happends before op2 + # e.g. if op1 and op2 depend on op, and op1 happends before op2 # then op2 does not need to explicitly depend on op for gpu in program.gpus: for tb in gpu.threadblocks: @@ -276,7 +281,7 @@ def ir_to_xml(program: Program, old_format=True, use_scratch=True, pretty_print= for dep in op.depends: if first_dep is None: first_dep = dep - else: + else: pre_ops.append(Op(Instruction.nop, -1, None, None, [dep])) op.depends = [] if first_re is None: diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index ae00960..36e1898 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -23,7 +23,7 @@ def same_tb(op1, op2): def same_count(op1, op2): return op1.cnt() == op2.cnt() - + def same_buf_dst(op1, op2): return op1.dst.buffer == op2.dst.buffer and op1.dst.index == op2.dst.index @@ -36,9 +36,9 @@ def __init__(self, num_ranks, buffers): self.last_writer = {} # slot -> last writing op self.last_readers = defaultdict(list) # slot -> list of last reading ops # State for the MSCCL-IR - self.tbs = [] + self.tbs = [] for _ in range(num_ranks): - self.tbs.append({}) + self.tbs.append({}) self.tb_mapping = {} self.num_channels = [1] * num_ranks @@ -62,7 +62,7 @@ def _write(self, rank, buffer, index, size, op, read=False): prev_ops.update(readers) elif slot in self.last_writer: prev_ops.add(self.last_writer[slot]) - + # Set the last_writer to this op, and clear all readers self.last_writer[slot] = op self.last_readers[slot] = [] @@ -82,7 +82,7 @@ def _read(self, rank, buffer, index, size, op): writer = self.last_writer[slot] prev_ops.add(writer) self.last_readers[slot].append(op) - + # Update the next pointer of the previous ops for prev_op in prev_ops: prev_op.next.add(op) @@ -133,6 +133,15 @@ def add_send(self, rank, send_ref, recv_ref, tb, ch): self._read(rank, buffer, index, size, op) return op + # InstructionDAG - adds a put node + def add_put(self, rank, send_ref, recv_ref, tb, ch): + op = Op(Instruction.put, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) + buffer = send_ref.buffer + index = send_ref.index + size = send_ref.size + self._read(rank, buffer, index, size, op) + return op + # InstructionDAG - adds a recv node def add_recv(self, rank, send_ref, recv_ref, tb, ch, send_op): op = Op(Instruction.recv, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) @@ -172,7 +181,7 @@ def convert_set_list(self): ops = ops[1:] + op.next else: ops = ops[1:] - + def optimize(self): self._optimize_rrcs_rrs() self._optimize_rcs() @@ -198,13 +207,13 @@ def dfs(op, cs): for chunk, op in self.operations.items(): if op.inst == Instruction.start: dfs(op,-2) # Start instructions should start at -1 - - + + # Given the set of operations that operate over a particular slot (rank, buffer, idx) fixed # Try and replace operations with pipelined ops like receive copy send (rcs) # or receive reduce send (rrs) and receive reduce copy send (rrcs) # Rules: - # recv-copy-send + # recv-copy-send # recv(src, sbuf, si, _, _, _ ) send(_, _, _, dst, dbuf, di) -> recv_copy_send(src, sbuf, si, dst, dbuf, di) def _optimize_rcs(self): for slot, ops in self.operations.items(): @@ -222,7 +231,7 @@ def _optimize_rcs(self): break frontier = frontier[1:] + op.next # recv-reduce-send - A rrc followed by a send that gets overwritten - # rrc(src, sbuf, si, ...) send(_, _, _, dst, dbuf, di) recv(_, _, _, dst, dbuf, di) + # rrc(src, sbuf, si, ...) send(_, _, _, dst, dbuf, di) recv(_, _, _, dst, dbuf, di) # recv-reduce-copy-send - A rrc followed by a send that does not get overwritten # rrc(src, sbuf, si, ...) send(_, _, _, dst, dbuf, di) def _optimize_rrcs_rrs(self): @@ -241,7 +250,7 @@ def _optimize_rrcs_rrs(self): next_op.recv_match.send_match = op op.recv_match = next_op.recv_match remove_op(next_op) - + if op.inst == Instruction.recv_reduce_copy and next_op.inst == Instruction.send and same_tb(op, next_op) and same_count(op, next_op) and same_buf_dst(op, next_op): op.inst = Instruction.recv_reduce_copy_send op.dst = next_op.dst @@ -253,7 +262,7 @@ def _optimize_rrcs_rrs(self): def lower_pt1(self, instances): self.infer_dependencies() self.lower_buffers(instances) - + def lower_pt2(self, instances, interleaved): self.replicate(instances, interleaved) return self.lower_tbs() @@ -311,14 +320,14 @@ def lower_tbs(self): # interleaved sets the replication policy # if True chunks are split as: ChunkA ChunkB -> ChunkA0 ChunkA1 .. ChunkB0 ChunkB1 ... # if false chunks are divided as ChunkA0 ChunkB0 ChunkA1 ChunkB1 ... - # For collectives were chunks are designated for a particular GPU (e.g. AllToAll) + # For collectives were chunks are designated for a particular GPU (e.g. AllToAll) # only interleaved replication will be correct # Interleaved policy only supports single count sends/receives from the input/output buffer # (multicount ops are fine between scratch) def replicate(self, instances, interleaved): if instances == 1: self.instanced_tbs = self.tbs - return + return self.instanced_tbs = [] for _ in range(self.num_ranks): @@ -357,12 +366,12 @@ def get_instance_ref(ref): for s, op in enumerate(tb.ops): isrc = get_instance_ref(op.src) idst = get_instance_ref(op.dst) - idepends = [] + idepends = [] # Note: We don't need the fill out the rest of the metadata since replication is the last optimization - iop = Op(op.inst, op.rank, isrc, idst, idepends, op.step, itbid) + iop = Op(op.inst, op.rank, isrc, idst, idepends, op.step, itbid) itb.ops[s] = iop self.instanced_tbs[op.rank][itbid] = itb - + # Redo dependency analysis for rank, rank_tbs in enumerate(self.tbs): for tbid, tb in rank_tbs.items(): @@ -375,5 +384,5 @@ def get_instance_ref(ref): dep_tbid = dep.tb dep_itbid = dep_tbid * instances + i dep_step = dep.step - iop.depends[s] = self.instanced_tbs[op.rank][dep_itbid].ops[dep_step] + iop.depends[s] = self.instanced_tbs[op.rank][dep_itbid].ops[dep_step] From 3527e3273111825e3de835561ebcfa08c4ba7522 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Fri, 22 Mar 2024 09:48:06 +0000 Subject: [PATCH 03/76] fix --- examples/mscclang/put_mscclpp.py | 37 +++++++++++++++++++++++ msccl/language/__init__.py | 50 ++++++++++++++++++++++---------- msccl/language/ir.py | 13 +++++++++ msccl/language/rank_dag.py | 43 +++++++++++++++++++++++---- msccl/language/tb_assignment.py | 33 ++++++++++++++------- 5 files changed, 145 insertions(+), 31 deletions(-) create mode 100644 examples/mscclang/put_mscclpp.py diff --git a/examples/mscclang/put_mscclpp.py b/examples/mscclang/put_mscclpp.py new file mode 100644 index 0000000..b4bcc82 --- /dev/null +++ b/examples/mscclang/put_mscclpp.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +from msccl.language import * +from msccl.topologies import * +from msccl.language.collectives import AllReduce + +def allreduce_allpairs(gpus, instances, protocol): + size = gpus + chunksperloop = gpus + topology = fully_connected(size) + collective = AllReduce(size, chunksperloop, True) + with MSCCLProgram("allreduce_pairs", topology, collective, instances, protocol=protocol, + interleaved_replication=False, threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True): + + c = chunk(0, Buffer.input, 0, size=1) + c.put(1, Buffer.input, index = 0, sendtb=0) + c.put(1, Buffer.input, index = 1, sendtb=0) + c.signal(1, sendtb=0) + + dc0 = chunk(1, Buffer.input, 1, size=1) + dc1 = chunk(1, Buffer.input, 0, size=1) + dc0.wait(0, recvtb=1) + dc1.wait(0, recvtb=1) + + Json() + #Check() + +parser = argparse.ArgumentParser() +parser.add_argument('num_gpus', type=int, help ='number of gpus') +parser.add_argument('instances', type=int, help='number of instances') +parser.add_argument('--protocol', type=str, default='Simple', choices=['Simple', 'LL128', 'LL'], help='Protocol') + +args = parser.parse_args() + +allreduce_allpairs(args.num_gpus, args.instances, args.protocol) diff --git a/msccl/language/__init__.py b/msccl/language/__init__.py index f349119..8a543e7 100755 --- a/msccl/language/__init__.py +++ b/msccl/language/__init__.py @@ -125,9 +125,21 @@ def lower(self): check_threadblock_ordering(self.instr_dag) return Program(self.name, self.collective.name, self.collective.inplace, self.protocol, gpu_prgms) + # Lower program to MSCCLPP + def lower_mscclpp(self): + convert_to_exectuion_plan(self.instr_dag) + if self.instr_fusion: + self.instr_dag.optimize_mscclpp() + self.instr_dag.lower_pt1(self.instances) + gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.interleaved_replication) + return Program(self.name, self.collective.name, self.collective.inplace, self.protocol, gpu_prgms) + def generate_xml(self): return ir_to_xml(self.lower(), dependence_nop=self.dependence_nop) + def generate_json(self): + return ir_to_xml(self.lower_mscclpp(), dependence_nop=self.dependence_nop) + def print_chunk_dag(self): visualize_chunk_dag(self.chunk_dag.chunk_paths) @@ -155,6 +167,9 @@ def create_scratch(rank, name): def XML(): print(_curr().generate_xml()) +def Json(): + print(_curr().generate_json()) + def Check(): return _curr().check() @@ -193,8 +208,12 @@ def group(self, other): end = max(first._end(), second._end()) return Ref(self.rank, self.buffer, first.index, end - first.index, self.prog) - def put(self, dst, buffer=None, index=-1, sendtb=-1, channel_type="SM"): + def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): self.prog.check_buffer_exists(dst, buffer) + sender = self.rank + receiver = dst + assert sender != receiver, 'Cannot put to the same rank' + # If index is not specified assume it is going to the same place in the next gpu if index == -1 and buffer == None: index = self.index @@ -209,27 +228,28 @@ def put(self, dst, buffer=None, index=-1, sendtb=-1, channel_type="SM"): assert (self.prog.topo.link(self.rank, dst) or dst == self.rank), f'No link from {self.rank} to {dst}' dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) - # Check if we are copying the chunk to the same index (easy mistake when we are using inplace) - if dst_chunkref == self: - return + self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) + self.prog.instr_dag.add_put(sender, self, dst_chunkref, sendtb, chan_type) - # chunks = self.prog.get_chunks(self.rank, self.buffer, self.index, self.size) - # overwritten_chunks = self.prog.get_chunks(dst, buffer, index, self.size) + return dst_chunkref - self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) - # self.prog.chunk_dag.add_send(chunks, overwritten_chunks, self, dst_chunkref, sendtb, recvtb, ch) + def get(self, src, buffer=None, index=-1, recvtb=-1): + self.prog.check_buffer_exists(src, buffer) + + def signal(self, dst, sendtb=-1, chan_type=ChannelType.sm): sender = self.rank receiver = dst - if sender != receiver: - sop = self.prog.instr_dag.add_send(sender, self, dst_chunkref, sendtb) - else: - self.prog.instr_dag.add_copy(sender, self, dst_chunkref, sendtb) + assert sender != receiver, 'Cannot signal to the same rank' - return dst_chunkref + self.prog.instr_dag.add_signal(sender, self, dst, sendtb, chan_type) - def get(self, src, buffer=None, index=-1, recvtb=-1): - self.prog.check_buffer_exists(src, buffer) + def wait(self, src, recvtb=-1, chan_type=ChannelType.sm): + sender = src + receiver = self.rank + assert sender != receiver, 'Cannot wait on the same rank' + + self.prog.instr_dag.add_wait(receiver, self, src, recvtb, chan_type) # Copies the chunk(s) referenced by this chunkref onto Rank dst at location (buffer, index) def copy(self, dst, buffer=None, index=-1, sendtb=-1, recvtb=-1, ch=-1): diff --git a/msccl/language/ir.py b/msccl/language/ir.py index c1299d2..2c905fc 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -103,6 +103,13 @@ def __gt__(self, other): return self.value < other.value +class ChannelType(Enum): + proxy = 'proxy' + sm = 'sm' + + def __str__(self): + return self.value + @dataclass class ChunkRef: @@ -132,6 +139,9 @@ class Op: recv_match = None send_match = None channel: int = -1 + channel_type: ChannelType = ChannelType.sm + dst_ranks: list = field(default_factory=list) + src_ranks: list = field(default_factory=list) def cnt(self): if self.src: @@ -398,3 +408,6 @@ def ir_to_xml(program: Program, old_format=True, use_scratch=True, pretty_print= if pretty_print: ET.indent(algo_elem, space=' ') return ET.tostring(algo_elem, encoding='unicode') + +def ir_to_json(program: Program, old_format=True, use_scratch=True, pretty_print=True, dependence_nop=False): + pass diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index 36e1898..75913e5 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -41,7 +41,7 @@ def __init__(self, num_ranks, buffers): self.tbs.append({}) self.tb_mapping = {} self.num_channels = [1] * num_ranks - + self.tb_steps = [{} for _ in range(num_ranks)] # InstructionDAG helper - identifies the dependencies for a write-type operation (recv, copy, rrc, reduce) def _write(self, rank, buffer, index, size, op, read=False): @@ -134,14 +134,35 @@ def add_send(self, rank, send_ref, recv_ref, tb, ch): return op # InstructionDAG - adds a put node - def add_put(self, rank, send_ref, recv_ref, tb, ch): - op = Op(Instruction.put, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) + def add_put(self, rank, send_ref, recv_ref, tb, ch_type): + tb_step = self._get_tb_step(rank, tb) + op = Op(Instruction.put, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step) buffer = send_ref.buffer index = send_ref.index size = send_ref.size self._read(rank, buffer, index, size, op) return op + # InstructionDAG - adds a signal node. + def add_signal(self, rank, send_ref, dst, tb, ch_type): + tb_step = self._get_tb_step(rank, tb) + op = Op(Instruction.signal, rank, send_ref, None, next=set(), prev=set(), tb=tb, channel_type=ch_type, dst_ranks=[dst], step=tb_step) + buffer = send_ref.buffer + index = send_ref.index + size = send_ref.size + # treat signal as a write since it can not be executed parallelly with read operations + self._write(rank, buffer, index, size, op) + return op + + def add_wait(self, rank, send_ref, src, tb, ch_type): + tb_step = self._get_tb_step(rank, tb) + op = Op(Instruction.wait, rank, send_ref, send_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, src_ranks=[src], step=tb_step) + buffer = send_ref.buffer + index = send_ref.index + size = send_ref.size + self._write(rank, buffer, index, size, op) + return op + # InstructionDAG - adds a recv node def add_recv(self, rank, send_ref, recv_ref, tb, ch, send_op): op = Op(Instruction.recv, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) @@ -164,6 +185,7 @@ def add_recv_reduce_copy(self, rank, send_ref, recv_ref, tb, ch, send_op): def convert_set_list(self): ops = [] + visited = set() for slot, op in self.operations.items(): if op.inst == Instruction.start: op.next = list(op.next) @@ -172,7 +194,6 @@ def convert_set_list(self): elif op.inst != Instruction.copy: ops.append(op) - visited = set() while len(ops) > 0: op = ops[0] if op not in visited: @@ -181,11 +202,15 @@ def convert_set_list(self): ops = ops[1:] + op.next else: ops = ops[1:] + return visited def optimize(self): self._optimize_rrcs_rrs() self._optimize_rcs() + def optimize_mscclpp(self): + pass + # Completes metadata for chunk_steps (number of steps from a start op) and priority (number of steps to the last op) def _complete_metadata(self): def dfs(op, cs): @@ -259,6 +284,14 @@ def _optimize_rrcs_rrs(self): remove_op(next_op) frontier = frontier[1:] + op.next + def _get_tb_step(self, rank, tb): + if tb in self.tb_steps[rank]: + self.tb_steps[rank][tb] += 1 + return self.tb_steps[rank][tb] + else: + self.tb_steps[rank][tb] = 0 + return 0 + def lower_pt1(self, instances): self.infer_dependencies() self.lower_buffers(instances) @@ -287,7 +320,7 @@ def infer_dependencies(self): # Convert local scratch buffers to index into one global scratch buffer def lower_chunk(self, chunk): - if chunk.buffer is not Buffer.input and chunk.buffer is not Buffer.output: + if chunk is not None and chunk.buffer is not Buffer.input and chunk.buffer is not Buffer.output: buffer = self.buffers[chunk.rank][chunk.buffer].get_buffer() index = self.buffers[chunk.rank][chunk.buffer].get_global_index(chunk.index) return ChunkRef(chunk.rank, buffer, index, chunk.size) diff --git a/msccl/language/tb_assignment.py b/msccl/language/tb_assignment.py index 9760be2..b5fe501 100755 --- a/msccl/language/tb_assignment.py +++ b/msccl/language/tb_assignment.py @@ -12,7 +12,7 @@ def _verify_tb_op_compatible(tb, op): s = op.dst.rank if op.is_send() else -1 r = op.src.rank if op.is_recv() else -1 - + sends_ok = tb.send == s or s == -1 or tb.send == -1 recvs_ok = tb.recv == r or r == -1 or tb.recv == -1 channel_ok = tb.channel == op.channel or tb.channel == -1 or op.channel == -1 @@ -22,7 +22,7 @@ def _verify_tb_op_compatible(tb, op): def manual_assign_tbs(rank_dag): instrs = topo_sort_instrs(rank_dag) for op in instrs: - + rank = op.rank tbid = op.tb if tbid not in rank_dag.tbs[rank]: @@ -40,6 +40,17 @@ def manual_assign_tbs(rank_dag): f"Threadblock {tbid} send:{tb.send} recv:{tb.recv} channel:{tb.channel}\n" \ f"Operation send:{op.dst.rank if op.is_send() else -1} recv:{op.dst.rank if op.is_recv() else -1} channel:{op.channel}") +def convert_to_exectuion_plan(instr_dag): + ops = instr_dag.convert_set_list() + ops = sorted(ops, key=lambda x: x.step) + for op in ops: + rank = op.rank + tbid = op.tb + if tbid not in instr_dag.tbs[rank]: + instr_dag.tbs[rank][tbid] = Threadblock() + tb = instr_dag.tbs[rank][tbid] + tb.ops.append(op) + def _get_tb_options(mapping, send, recv, channel, num_tbs): options = [] for tbid, tb in mapping.items(): @@ -75,12 +86,12 @@ def auto_assign_tbs(rank_dag): tbid = rank_tbids[rank] rank_dag.tbs[rank][tbid] = Threadblock(send=s, recv=r, channel=channel) rank_tbids[rank] += 1 - else: + else: tbid = tb_options[0] for tbid_opt in tb_options: if current_tb_step[rank][tbid_opt] < current_tb_step[rank][tbid] and _verify_tb_op_compatible(rank_dag.tbs[rank][tbid], op): tbid = tbid_opt - + tb = rank_dag.tbs[rank][tbid] assert _verify_tb_op_compatible(tb, op), f"Failing: Operations uses channel {op.channel}, send:{s} recv:{r} {op}\n" \ f"Threadblock uses send:{tb.send} recv:{tb.recv} channel:{tb.channel}" @@ -90,13 +101,13 @@ def auto_assign_tbs(rank_dag): tb.ops.append(op) tb.send = op.dst.rank if op.is_send() else tb.send tb.recv = op.src.rank if op.is_recv() else tb.recv - + op.step = len(tb.ops)-1 op.tb = tbid current_tb_step[rank][tbid] = op.chunk_step # Topologically orders instructions so that (1): Sends occur before their receives -# (2): Dependent instructions occur before +# (2): Dependent instructions occur before def topo_sort_instrs(rank_dag): def priority(op): return ((op.chunk_step, -op.priority, op.dst.index)) @@ -117,9 +128,9 @@ def priority(op): rmatch = op.recv_match ordered.append(op) visited.add(op) - + # Add a matching receive if one exists and its dependencies are satisfied - if rmatch is not None and all([x in visited for x in rmatch.prev]): + if rmatch is not None and all([x in visited for x in rmatch.prev]): heapq.heappush(ops, (priority(rmatch), rmatch)) # Add other operation that have dependencies satisfied for o in op.next: @@ -140,7 +151,7 @@ def valid_send_ch(sender, receiver, ch): def valid_recv_ch(sender, receiver, ch): return ch in rank2recvch[receiver][sender] - # Returns a channel this flow can be scheduled on, else -1 + # Returns a channel this flow can be scheduled on, else -1 def is_matching_flow(flow): if flow in flows: return flow_channels[flows.index(flow)] @@ -159,7 +170,7 @@ def create_flow(f): for i in range(1, len(f)): flow.add((f[i-1], f[i])) return flow - + def dfs(op, channels, f): if op.is_local(): op.channel = 0 @@ -210,7 +221,7 @@ def dfs(op, channels, f): if op.is_send(): dst = op.dst.rank pending_recv[(rank, dst, channel)].append(op.recv_match) - + if op.is_recv(): src = op.src.rank pr = pending_recv[(src, rank, channel)] From 24c086325129f765e18897518d77408bcaccd2d9 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Sat, 23 Mar 2024 10:49:35 +0000 Subject: [PATCH 04/76] WIP --- examples/mscclang/put_mscclpp.py | 11 +- msccl/language/__init__.py | 30 ++++- msccl/language/ir.py | 209 ++++++++++++++++++++++++++----- msccl/language/rank_dag.py | 14 +-- 4 files changed, 218 insertions(+), 46 deletions(-) diff --git a/examples/mscclang/put_mscclpp.py b/examples/mscclang/put_mscclpp.py index b4bcc82..4c5ac75 100644 --- a/examples/mscclang/put_mscclpp.py +++ b/examples/mscclang/put_mscclpp.py @@ -15,14 +15,15 @@ def allreduce_allpairs(gpus, instances, protocol): interleaved_replication=False, threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True): c = chunk(0, Buffer.input, 0, size=1) - c.put(1, Buffer.input, index = 0, sendtb=0) - c.put(1, Buffer.input, index = 1, sendtb=0) - c.signal(1, sendtb=0) + c.put(1, Buffer.input, index=0, sendtb=0) + c.put(1, Buffer.input, index=1, sendtb=0) + c.signal(1, Buffer.input, index=0, sendtb=0) + c.signal(1, Buffer.input, index=1, sendtb=0) dc0 = chunk(1, Buffer.input, 1, size=1) dc1 = chunk(1, Buffer.input, 0, size=1) - dc0.wait(0, recvtb=1) - dc1.wait(0, recvtb=1) + dc0.wait(0, Buffer.input, index=0, recvtb=1) + dc1.wait(0, Buffer.input, index=1, recvtb=1) Json() #Check() diff --git a/msccl/language/__init__.py b/msccl/language/__init__.py index 8a543e7..7173494 100755 --- a/msccl/language/__init__.py +++ b/msccl/language/__init__.py @@ -138,7 +138,7 @@ def generate_xml(self): return ir_to_xml(self.lower(), dependence_nop=self.dependence_nop) def generate_json(self): - return ir_to_xml(self.lower_mscclpp(), dependence_nop=self.dependence_nop) + return ir_to_json(self.lower_mscclpp(), dependence_nop=self.dependence_nop) def print_chunk_dag(self): visualize_chunk_dag(self.chunk_dag.chunk_paths) @@ -237,19 +237,39 @@ def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): def get(self, src, buffer=None, index=-1, recvtb=-1): self.prog.check_buffer_exists(src, buffer) - def signal(self, dst, sendtb=-1, chan_type=ChannelType.sm): + def signal(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): sender = self.rank receiver = dst assert sender != receiver, 'Cannot signal to the same rank' - self.prog.instr_dag.add_signal(sender, self, dst, sendtb, chan_type) + if index == -1 and buffer == None: + index = self.index + buffer = self.buffer + elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output: + index = self.prog.buffers[dst][buffer].instance_size() + + # Direct signal + assert (self.prog.topo.link(self.rank, dst) or dst == self.rank), f'No link from {self.rank} to {dst}' + dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) - def wait(self, src, recvtb=-1, chan_type=ChannelType.sm): + self.prog.instr_dag.add_signal(sender, self, dst_chunkref, sendtb, chan_type) + + def wait(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm): sender = src receiver = self.rank assert sender != receiver, 'Cannot wait on the same rank' - self.prog.instr_dag.add_wait(receiver, self, src, recvtb, chan_type) + if index == -1 and buffer == None: + index = self.index + buffer = self.buffer + elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output: + index = self.prog.buffers[src][buffer].instance_size() + + # Direct signal + assert (self.prog.topo.link(self.rank, src) or src == self.rank), f'No link from {self.rank} to {src}' + src_chunkref = self.prog.get_ref(src, buffer, index, self.size) + + self.prog.instr_dag.add_wait(receiver, self, src_chunkref, recvtb, chan_type) # Copies the chunk(s) referenced by this chunkref onto Rank dst at location (buffer, index) def copy(self, dst, buffer=None, index=-1, sendtb=-1, recvtb=-1, ch=-1): diff --git a/msccl/language/ir.py b/msccl/language/ir.py index 2c905fc..50bc86a 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import json from lxml import etree as ET from dataclasses import dataclass, field from enum import Enum @@ -28,11 +29,43 @@ class Gpu: outputs: dict = field(default_factory=dict) input_chunks: int = 0 output_chunks: int = 0 + scratch_chunks: int = 0 scratch: dict = field(default_factory=dict) + channels: dict = field(default_factory=dict) def scratch_size(self): return max((idx for addr, idx in self.scratch.items()), default=-1) + 1 +class ChannelType(Enum): + proxy = 'proxy' + sm = 'sm' + none = 'none' + + def __str__(self): + return self.value + +class Buffer(Enum): + input = 'i' + output = 'o' + scratch = 's' + + def __str__(self): + return self.value + + def __lt__(self, other): + return self.value < other.value + + def __gt__(self, other): + return self.value < other.value + +@dataclass +class Channel: + name: str + srcBuffer: Buffer + dstBuffer: Buffer + type: ChannelType + connected_to: int + @dataclass class Threadblock: @@ -41,6 +74,7 @@ class Threadblock: recv: int = -1 ops: list = field(default_factory=list) rbid: int = -1 # threadblock id of the receiver + chan: dict = field(default_factory=dict) def __eq__(self, other): return self is other @@ -88,28 +122,6 @@ def __str__(self): return self.value -class Buffer(Enum): - input = 'i' - output = 'o' - scratch = 's' - - def __str__(self): - return self.value - - def __lt__(self, other): - return self.value < other.value - - def __gt__(self, other): - return self.value < other.value - - -class ChannelType(Enum): - proxy = 'proxy' - sm = 'sm' - - def __str__(self): - return self.value - @dataclass class ChunkRef: @@ -139,9 +151,7 @@ class Op: recv_match = None send_match = None channel: int = -1 - channel_type: ChannelType = ChannelType.sm - dst_ranks: list = field(default_factory=list) - src_ranks: list = field(default_factory=list) + channel_type: ChannelType = ChannelType.none def cnt(self): if self.src: @@ -216,11 +226,11 @@ def __repr__(self): # Instructions where src is on local GPU -_local_src_insts = {Instruction.send, Instruction.copy, Instruction.reduce} +_local_src_insts = {Instruction.send, Instruction.copy, Instruction.reduce, Instruction.put, Instruction.signal} # Instructions where dst is on local GPU _local_dst_insts = {Instruction.recv, Instruction.recv_copy_send, Instruction.recv_reduce_send, Instruction.recv_reduce_copy, Instruction.copy, Instruction.reduce, - Instruction.recv_reduce_copy_send} + Instruction.recv_reduce_copy_send, Instruction.wait} def ir_to_xml(program: Program, old_format=True, use_scratch=True, pretty_print=True, dependence_nop=False): @@ -409,5 +419,146 @@ def ir_to_xml(program: Program, old_format=True, use_scratch=True, pretty_print= ET.indent(algo_elem, space=' ') return ET.tostring(algo_elem, encoding='unicode') -def ir_to_json(program: Program, old_format=True, use_scratch=True, pretty_print=True, dependence_nop=False): - pass +def ir_to_json(program: Program, dependence_nop=False): + # Figure out sizes of buffers based on usage + buffer_sizes = defaultdict(lambda: 0) + for gpu in program.gpus: + for tb in gpu.threadblocks: + for op in tb.ops: + if op.inst in _local_src_insts: + key = (gpu.rank, op.src.buffer) + buffer_sizes[key] = max( + buffer_sizes[key], op.src.index + op.src.size) + if op.inst in _local_dst_insts: + key = (gpu.rank, op.dst.buffer) + buffer_sizes[key] = max( + buffer_sizes[key], op.dst.index + op.dst.size) + gpu.input_chunks = max(buffer_sizes[(gpu.rank, Buffer.input)], gpu.input_chunks) + gpu.output_chunks = max(buffer_sizes[(gpu.rank, Buffer.output)], gpu.output_chunks) + gpu.scratch_chunks = max(buffer_sizes[(gpu.rank, Buffer.scratch)], gpu.scratch_chunks) + + # get channel info for each GPU and threadblock + + # Filter out dependencies within the same threadblock + op_tb_id = {} + for gpu in program.gpus: + for tb in gpu.threadblocks: + for op in tb.ops: + op_tb_id[op] = op.tb + for gpu in program.gpus: + for tb in gpu.threadblocks: + for op in tb.ops: + op.depends = list( + filter(lambda dep: op_tb_id[dep] != op.tb, op.depends)) + # Filter out redundant dependencies + # e.g. if op1 and op2 depend on op, and op1 happends before op2 + # then op2 does not need to explicitly depend on op + for gpu in program.gpus: + for tb in gpu.threadblocks: + running_depends = [] + for op in tb.ops: + op.depends = list( + filter(lambda dep: dep not in running_depends, op.depends)) + running_depends = running_depends + op.depends + + # Mark all ops that have a dependence on them + has_dependence = set() + for gpu in program.gpus: + for tb in gpu.threadblocks: + for op in tb.ops: + has_dependence.update(op.depends) + + if dependence_nop: + for gpu in program.gpus: + for tb in gpu.threadblocks: + pre_ops = [] + after_ops = [] + first_re = None + first_dep = None + for i, op in enumerate(tb.ops): + # Expand extra dependencies into nop operations + num_depends = len(op.depends) + if op.inst is Instruction.reduce: + if num_depends > 0: + for dep in op.depends: + if first_dep is None: + first_dep = dep + else: + pre_ops.append(Op(Instruction.nop, -1, None, None, [dep])) + op.depends = [] + if first_re is None: + first_re = op + + if first_re is not None: + after_ops.append(op) + else: + pre_ops.append(op) + if first_dep is not None: + first_re.depends = [first_dep] + tb.ops = pre_ops + after_ops + + # Do some additional postprocessing of operations: + # - Expand operations with extra dependencies with no-ops + # - Mark the index of each operation taking any extra no-ops into account + op_idx = {} + for gpu in program.gpus: + for tb in gpu.threadblocks: + new_ops = [] + for op in tb.ops: + # Expand extra dependencies into nop operations + if len(op.depends) > 1: + extra_deps = op.depends[1:] + op.depends = op.depends[:1] + for i, dep in enumerate(extra_deps): + new_ops.append(Op(Instruction.nop, -1, None, None, [dep])) + op_idx[new_ops[-1]] = len(new_ops) - 1 + #op_tb_id[new_ops[-1]] = op_tb_id[op] + new_ops.append(op) + op_idx[new_ops[-1]] = len(new_ops) - 1 + tb.ops = new_ops + + # Need to calculate channel info for each GPU + nchannels = 0 + for gpu in program.gpus: + max_tb_channels = 0 + if len(gpu.threadblocks) > 0: + max_tb_channels = max(tb.channel+1 for tb in gpu.threadblocks) + nchannels = max(nchannels, max_tb_channels) + return dump_to_json(program) + +def dump_to_json(program: Program): + gpus = [] + for id, gpu in enumerate(program.gpus): + gpu_instance = { + 'id': id, + 'input_chunks': gpu.input_chunks, + 'output_chunks': gpu.output_chunks, + 'scratch_size': gpu.scratch_chunks, + 'threadblocks': [] + } + for id, tb in enumerate(gpu.threadblocks): + ops = [] + for op in tb.ops: + instr = { + "name": op.inst.name, + "srcbuff": op.src.buffer.name if op.src else None, + "srcoff": op.src.index if op.src else None, + "dstbuff": op.dst.buffer.name if op.dst else None, + "dstoff": op.dst.index if op.dst else None, + "channel_type": op.channel_type.name, + } + ops.append(instr) + threadblock = { + 'id': id, + 'ops': ops + } + gpu_instance['threadblocks'].append(threadblock) + gpus.append(gpu_instance) + obj = { + 'name': program.name, + 'colletive': program.collective, + 'protocol': program.protocol, + 'inplace': program.inplace, + 'gpus': gpus + } + return json.dumps(obj, indent=2) diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index 75913e5..734bc75 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -144,9 +144,9 @@ def add_put(self, rank, send_ref, recv_ref, tb, ch_type): return op # InstructionDAG - adds a signal node. - def add_signal(self, rank, send_ref, dst, tb, ch_type): + def add_signal(self, rank, send_ref, recv_ref, tb, ch_type): tb_step = self._get_tb_step(rank, tb) - op = Op(Instruction.signal, rank, send_ref, None, next=set(), prev=set(), tb=tb, channel_type=ch_type, dst_ranks=[dst], step=tb_step) + op = Op(Instruction.signal, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step) buffer = send_ref.buffer index = send_ref.index size = send_ref.size @@ -154,12 +154,12 @@ def add_signal(self, rank, send_ref, dst, tb, ch_type): self._write(rank, buffer, index, size, op) return op - def add_wait(self, rank, send_ref, src, tb, ch_type): + def add_wait(self, rank, dst_ref, src_ref, tb, ch_type): tb_step = self._get_tb_step(rank, tb) - op = Op(Instruction.wait, rank, send_ref, send_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, src_ranks=[src], step=tb_step) - buffer = send_ref.buffer - index = send_ref.index - size = send_ref.size + op = Op(Instruction.wait, rank, src_ref, dst_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step) + buffer = dst_ref.buffer + index = dst_ref.index + size = dst_ref.size self._write(rank, buffer, index, size, op) return op From c32229d82a3fb1bd3a92db3f0b48ae66bf78b3ac Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Sat, 23 Mar 2024 14:31:59 +0000 Subject: [PATCH 05/76] WIP --- msccl/language/__init__.py | 5 ++++- msccl/language/ir.py | 36 ++++++++++++++++++++++++++++++------ msccl/language/rank_dag.py | 16 ++++++++++++++++ 3 files changed, 50 insertions(+), 7 deletions(-) diff --git a/msccl/language/__init__.py b/msccl/language/__init__.py index 7173494..bb2dc73 100755 --- a/msccl/language/__init__.py +++ b/msccl/language/__init__.py @@ -128,6 +128,7 @@ def lower(self): # Lower program to MSCCLPP def lower_mscclpp(self): convert_to_exectuion_plan(self.instr_dag) + self.instr_dag.complete_channels() if self.instr_fusion: self.instr_dag.optimize_mscclpp() self.instr_dag.lower_pt1(self.instances) @@ -233,10 +234,12 @@ def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): return dst_chunkref - def get(self, src, buffer=None, index=-1, recvtb=-1): self.prog.check_buffer_exists(src, buffer) + # for signal and wait, currently we assuem the pair will use the same tb index. In future we need + # to infer the tb index from the instruction DAG Add a channel is define as (send_tb, src_buffer, recv_tb, dst_buffer, type). + # Then we can use DAG info to reduce the number of channels. def signal(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): sender = self.rank receiver = dst diff --git a/msccl/language/ir.py b/msccl/language/ir.py index 50bc86a..0fcc2a2 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -58,9 +58,8 @@ def __lt__(self, other): def __gt__(self, other): return self.value < other.value -@dataclass +@dataclass(frozen=True) class Channel: - name: str srcBuffer: Buffer dstBuffer: Buffer type: ChannelType @@ -74,7 +73,7 @@ class Threadblock: recv: int = -1 ops: list = field(default_factory=list) rbid: int = -1 # threadblock id of the receiver - chan: dict = field(default_factory=dict) + channels: list = field(default_factory=list) def __eq__(self, other): return self is other @@ -438,6 +437,19 @@ def ir_to_json(program: Program, dependence_nop=False): gpu.scratch_chunks = max(buffer_sizes[(gpu.rank, Buffer.scratch)], gpu.scratch_chunks) # get channel info for each GPU and threadblock + for gpu in program.gpus: + chan_dict = {} + # the channel key is the tuple (srcBuffer, dstBuffer, type) + for tid, tb in enumerate(gpu.threadblocks): + for ch in tb.channels: + key = (ch.srcBuffer, ch.dstBuffer, ch.type) + if key not in chan_dict: + chan_dict[key] = [(tid, ch.connected_to)] + else: + chan_dict[key].append((tid, ch.connected_to)) + for key, value in chan_dict.items(): + chan_dict[key] = sorted(value) + gpu.channels = chan_dict # Filter out dependencies within the same threadblock op_tb_id = {} @@ -534,16 +546,28 @@ def dump_to_json(program: Program): 'input_chunks': gpu.input_chunks, 'output_chunks': gpu.output_chunks, 'scratch_size': gpu.scratch_chunks, - 'threadblocks': [] + 'threadblocks': [], + "channels": [] } + for (srcBuffer, dstBuffer, type), channels in gpu.channels.items(): + obj = { + "srcBuffer": srcBuffer.name, + "dstBuffer": dstBuffer.name, + "type": type.name, + "connectedTo": [eles[1] for eles in channels], + "threadblockMap": [eles[0] for eles in channels] + } + gpu_instance["channels"].append(obj) for id, tb in enumerate(gpu.threadblocks): ops = [] for op in tb.ops: instr = { "name": op.inst.name, - "srcbuff": op.src.buffer.name if op.src else None, + "src": op.src.rank if op.src else None, + "srcbuff": op.src.buffer.name if op.src.buffer else None, "srcoff": op.src.index if op.src else None, - "dstbuff": op.dst.buffer.name if op.dst else None, + "dst": op.dst.rank if op.dst else None, + "dstbuff": op.dst.buffer.name if op.dst.buffer else None, "dstoff": op.dst.index if op.dst else None, "channel_type": op.channel_type.name, } diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index 734bc75..5fa50c9 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -208,6 +208,22 @@ def optimize(self): self._optimize_rrcs_rrs() self._optimize_rcs() + def complete_channels(self): + send_op = [Instruction.put, Instruction.signal] + recv_op = [Instruction.wait] + for rank, rank_tbs in enumerate(self.tbs): + for tbid, tb in rank_tbs.items(): + chans = set() + for op in tb.ops: + if op.inst in send_op: + chan = Channel(op.src.buffer, op.dst.buffer, op.channel_type, op.dst.rank) + chans.add(chan) + elif op.inst in recv_op: + chan = Channel(op.dst.buffer, op.src.buffer, op.channel_type, op.src.rank) + chans.add(chan) + tb.channels = list(chans) + pass + def optimize_mscclpp(self): pass From 4929ef645ff70d63b76d5e53be6dc68f1a8d2bb7 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Sat, 23 Mar 2024 14:58:04 +0000 Subject: [PATCH 06/76] WIP need algo --- examples/mscclang/put_mscclpp.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/examples/mscclang/put_mscclpp.py b/examples/mscclang/put_mscclpp.py index 4c5ac75..cdf2691 100644 --- a/examples/mscclang/put_mscclpp.py +++ b/examples/mscclang/put_mscclpp.py @@ -14,16 +14,28 @@ def allreduce_allpairs(gpus, instances, protocol): with MSCCLProgram("allreduce_pairs", topology, collective, instances, protocol=protocol, interleaved_replication=False, threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True): - c = chunk(0, Buffer.input, 0, size=1) - c.put(1, Buffer.input, index=0, sendtb=0) - c.put(1, Buffer.input, index=1, sendtb=0) - c.signal(1, Buffer.input, index=0, sendtb=0) - c.signal(1, Buffer.input, index=1, sendtb=0) - - dc0 = chunk(1, Buffer.input, 1, size=1) - dc1 = chunk(1, Buffer.input, 0, size=1) - dc0.wait(0, Buffer.input, index=0, recvtb=1) - dc1.wait(0, Buffer.input, index=1, recvtb=1) + for rank in range(gpus): + c = chunk(rank, Buffer.input, rank, size=1) + for i in range(gpus - 1): + peer = (rank + i + 1) % gpus + c.put(peer, Buffer.input, rank, sendtb=0) + for i in range(gpus - 1): + peer = (rank + i + 1) % gpus + c.signal(peer, Buffer.input, rank, sendtb=0) + for i in range(gpus - 1): + peer = (rank + i + 1) % gpus + c.wait(peer, Buffer.input, peer, recvtb=0) + + # c = chunk(0, Buffer.input, 0, size=1) + # c.put(1, Buffer.input, index=0, sendtb=0) + # c.put(1, Buffer.input, index=1, sendtb=0) + # c.signal(1, Buffer.input, index=0, sendtb=0) + # c.signal(1, Buffer.input, index=1, sendtb=0) + + # dc0 = chunk(1, Buffer.input, 1, size=1) + # dc1 = chunk(1, Buffer.input, 0, size=1) + # dc0.wait(0, Buffer.input, index=0, recvtb=1) + # dc1.wait(0, Buffer.input, index=1, recvtb=1) Json() #Check() From 787645b7b2c28c6f8b3ec6f5eca34ca5ab2507c6 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Sun, 24 Mar 2024 04:08:30 +0000 Subject: [PATCH 07/76] WIP --- .../allreduce_a100_allpairs_packet_mscclpp.py | 50 +++++++++++++++++++ msccl/language/ir.py | 25 +++++++--- msccl/language/tb_assignment.py | 2 +- 3 files changed, 68 insertions(+), 9 deletions(-) create mode 100644 examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py diff --git a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py new file mode 100644 index 0000000..c0abd31 --- /dev/null +++ b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +from msccl.language import * +from msccl.topologies import * +from msccl.language.collectives import AllReduce + +def allreduce_allpairs(gpus, instances): + size = gpus + chunksperloop = gpus * gpus + topology = fully_connected(size) + collective = AllReduce(size, chunksperloop, True) + with MSCCLProgram("allreduce_pairs", topology, collective, instances, protocol="LL", + interleaved_replication=False, threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True): + + # Each rank sends the nth chunk to the nth rank into scratch space + for r1 in range(size): + for r2 in range(size): + if r1 != r2: + index = r2 * size + c = chunk(r1, Buffer.input, index, size=size) + c.put(r2, 'scratch', index=r1*size, sendtb=r2) + + # # Each rank performs a local reduction on the nth chunk + # # Utilize 8 threadblocks for this reduction for better parallelism + # for r in range(size): + # for index in range(0, size * (size-1)): + # c = chunk(r, Buffer.input, r*size + (index % size)) + # c.reduce(chunk(r, 'scratch', index), sendtb=(index % size)) + + # # Each rank sends the fully reduced nth chunk to all other gpus + # for r1 in range(size): + # for r2 in range(size): + # index = r1 * size + # c = chunk(r1, Buffer.input, index + r2) + # for r3 in range(size): + # if r3 != r1: + # c.put(r3, Buffer.input, index, sendtb=r2) + + Json() + # Check() + +parser = argparse.ArgumentParser() +parser.add_argument('num_gpus', type=int, help ='number of gpus') +parser.add_argument('instances', type=int, help='number of instances') + +args = parser.parse_args() + +allreduce_allpairs(args.num_gpus, args.instances) diff --git a/msccl/language/ir.py b/msccl/language/ir.py index 0fcc2a2..f02172f 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -68,6 +68,7 @@ class Channel: @dataclass class Threadblock: + id: int = -1 channel: int = -1 send: int = -1 recv: int = -1 @@ -230,6 +231,7 @@ def __repr__(self): _local_dst_insts = {Instruction.recv, Instruction.recv_copy_send, Instruction.recv_reduce_send, Instruction.recv_reduce_copy, Instruction.copy, Instruction.reduce, Instruction.recv_reduce_copy_send, Instruction.wait} +_send_insts = {Instruction.put} def ir_to_xml(program: Program, old_format=True, use_scratch=True, pretty_print=True, dependence_nop=False): @@ -428,25 +430,31 @@ def ir_to_json(program: Program, dependence_nop=False): key = (gpu.rank, op.src.buffer) buffer_sizes[key] = max( buffer_sizes[key], op.src.index + op.src.size) + if op.inst in _send_insts: + key = (op.dst.rank, op.dst.buffer) + buffer_sizes[key] = max( + buffer_sizes[key], op.dst.index + op.dst.size) if op.inst in _local_dst_insts: key = (gpu.rank, op.dst.buffer) buffer_sizes[key] = max( buffer_sizes[key], op.dst.index + op.dst.size) + for gpu in program.gpus: gpu.input_chunks = max(buffer_sizes[(gpu.rank, Buffer.input)], gpu.input_chunks) gpu.output_chunks = max(buffer_sizes[(gpu.rank, Buffer.output)], gpu.output_chunks) gpu.scratch_chunks = max(buffer_sizes[(gpu.rank, Buffer.scratch)], gpu.scratch_chunks) # get channel info for each GPU and threadblock for gpu in program.gpus: + gpu.threadblocks = sorted(gpu.threadblocks, key=lambda tb: tb.id) chan_dict = {} # the channel key is the tuple (srcBuffer, dstBuffer, type) - for tid, tb in enumerate(gpu.threadblocks): + for tb in gpu.threadblocks: for ch in tb.channels: key = (ch.srcBuffer, ch.dstBuffer, ch.type) if key not in chan_dict: - chan_dict[key] = [(tid, ch.connected_to)] + chan_dict[key] = [(tb.id, ch.connected_to)] else: - chan_dict[key].append((tid, ch.connected_to)) + chan_dict[key].append((tb.id, ch.connected_to)) for key, value in chan_dict.items(): chan_dict[key] = sorted(value) gpu.channels = chan_dict @@ -545,20 +553,20 @@ def dump_to_json(program: Program): 'id': id, 'input_chunks': gpu.input_chunks, 'output_chunks': gpu.output_chunks, - 'scratch_size': gpu.scratch_chunks, + 'scratch_chunks': gpu.scratch_chunks, 'threadblocks': [], "channels": [] } for (srcBuffer, dstBuffer, type), channels in gpu.channels.items(): obj = { - "srcBuffer": srcBuffer.name, - "dstBuffer": dstBuffer.name, + "srcBuffer": srcBuffer.name if hasattr(srcBuffer, 'name') else srcBuffer, + "dstBuffer": dstBuffer.name if hasattr(dstBuffer, 'name') else dstBuffer, "type": type.name, "connectedTo": [eles[1] for eles in channels], "threadblockMap": [eles[0] for eles in channels] } gpu_instance["channels"].append(obj) - for id, tb in enumerate(gpu.threadblocks): + for tb in gpu.threadblocks: ops = [] for op in tb.ops: instr = { @@ -570,10 +578,11 @@ def dump_to_json(program: Program): "dstbuff": op.dst.buffer.name if op.dst.buffer else None, "dstoff": op.dst.index if op.dst else None, "channel_type": op.channel_type.name, + "cnt": op.cnt(), } ops.append(instr) threadblock = { - 'id': id, + 'id': tb.id, 'ops': ops } gpu_instance['threadblocks'].append(threadblock) diff --git a/msccl/language/tb_assignment.py b/msccl/language/tb_assignment.py index b5fe501..83d0171 100755 --- a/msccl/language/tb_assignment.py +++ b/msccl/language/tb_assignment.py @@ -47,7 +47,7 @@ def convert_to_exectuion_plan(instr_dag): rank = op.rank tbid = op.tb if tbid not in instr_dag.tbs[rank]: - instr_dag.tbs[rank][tbid] = Threadblock() + instr_dag.tbs[rank][tbid] = Threadblock(id=tbid) tb = instr_dag.tbs[rank][tbid] tb.ops.append(op) From 0448b1d8897c8e3cd07271f55058596658ecaa3e Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Sun, 24 Mar 2024 06:47:43 +0000 Subject: [PATCH 08/76] WIP --- .../allreduce_a100_allpairs_mscclpp.py | 51 ------------------- .../allreduce_a100_allpairs_packet_mscclpp.py | 14 ++--- msccl/language/__init__.py | 4 +- msccl/language/ir.py | 5 ++ msccl/language/rank_dag.py | 13 ++++- 5 files changed, 28 insertions(+), 59 deletions(-) delete mode 100644 examples/mscclang/allreduce_a100_allpairs_mscclpp.py diff --git a/examples/mscclang/allreduce_a100_allpairs_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_mscclpp.py deleted file mode 100644 index 53a81a0..0000000 --- a/examples/mscclang/allreduce_a100_allpairs_mscclpp.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import argparse -from msccl.language import * -from msccl.topologies import * -from msccl.language.collectives import AllReduce - -def allreduce_allpairs(gpus, instances, protocol): - size = gpus - chunksperloop = gpus * gpus - topology = fully_connected(size) - collective = AllReduce(size, chunksperloop, True) - with MSCCLProgram("allreduce_pairs", topology, collective, instances, protocol=protocol, - interleaved_replication=False, threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True): - - # Each rank sends the nth chunk to the nth rank into scratch space - for r1 in range(size): - for r2 in range(size): - if r1 != r2: - index = r2 * size - c = chunk(r1, Buffer.input, index, size=size) - c.put(r2, 'scratch', index=r1, sendtb=r2) - - # Each rank performs a local reduction on the nth chunk - # Utilize 8 threadblocks for this reduction for better parallelism - for r in range(size): - for index in range(0, size * (size-1)): - c = chunk(r, Buffer.input, r*size + (index % size)) - c.reduce(chunk(r, 'scratch', index), sendtb=(index % size)) - - # Each rank sends the fully reduced nth chunk to all other gpus - for r1 in range(size): - for r2 in range(size): - index = r1 * size - c = chunk(r1, Buffer.input, index + r2) - for r3 in range(size): - if r3 != r1: - c.put(r3, Buffer.input, index, sendtb=r2) - - XML() - Check() - -parser = argparse.ArgumentParser() -parser.add_argument('num_gpus', type=int, help ='number of gpus') -parser.add_argument('instances', type=int, help='number of instances') -parser.add_argument('--protocol', type=str, default='LL', choices=['Simple', 'LL'], help='Protocol') - -args = parser.parse_args() - -allreduce_allpairs(args.num_gpus, args.instances, args.protocol) diff --git a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py index c0abd31..8713925 100644 --- a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py +++ b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py @@ -22,12 +22,14 @@ def allreduce_allpairs(gpus, instances): c = chunk(r1, Buffer.input, index, size=size) c.put(r2, 'scratch', index=r1*size, sendtb=r2) - # # Each rank performs a local reduction on the nth chunk - # # Utilize 8 threadblocks for this reduction for better parallelism - # for r in range(size): - # for index in range(0, size * (size-1)): - # c = chunk(r, Buffer.input, r*size + (index % size)) - # c.reduce(chunk(r, 'scratch', index), sendtb=(index % size)) + # Each rank performs a local reduction on the nth chunk + # Utilize 8 threadblocks for this reduction for better parallelism + for r in range(size): + for index in range(size): + c = chunk(r, Buffer.input, r*size + index) + for peer in range(size): + if peer != r: + c.reduce(chunk(r, 'scratch', peer*size+index), sendtb=index) # # Each rank sends the fully reduced nth chunk to all other gpus # for r1 in range(size): diff --git a/msccl/language/__init__.py b/msccl/language/__init__.py index bb2dc73..05a941c 100755 --- a/msccl/language/__init__.py +++ b/msccl/language/__init__.py @@ -230,7 +230,9 @@ def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) - self.prog.instr_dag.add_put(sender, self, dst_chunkref, sendtb, chan_type) + sop = self.prog.instr_dag.add_put(sender, self, dst_chunkref, sendtb, chan_type) + if self.prog.protocol == 'LL': + self.prog.instr_dag.add_packet_recv(receiver, self, dst_chunkref, chan_type, sop) return dst_chunkref diff --git a/msccl/language/ir.py b/msccl/language/ir.py index f02172f..965ddaf 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -113,6 +113,7 @@ class Instruction(Enum): delete = 'd' start = 'st' put = 'put' + packet_recv = 'packet_recv' get = 'get' wait = 'wait' signal = 'signal' @@ -567,8 +568,12 @@ def dump_to_json(program: Program): } gpu_instance["channels"].append(obj) for tb in gpu.threadblocks: + if tb.id == -1: + continue ops = [] for op in tb.ops: + if op.tb == -1: + continue instr = { "name": op.inst.name, "src": op.src.rank if op.src else None, diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index 5fa50c9..d55c245 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -111,7 +111,8 @@ def add_copy(self, rank, send_ref, recv_ref, tb, ch): # InstructionDAG - adds a redduce node def add_reduce(self, rank, send_ref, recv_ref, tb, ch): - op = Op(Instruction.reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) + tb_step = self._get_tb_step(rank, tb) + op = Op(Instruction.reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch, step=tb_step) dstbuffer = recv_ref.buffer dstindex = recv_ref.index srcbuffer = send_ref.buffer @@ -143,6 +144,16 @@ def add_put(self, rank, send_ref, recv_ref, tb, ch_type): self._read(rank, buffer, index, size, op) return op + def add_packet_recv(self, rank, send_ref, recv_ref, ch_type, send_op): + # This is mock instruction for packet recv + op = Op(Instruction.packet_recv, rank, send_ref, recv_ref, next=set(), prev=set(), channel_type=ch_type) + buffer = recv_ref.buffer + index = recv_ref.index + size = recv_ref.size + self._write(rank, buffer, index, size, op) + op.send_match = send_op + return op + # InstructionDAG - adds a signal node. def add_signal(self, rank, send_ref, recv_ref, tb, ch_type): tb_step = self._get_tb_step(rank, tb) From e23dc1832b1970cd74b2e975392e1025024ef6b9 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Sun, 24 Mar 2024 07:28:03 +0000 Subject: [PATCH 09/76] WIP need fuse --- .../allreduce_a100_allpairs_packet_mscclpp.py | 18 ++++++++++-------- msccl/language/rank_dag.py | 3 ++- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py index 8713925..7d7c52a 100644 --- a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py +++ b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py @@ -30,15 +30,17 @@ def allreduce_allpairs(gpus, instances): for peer in range(size): if peer != r: c.reduce(chunk(r, 'scratch', peer*size+index), sendtb=index) + for peer in range(size): + if peer != r: + c.put(peer, 'scratch', (size*size)+r*size+index, sendtb=index) - # # Each rank sends the fully reduced nth chunk to all other gpus - # for r1 in range(size): - # for r2 in range(size): - # index = r1 * size - # c = chunk(r1, Buffer.input, index + r2) - # for r3 in range(size): - # if r3 != r1: - # c.put(r3, Buffer.input, index, sendtb=r2) + # Each rank get final result from scratch space + for r in range(size): + for index in range(size): + for peer in range(size): + if peer != r: + c = chunk(r, 'scratch', size*size+peer*size+index) + c.copy(r, Buffer.input, peer*size+index, sendtb=index) Json() # Check() diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index d55c245..9c0e0e3 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -97,7 +97,8 @@ def add_start(self, rank, buffer, index, ref): # InstructionDAG - adds a copy node def add_copy(self, rank, send_ref, recv_ref, tb, ch): - op = Op(Instruction.copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) + tb_step = self._get_tb_step(rank, tb) + op = Op(Instruction.copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch, step=tb_step) dstbuffer = recv_ref.buffer dstindex = recv_ref.index srcbuffer = send_ref.buffer From 2e08484e5c1a0cde036c1d07c5fc6a59c81a95b6 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Sun, 24 Mar 2024 09:33:11 +0000 Subject: [PATCH 10/76] WIP --- ...lpp_v2.py => allreduce_a100_allpairs_sm_mscclpp.py} | 6 +++--- msccl/language/__init__.py | 5 +++-- msccl/language/ir.py | 1 - msccl/language/rank_dag.py | 10 ---------- 4 files changed, 6 insertions(+), 16 deletions(-) rename examples/mscclang/{allreduce_a100_allpairs_mscclpp_v2.py => allreduce_a100_allpairs_sm_mscclpp.py} (92%) diff --git a/examples/mscclang/allreduce_a100_allpairs_mscclpp_v2.py b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py similarity index 92% rename from examples/mscclang/allreduce_a100_allpairs_mscclpp_v2.py rename to examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py index 4589e9b..3ef3621 100644 --- a/examples/mscclang/allreduce_a100_allpairs_mscclpp_v2.py +++ b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py @@ -21,7 +21,7 @@ def allreduce_allpairs(gpus, instances, protocol): c = chunk(rank, Buffer.input, index + tb) for nghr in range(size): if rank != nghr: - c.reduce(chunk(nghr, 'input', index + tb), recvtb==tb) + c.reduce(chunk(nghr, 'input', index + tb), recvtb=tb) # Each rank sends the fully reduced nth chunk to all other gpus for rank in range(size): @@ -32,13 +32,13 @@ def allreduce_allpairs(gpus, instances, protocol): if rank != nghr: c.put(nghr, Buffer.input, index, sendtb=tb) - XML() + Json() Check() parser = argparse.ArgumentParser() parser.add_argument('num_gpus', type=int, help ='number of gpus') parser.add_argument('instances', type=int, help='number of instances') -parser.add_argument('--protocol', type=str, default='LL', choices=['Simple', 'LL'], help='Protocol') +parser.add_argument('--protocol', type=str, default='Simple', choices=['Simple'], help='Protocol') args = parser.parse_args() diff --git a/msccl/language/__init__.py b/msccl/language/__init__.py index 05a941c..a5b2754 100755 --- a/msccl/language/__init__.py +++ b/msccl/language/__init__.py @@ -230,9 +230,10 @@ def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) - sop = self.prog.instr_dag.add_put(sender, self, dst_chunkref, sendtb, chan_type) + self.prog.instr_dag.add_put(sender, self, dst_chunkref, sendtb, chan_type) if self.prog.protocol == 'LL': - self.prog.instr_dag.add_packet_recv(receiver, self, dst_chunkref, chan_type, sop) + self.prog.instr_dag.add_signal(sender, self, dst_chunkref, -1, ChannelType.none) + self.prog.instr_dag.add_wait(receiver, dst_chunkref, self, -1, ChannelType.none) return dst_chunkref diff --git a/msccl/language/ir.py b/msccl/language/ir.py index 965ddaf..30de374 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -113,7 +113,6 @@ class Instruction(Enum): delete = 'd' start = 'st' put = 'put' - packet_recv = 'packet_recv' get = 'get' wait = 'wait' signal = 'signal' diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index 9c0e0e3..7316c94 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -145,16 +145,6 @@ def add_put(self, rank, send_ref, recv_ref, tb, ch_type): self._read(rank, buffer, index, size, op) return op - def add_packet_recv(self, rank, send_ref, recv_ref, ch_type, send_op): - # This is mock instruction for packet recv - op = Op(Instruction.packet_recv, rank, send_ref, recv_ref, next=set(), prev=set(), channel_type=ch_type) - buffer = recv_ref.buffer - index = recv_ref.index - size = recv_ref.size - self._write(rank, buffer, index, size, op) - op.send_match = send_op - return op - # InstructionDAG - adds a signal node. def add_signal(self, rank, send_ref, recv_ref, tb, ch_type): tb_step = self._get_tb_step(rank, tb) From e634e6fa3fbb7e2d64a5e19ecf4d05f60154e8cd Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 25 Mar 2024 06:49:03 +0000 Subject: [PATCH 11/76] WIP --- .../allreduce_a100_allpairs_packet_mscclpp.py | 2 +- .../allreduce_a100_allpairs_sm_mscclpp.py | 25 ++++++-- .../allreduce_a100_allpairs_sm_mscclpp_get.py | 59 +++++++++++++++++++ msccl/language/__init__.py | 44 +++++++++++++- msccl/language/ir.py | 1 + msccl/language/rank_dag.py | 18 ++++++ 6 files changed, 139 insertions(+), 10 deletions(-) create mode 100644 examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py diff --git a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py index 7d7c52a..bc12011 100644 --- a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py +++ b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py @@ -29,7 +29,7 @@ def allreduce_allpairs(gpus, instances): c = chunk(r, Buffer.input, r*size + index) for peer in range(size): if peer != r: - c.reduce(chunk(r, 'scratch', peer*size+index), sendtb=index) + c.reduce_mscclpp(chunk(r, 'scratch', peer*size+index), sendtb=index) for peer in range(size): if peer != r: c.put(peer, 'scratch', (size*size)+r*size+index, sendtb=index) diff --git a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py index 3ef3621..e1ea01f 100644 --- a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py +++ b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py @@ -19,21 +19,34 @@ def allreduce_allpairs(gpus, instances, protocol): for tb in range(size): index = rank * size c = chunk(rank, Buffer.input, index + tb) + # make sure the data is ready for nghr in range(size): if rank != nghr: - c.reduce(chunk(nghr, 'input', index + tb), recvtb=tb) + c.signal(nghr, Buffer.input, index + tb, sendtb=tb) + for nghr in range(size): + index = nghr * size + if rank != nghr: + c.wait(nghr, Buffer.input, index + tb, recvtb=tb) + for nghr in range(size): + if rank != nghr: + c.reduce_mscclpp(chunk(nghr, Buffer.input, index + tb), recvtb=tb) + for nghr in range(size): + if rank != nghr: + c.put(nghr, Buffer.input, index, sendtb=tb) + for nghr in range(size): + if rank != nghr: + c.signal(nghr, Buffer.input, index, sendtb=tb) - # Each rank sends the fully reduced nth chunk to all other gpus + # wait for all the chunks to be received for rank in range(size): for tb in range(size): - index = rank * size - c = chunk(rank, Buffer.input, index + tb) for nghr in range(size): if rank != nghr: - c.put(nghr, Buffer.input, index, sendtb=tb) + index = nghr * size + c = chunk(rank, Buffer.input, index + tb) + c.wait(nghr, Buffer.input, index, recvtb=tb) Json() - Check() parser = argparse.ArgumentParser() parser.add_argument('num_gpus', type=int, help ='number of gpus') diff --git a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py new file mode 100644 index 0000000..138c39e --- /dev/null +++ b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +from msccl.language import * +from msccl.topologies import * +from msccl.language.collectives import AllReduce + +def allreduce_allpairs(gpus, instances, protocol): + size = gpus + chunksperloop = gpus * gpus + topology = fully_connected(size) + collective = AllReduce(size, chunksperloop, True) + with MSCCLProgram("allreduce_pairs", topology, collective, instances, protocol=protocol, + interleaved_replication=False, threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True): + + # Each rank sends the nth chunk to the nth rank into scratch space + for rank in range(size): + for tb in range(size): + index = rank * size + c = chunk(rank, Buffer.input, index + tb) + # make sure the data is ready + for nghr in range(size): + if rank != nghr: + c.signal(nghr, Buffer.input, index + tb, sendtb=tb) + for nghr in range(size): + index = nghr * size + if rank != nghr: + c.wait(nghr, Buffer.input, index + tb, recvtb=tb) + # reduce the chunks + for nghr in range(size): + if rank != nghr: + c.reduce_mscclpp(chunk(nghr, Buffer.input, index + tb), recvtb=tb) + for nghr in range(size): + if rank != nghr: + c.signal(nghr, Buffer.input, index, sendtb=tb) + + # wait for all the chunks is ready, then get the chunks + for rank in range(size): + for tb in range(size): + for nghr in range(size): + if rank != nghr: + index = nghr * size + c = chunk(rank, Buffer.input, index + tb) + c.wait(nghr, Buffer.input, index, recvtb=tb) + for nghr in range(size): + if rank != nghr: + c.get(nghr, Buffer.input, index, recvtb=tb) + + Json() + +parser = argparse.ArgumentParser() +parser.add_argument('num_gpus', type=int, help ='number of gpus') +parser.add_argument('instances', type=int, help='number of instances') +parser.add_argument('--protocol', type=str, default='Simple', choices=['Simple'], help='Protocol') + +args = parser.parse_args() + +allreduce_allpairs(args.num_gpus, args.instances, args.protocol) diff --git a/msccl/language/__init__.py b/msccl/language/__init__.py index a5b2754..f45249b 100755 --- a/msccl/language/__init__.py +++ b/msccl/language/__init__.py @@ -20,6 +20,11 @@ def _curr(): raise RuntimeError("No Program in context") return _current_program +# For msccl++ program, we have one assumption that for channel can be identified by (send_buffer, recv_buffer, type, send_tb/recv_tb) +# which means the send_tb and recv_tb should be the same for a pair of signal and wait, also same for put/get operation. +# If one sender what to send data to peer want to use different tb in receiver side. We need to send to same tb in receiver side first, +# then performance a across tb sync. This is a limitation of current implementation. + class MSCCLProgram: def __init__(self, name, topo, collective, instances, protocol='Simple', \ threadblock_policy=ThreadblockPolicy.auto, interleaved_replication=True, @@ -235,10 +240,28 @@ def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): self.prog.instr_dag.add_signal(sender, self, dst_chunkref, -1, ChannelType.none) self.prog.instr_dag.add_wait(receiver, dst_chunkref, self, -1, ChannelType.none) - return dst_chunkref - - def get(self, src, buffer=None, index=-1, recvtb=-1): + def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm): self.prog.check_buffer_exists(src, buffer) + sender = src + receiver = self.rank + assert sender != receiver, 'Cannot get from the same rank' + + # If index is not specified assume it is going to the same place in the next gpu + if index == -1 and buffer == None: + index = self.index + buffer = self.buffer + elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output: + index = self.prog.buffers[src][buffer].instance_size() + + # Some inplace collectives have custom logic for buffers and index (ReduceScatter, AllGather) + buffer, index = self.prog.collective.get_buffer_index(src, buffer, index) + + # Direct get + assert (self.prog.topo.link(self.rank, src) or src == self.rank), f'No link from {self.rank} to {src}' + src_chunkref = self.prog.get_ref(src, buffer, index, self.size) + + self.prog.apply_send(src, buffer, index, self.rank, self.buffer, self.index, self.size) + self.prog.instr_dag.add_get(receiver, src_chunkref, self, recvtb, chan_type) # for signal and wait, currently we assuem the pair will use the same tb index. In future we need # to infer the tb index from the instruction DAG Add a channel is define as (send_tb, src_buffer, recv_tb, dst_buffer, type). @@ -340,6 +363,21 @@ def reduce(self, other_chunkref, sendtb=-1, recvtb=-1, ch=-1): return self + # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref + def reduce_mscclpp(self, other_chunkref, sendtb=-1, recvtb=-1, channel_type=ChannelType.sm): + # Receive reduce copy + dst = self.rank + src = other_chunkref.rank + assert (self.prog.topo.link(src, dst) or src == dst), f'No link from {src} to {dst}' + self.prog.apply_reduce(src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size) + + if src != dst: + self.prog.instr_dag.add_read_reduce_copy(dst, other_chunkref, self, recvtb, channel_type) + else: + self.prog.instr_dag.add_reduce(src, other_chunkref, self, sendtb, ChannelType.none) + + return self + def get_origin_index(self, index=0): return self._get_chunk(index + self.index).origin_index diff --git a/msccl/language/ir.py b/msccl/language/ir.py index 30de374..1aa14e3 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -108,6 +108,7 @@ class Instruction(Enum): recv_reduce_send = 'rrs' recv_reduce_copy = 'rrc' recv_reduce_copy_send = 'rrcs' + read_reduce_copy = "rrc" copy = 'cpy' reduce = 're' delete = 'd' diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index 7316c94..1c61474 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -145,6 +145,15 @@ def add_put(self, rank, send_ref, recv_ref, tb, ch_type): self._read(rank, buffer, index, size, op) return op + def add_get(self, rank, send_ref, recv_ref, tb, ch_type): + tb_step = self._get_tb_step(rank, tb) + op = Op(Instruction.get, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step) + buffer = recv_ref.buffer + index = recv_ref.index + size = recv_ref.size + self._write(rank, buffer, index, size, op) + return op + # InstructionDAG - adds a signal node. def add_signal(self, rank, send_ref, recv_ref, tb, ch_type): tb_step = self._get_tb_step(rank, tb) @@ -185,6 +194,15 @@ def add_recv_reduce_copy(self, rank, send_ref, recv_ref, tb, ch, send_op): op.send_match = send_op return op + def add_read_reduce_copy(self, rank, send_ref, recv_ref, tb, ch_type): + tb_step = self._get_tb_step(rank, tb) + op = Op(Instruction.read_reduce_copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step) + buffer = recv_ref.buffer + index = recv_ref.index + size = recv_ref.size + self._write(rank, buffer, index, size, op, read=True) + return op + def convert_set_list(self): ops = [] visited = set() From f8fe329f298be1d69f8379ef346f60a079560d2f Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 25 Mar 2024 11:44:18 +0000 Subject: [PATCH 12/76] need more fuse --- .../allreduce_a100_allpairs_sm_mscclpp.py | 10 +- .../allreduce_a100_allpairs_sm_mscclpp_get.py | 1 - msccl/language/__init__.py | 2 +- msccl/language/ir.py | 3 + msccl/language/rank_dag.py | 152 +++++++++++++++++- 5 files changed, 156 insertions(+), 12 deletions(-) diff --git a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py index e1ea01f..5840ae5 100644 --- a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py +++ b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py @@ -14,32 +14,28 @@ def allreduce_allpairs(gpus, instances, protocol): with MSCCLProgram("allreduce_pairs", topology, collective, instances, protocol=protocol, interleaved_replication=False, threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True): - # Each rank sends the nth chunk to the nth rank into scratch space for rank in range(size): for tb in range(size): index = rank * size c = chunk(rank, Buffer.input, index + tb) - # make sure the data is ready + # step1 make sure the data is ready for nghr in range(size): if rank != nghr: c.signal(nghr, Buffer.input, index + tb, sendtb=tb) for nghr in range(size): - index = nghr * size if rank != nghr: c.wait(nghr, Buffer.input, index + tb, recvtb=tb) + # step2 reduce the chunks and send to peers for nghr in range(size): if rank != nghr: c.reduce_mscclpp(chunk(nghr, Buffer.input, index + tb), recvtb=tb) for nghr in range(size): if rank != nghr: c.put(nghr, Buffer.input, index, sendtb=tb) + # step3 signal the peers to receive the chunks for nghr in range(size): if rank != nghr: c.signal(nghr, Buffer.input, index, sendtb=tb) - - # wait for all the chunks to be received - for rank in range(size): - for tb in range(size): for nghr in range(size): if rank != nghr: index = nghr * size diff --git a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py index 138c39e..165cc02 100644 --- a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py +++ b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py @@ -24,7 +24,6 @@ def allreduce_allpairs(gpus, instances, protocol): if rank != nghr: c.signal(nghr, Buffer.input, index + tb, sendtb=tb) for nghr in range(size): - index = nghr * size if rank != nghr: c.wait(nghr, Buffer.input, index + tb, recvtb=tb) # reduce the chunks diff --git a/msccl/language/__init__.py b/msccl/language/__init__.py index f45249b..ffa4524 100755 --- a/msccl/language/__init__.py +++ b/msccl/language/__init__.py @@ -135,7 +135,7 @@ def lower_mscclpp(self): convert_to_exectuion_plan(self.instr_dag) self.instr_dag.complete_channels() if self.instr_fusion: - self.instr_dag.optimize_mscclpp() + self.instr_dag.optimize_mscclpp(self.protocol) self.instr_dag.lower_pt1(self.instances) gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.interleaved_replication) return Program(self.name, self.collective.name, self.collective.inplace, self.protocol, gpu_prgms) diff --git a/msccl/language/ir.py b/msccl/language/ir.py index 1aa14e3..4df345e 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -109,6 +109,7 @@ class Instruction(Enum): recv_reduce_copy = 'rrc' recv_reduce_copy_send = 'rrcs' read_reduce_copy = "rrc" + reduce_send = 'rs' copy = 'cpy' reduce = 're' delete = 'd' @@ -153,6 +154,8 @@ class Op: send_match = None channel: int = -1 channel_type: ChannelType = ChannelType.none + srcs: list = field(default_factory=list) + dsts: list = field(default_factory=list) def cnt(self): if self.src: diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index 1c61474..fc154eb 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -27,6 +27,15 @@ def same_count(op1, op2): def same_buf_dst(op1, op2): return op1.dst.buffer == op2.dst.buffer and op1.dst.index == op2.dst.index +def buf_dst_src_match(op1, op2): + return op1.dst.buffer == op2.src.buffer and op1.dst.index == op2.src.index + +def same_buf_src(op1, op2): + return op1.src.buffer == op2.src.buffer and op1.src.index == op2.src.index + +def same_chan_type(op1, op2): + return op1.channel_type == op2.channel_type + class InstructionDAG: def __init__(self, num_ranks, buffers): self.num_ranks = num_ranks @@ -120,6 +129,7 @@ def add_reduce(self, rank, send_ref, recv_ref, tb, ch): srcindex = send_ref.index size = recv_ref.size prev_ops = [] + op.srcs.append(ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size)) # Sending part of reduce self._read(rank, srcbuffer, srcindex, size, op) # Reduce part of copy @@ -163,6 +173,7 @@ def add_signal(self, rank, send_ref, recv_ref, tb, ch_type): size = send_ref.size # treat signal as a write since it can not be executed parallelly with read operations self._write(rank, buffer, index, size, op) + op.dsts.append(ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size)) return op def add_wait(self, rank, dst_ref, src_ref, tb, ch_type): @@ -172,6 +183,7 @@ def add_wait(self, rank, dst_ref, src_ref, tb, ch_type): index = dst_ref.index size = dst_ref.size self._write(rank, buffer, index, size, op) + op.srcs.append(ChunkRef(src_ref.rank, src_ref.buffer, src_ref.index, src_ref.size)) return op # InstructionDAG - adds a recv node @@ -200,6 +212,7 @@ def add_read_reduce_copy(self, rank, send_ref, recv_ref, tb, ch_type): buffer = recv_ref.buffer index = recv_ref.index size = recv_ref.size + op.srcs.append(ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size)) self._write(rank, buffer, index, size, op, read=True) return op @@ -242,10 +255,138 @@ def complete_channels(self): chan = Channel(op.dst.buffer, op.src.buffer, op.channel_type, op.src.rank) chans.add(chan) tb.channels = list(chans) - pass - def optimize_mscclpp(self): - pass + def _optimize_redandant_signal_wait(self, protocol): + if protocol != 'LL': + return + # For LL algorithm, we can remove signal/wait + for rank, rank_tbs in enumerate(self.tbs): + for tbid, tb in rank_tbs.items(): + queue = list(tb.ops) + while len(queue) > 0: + op = queue[0] + if op.inst == Instruction.put: + fused = False + for next_op in op.next: + if next_op.inst == Instruction.signal: + remove_op(next_op) + fused = True + break + if fused: + continue + elif op.inst == Instruction.reduce or op.inst == Instruction.read_reduce_copy or op.inst == Instruction.copy: + fused = False + for prev_op in op.prev: + if prev_op.inst == Instruction.wait: + remove_op(prev_op) + fused = True + break + if fused: + continue + queue = queue[1:] + + # rrc(_,_,_,dst,dbuf,di) rrc(_,_,_,dst,dbuf,di) -> rrc(list[src,sbuf,si], dst, dbuf, di) + # signal(_,_,_,dst,dbuf,di) signal(_,_,_,dst,dbuf,di) -> signal(_,_,_,list[dst,dbuf,di]) + # wait(src,sbuf,si,_,_,_) wait(src,sbuf,si,_,_,_) -> wait(list[src,sbuf,si],_,_,_,_]) + # reduce(_,_,_,dst,dbuf,di) reduce(_,_,_,dst,dbuf,di) -> reduce(list[src,sbuf,si], dst, dbuf, di) + def _optimize_rrc_r_signal_wait(self): + for rank, rank_tbs in enumerate(self.tbs): + for tbid, tb in rank_tbs.items(): + queue = list(tb.ops) + while len(queue) > 0: + op = queue[0] + if op.inst == Instruction.recv_reduce_copy: + fused = False + for next_op in op.next: + if next_op.inst == Instruction.recv_reduce_copy and same_count(op, next_op) and same_buf_dst(op, next_op) and same_chan_type(op, next_op): + op.srcs.append(ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size)) + remove_op(next_op) + tb.ops.remove(next_op) + queue.remove(next_op) + fused = True + break + if fused: + continue + elif op.inst == Instruction.reduce: + fused = False + for next_op in op.next: + if next_op.inst == Instruction.reduce and same_buf_dst(op, next_op) and same_chan_type(op, next_op): + op.srcs.append(ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size)) + remove_op(next_op) + tb.ops.remove(next_op) + queue.remove(next_op) + fused = True + break + if fused: + continue + elif op.inst == Instruction.signal: + fused = False + for next_op in op.next: + if next_op.inst == Instruction.signal and same_buf_src(op, next_op) and same_chan_type(op, next_op): + op.dsts.append(ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size)) + remove_op(next_op) + tb.ops.remove(next_op) + queue.remove(next_op) + fused = True + break + if fused: + continue + elif op.inst == Instruction.wait: + fused = False + for next_op in op.next: + if next_op.inst == Instruction.wait and same_buf_dst(op, next_op) and same_chan_type(op, next_op): + op.srcs.append(ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size)) + remove_op(next_op) + tb.ops.remove(next_op) + queue.remove(next_op) + fused = True + break + if fused: + continue + queue = queue[1:] + + # rrc(_,_,_,dst,dbuf,di) send(dst,dbuf,di,_,_,_) -> rrcs(_,_,_,_,_,_) + # reduce(_,_,_,dst,dbuf,di) send(dst,dbuf,di,_,_,_) -> rrs(_,_,_,_,_,_) + def _optimize_rrcs_rs(self): + for rank, rank_tbs in enumerate(self.tbs): + for tbid, tb in rank_tbs.items(): + queue = list(tb.ops) + while len(queue) > 0: + op = queue[0] + if op.inst == Instruction.recv_reduce_copy or op.inst == Instruction.recv_reduce_copy_send: + fused = False + for next_op in op.next: + if next_op.inst == Instruction.put and same_count(op, next_op) and buf_dst_src_match(op, next_op) and same_chan_type(op, next_op): + if op.inst == Instruction.recv_reduce_copy: + op.inst = Instruction.recv_reduce_copy_send + op.dsts.append(ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size)) + remove_op(next_op) + tb.ops.remove(next_op) + queue.remove(next_op) + fused = True + break + if fused: + continue + if op.inst == Instruction.reduce or op.inst == Instruction.reduce_send: + fused = False + for next_op in op.next: + if next_op.inst == Instruction.put and same_count(op, next_op) and buf_dst_src_match(op, next_op): + if op.inst == Instruction.reduce: + op.inst = Instruction.reduce_send + op.dsts.append(ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size)) + remove_op(next_op) + tb.ops.remove(next_op) + queue.remove(next_op) + fused = True + break + if fused: + continue + queue = queue[1:] + + def optimize_mscclpp(self, protocol): + self._optimize_redandant_signal_wait(protocol) + self._optimize_rrc_r_signal_wait() + self._optimize_rrcs_rs() # Completes metadata for chunk_steps (number of steps from a start op) and priority (number of steps to the last op) def _complete_metadata(self): @@ -320,6 +461,11 @@ def _optimize_rrcs_rrs(self): remove_op(next_op) frontier = frontier[1:] + op.next + # For signal/wait ops, if they are independent of other operations and no other operations in between, + # then merge them into a single signal/wait op + def _parallel_signal_wait(self): + pass + def _get_tb_step(self, rank, tb): if tb in self.tb_steps[rank]: self.tb_steps[rank][tb] += 1 From b4c08c91cca79c2742cd63bba0f2061a5da7378c Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 25 Mar 2024 14:55:40 +0000 Subject: [PATCH 13/76] WIP --- msccl/language/rank_dag.py | 63 ++++++++++++++++++++++++++++++++++---- 1 file changed, 57 insertions(+), 6 deletions(-) diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index fc154eb..170a45e 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -16,7 +16,19 @@ def remove_op(op): for n in op.next: n.prev.remove(op) - n.prev = op.prev.union(n.prev) + n.prev = op.prev.union(n.prev) + +def merge_op(op, other_op): + for p in other_op.prev: + p.next.remove(other_op) + p.next.append(op) + + for n in other_op.next: + n.prev.remove(other_op) + n.prev.add(op) + + op.prev = op.prev.union(other_op.prev) + op.next += other_op.next def same_tb(op1, op2): return op1.tb == op2.tb and op1.channel == op2.channel @@ -27,6 +39,9 @@ def same_count(op1, op2): def same_buf_dst(op1, op2): return op1.dst.buffer == op2.dst.buffer and op1.dst.index == op2.dst.index +def same_src_dst_buffer_type(op1, op2): + return op1.src.buffer == op2.src.buffer and op1.dst.buffer == op2.dst.buffer + def buf_dst_src_match(op1, op2): return op1.dst.buffer == op2.src.buffer and op1.dst.index == op2.src.index @@ -174,6 +189,7 @@ def add_signal(self, rank, send_ref, recv_ref, tb, ch_type): # treat signal as a write since it can not be executed parallelly with read operations self._write(rank, buffer, index, size, op) op.dsts.append(ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size)) + op.srcs.append(ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size)) return op def add_wait(self, rank, dst_ref, src_ref, tb, ch_type): @@ -184,6 +200,7 @@ def add_wait(self, rank, dst_ref, src_ref, tb, ch_type): size = dst_ref.size self._write(rank, buffer, index, size, op) op.srcs.append(ChunkRef(src_ref.rank, src_ref.buffer, src_ref.index, src_ref.size)) + op.dsts.append(ChunkRef(dst_ref.rank, dst_ref.buffer, dst_ref.index, dst_ref.size)) return op # InstructionDAG - adds a recv node @@ -383,11 +400,50 @@ def _optimize_rrcs_rs(self): continue queue = queue[1:] + # For signal/wait ops, if they are independent of other operations and no other operations in between, + # then merge them into a single signal/wait op + # wait(src,sbuf,si,_,_,_) wait(src,sbuf,si,_,_,_) -> wait(list[src,sbuf,si],_,_,_,_]) + def _parallel_signal_wait(self): + for rank, rank_tbs in enumerate(self.tbs): + for tbid, tb in rank_tbs.items(): + queue = list(tb.ops) + while len(queue) > 0: + op = queue[0] + if op.inst == Instruction.signal: + fused = False + if len(queue) > 1: + seq_op = queue[1] + if seq_op.inst == Instruction.signal and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op): + op.dsts.append(ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size)) + op.srcs.append(ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size)) + merge_op(op, seq_op) + tb.ops.remove(seq_op) + queue.remove(seq_op) + fused = True + if fused: + continue + elif op.inst == Instruction.wait: + fused = False + if len(queue) > 1: + seq_op = queue[1] + if seq_op.inst == Instruction.wait and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op): + op.dsts.append(ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size)) + op.srcs.append(ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size)) + merge_op(op, seq_op) + tb.ops.remove(seq_op) + queue.remove(seq_op) + fused = True + if fused: + continue + queue = queue[1:] + def optimize_mscclpp(self, protocol): self._optimize_redandant_signal_wait(protocol) self._optimize_rrc_r_signal_wait() self._optimize_rrcs_rs() + self._parallel_signal_wait() + # Completes metadata for chunk_steps (number of steps from a start op) and priority (number of steps to the last op) def _complete_metadata(self): def dfs(op, cs): @@ -461,11 +517,6 @@ def _optimize_rrcs_rrs(self): remove_op(next_op) frontier = frontier[1:] + op.next - # For signal/wait ops, if they are independent of other operations and no other operations in between, - # then merge them into a single signal/wait op - def _parallel_signal_wait(self): - pass - def _get_tb_step(self, rank, tb): if tb in self.tb_steps[rank]: self.tb_steps[rank][tb] += 1 From e4066321ff2a28625575812d11e7192dfae5a7f7 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Tue, 26 Mar 2024 08:56:30 +0000 Subject: [PATCH 14/76] WIP --- .../allreduce_a100_allpairs_sm_mscclpp_get.py | 8 +- msccl/language/ir.py | 96 ++++++++++++++++--- msccl/language/rank_dag.py | 52 +++++----- 3 files changed, 115 insertions(+), 41 deletions(-) diff --git a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py index 165cc02..d9442fd 100644 --- a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py +++ b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py @@ -32,7 +32,7 @@ def allreduce_allpairs(gpus, instances, protocol): c.reduce_mscclpp(chunk(nghr, Buffer.input, index + tb), recvtb=tb) for nghr in range(size): if rank != nghr: - c.signal(nghr, Buffer.input, index, sendtb=tb) + c.signal(nghr, Buffer.input, index + tb, sendtb=tb) # wait for all the chunks is ready, then get the chunks for rank in range(size): @@ -41,10 +41,12 @@ def allreduce_allpairs(gpus, instances, protocol): if rank != nghr: index = nghr * size c = chunk(rank, Buffer.input, index + tb) - c.wait(nghr, Buffer.input, index, recvtb=tb) + c.wait(nghr, Buffer.input, index + tb, recvtb=tb) for nghr in range(size): + index = nghr * size if rank != nghr: - c.get(nghr, Buffer.input, index, recvtb=tb) + c = chunk(rank, Buffer.input, index + tb) + c.get(nghr, Buffer.input, index + tb, recvtb=tb) Json() diff --git a/msccl/language/ir.py b/msccl/language/ir.py index 4df345e..5d1b11a 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -109,6 +109,7 @@ class Instruction(Enum): recv_reduce_copy = 'rrc' recv_reduce_copy_send = 'rrcs' read_reduce_copy = "rrc" + read_reduce_copy_send = "rrcs" reduce_send = 'rs' copy = 'cpy' reduce = 're' @@ -235,7 +236,7 @@ def __repr__(self): _local_dst_insts = {Instruction.recv, Instruction.recv_copy_send, Instruction.recv_reduce_send, Instruction.recv_reduce_copy, Instruction.copy, Instruction.reduce, Instruction.recv_reduce_copy_send, Instruction.wait} -_send_insts = {Instruction.put} +_send_insts = {Instruction.put} # do we need this? def ir_to_xml(program: Program, old_format=True, use_scratch=True, pretty_print=True, dependence_nop=False): @@ -552,6 +553,14 @@ def ir_to_json(program: Program, dependence_nop=False): def dump_to_json(program: Program): gpus = [] + + def get_channel_ids(chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_type): + channel_ids = [] + for c in chunk_list: + key = (src_buffer, dst_buffer, chan_type) + channel_ids.extend([{"id": id, "off": c.index} for id, ele in enumerate(tb_channel_dict[key]["connectedTo"]) if ele == c.rank]) + return channel_ids + for id, gpu in enumerate(program.gpus): gpu_instance = { 'id': id, @@ -566,32 +575,89 @@ def dump_to_json(program: Program): "srcBuffer": srcBuffer.name if hasattr(srcBuffer, 'name') else srcBuffer, "dstBuffer": dstBuffer.name if hasattr(dstBuffer, 'name') else dstBuffer, "type": type.name, - "connectedTo": [eles[1] for eles in channels], - "threadblockMap": [eles[0] for eles in channels] + "connectedTo": [eles[1] for eles in channels] } gpu_instance["channels"].append(obj) for tb in gpu.threadblocks: if tb.id == -1: continue ops = [] + tb_channels = [] + tb_channel_dict = {} + for (srcBuffer, dstBuffer, type), channels in gpu.channels.items(): + obj = { + "srcBuffer": srcBuffer.name if hasattr(srcBuffer, 'name') else srcBuffer, + "dstBuffer": dstBuffer.name if hasattr(dstBuffer, 'name') else dstBuffer, + "type": type.name, + "chanIds": [id for id, ele in enumerate(channels) if ele[0] == tb.id], + "connectedTo": [ele[1] for ele in channels if ele[0] == tb.id], + } + tb_channel_dict[(srcBuffer, dstBuffer, type)] = obj + tb_channels.append(obj) for op in tb.ops: if op.tb == -1: continue - instr = { - "name": op.inst.name, - "src": op.src.rank if op.src else None, - "srcbuff": op.src.buffer.name if op.src.buffer else None, - "srcoff": op.src.index if op.src else None, - "dst": op.dst.rank if op.dst else None, - "dstbuff": op.dst.buffer.name if op.dst.buffer else None, - "dstoff": op.dst.index if op.dst else None, - "channel_type": op.channel_type.name, - "cnt": op.cnt(), - } + if op.inst == Instruction.signal: + # get dst channel ids + dst_channel_ids = get_channel_ids(op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) + instr = { + "name": op.inst.name, + "o_cids": dst_channel_ids, + "srcbuff": op.src.buffer.value if op.src.buffer else None, + "dstbuff": op.dst.buffer.value if op.dst.buffer else None, + "ctype": op.channel_type.value, + } + elif op.inst == Instruction.wait: + # get src channel ids + src_channel_ids = get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) + instr = { + "name": op.inst.name, + "i_cids": src_channel_ids, + "srcbuff": op.src.buffer.value if op.src.buffer else None, + "dstbuff": op.dst.buffer.value if op.dst.buffer else None, + "ctype": op.channel_type.value, + } + elif op.inst == Instruction.read_reduce_copy: + src_channel_ids = get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) + instr = { + "name": op.inst.value, + "i_cids": src_channel_ids, + "srcbuff": op.src.buffer.value if op.src.buffer else None, + "dstbuff": op.dst.buffer.value if op.dst.buffer else None, + "dstoff": op.dst.index if op.dst else None, + "ctype": op.channel_type.value, + "cnt": op.cnt(), + } + elif op.inst == Instruction.read_reduce_copy_send: + src_channel_ids = get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) + dst_channel_ids = get_channel_ids(op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) + instr = { + "name": op.inst.value, + "i_cids": src_channel_ids, + "o_cids": dst_channel_ids, + "srcbuff": op.src.buffer.value if op.src.buffer else None, + "dstbuff": op.dst.buffer.value if op.dst.buffer else None, + "dstoff": op.dst.index if op.dst else None, + "ctype": op.channel_type.value, + "cnt": op.cnt(), + } + else: + instr = { + "name": op.inst.value, + "src": op.src.rank if op.src else None, + "srcbuff": op.src.buffer.value if op.src.buffer else None, + "srcoff": op.src.index if op.src else None, + "dst": op.dst.rank if op.dst else None, + "dstbuff": op.dst.buffer.value if op.dst.buffer else None, + "dstoff": op.dst.index if op.dst else None, + "channel_type": op.channel_type.value, + "cnt": op.cnt(), + } ops.append(instr) threadblock = { 'id': tb.id, - 'ops': ops + 'ops': ops, + 'channels': list(map(lambda x: {"s": x["srcBuffer"], "d": x["dstBuffer"], "t": x["type"], "cid": x["chanIds"]}, tb_channels)) } gpu_instance['threadblocks'].append(threadblock) gpus.append(gpu_instance) diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index 170a45e..49199c7 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -144,7 +144,7 @@ def add_reduce(self, rank, send_ref, recv_ref, tb, ch): srcindex = send_ref.index size = recv_ref.size prev_ops = [] - op.srcs.append(ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size)) + op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size)), tb_step) # Sending part of reduce self._read(rank, srcbuffer, srcindex, size, op) # Reduce part of copy @@ -188,8 +188,8 @@ def add_signal(self, rank, send_ref, recv_ref, tb, ch_type): size = send_ref.size # treat signal as a write since it can not be executed parallelly with read operations self._write(rank, buffer, index, size, op) - op.dsts.append(ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size)) - op.srcs.append(ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size)) + op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step)) + op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) return op def add_wait(self, rank, dst_ref, src_ref, tb, ch_type): @@ -199,8 +199,8 @@ def add_wait(self, rank, dst_ref, src_ref, tb, ch_type): index = dst_ref.index size = dst_ref.size self._write(rank, buffer, index, size, op) - op.srcs.append(ChunkRef(src_ref.rank, src_ref.buffer, src_ref.index, src_ref.size)) - op.dsts.append(ChunkRef(dst_ref.rank, dst_ref.buffer, dst_ref.index, dst_ref.size)) + op.srcs.append((ChunkRef(src_ref.rank, src_ref.buffer, src_ref.index, src_ref.size), tb_step)) + op.dsts.append((ChunkRef(dst_ref.rank, dst_ref.buffer, dst_ref.index, dst_ref.size), tb_step)) return op # InstructionDAG - adds a recv node @@ -229,7 +229,7 @@ def add_read_reduce_copy(self, rank, send_ref, recv_ref, tb, ch_type): buffer = recv_ref.buffer index = recv_ref.index size = recv_ref.size - op.srcs.append(ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size)) + op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) self._write(rank, buffer, index, size, op, read=True) return op @@ -312,11 +312,11 @@ def _optimize_rrc_r_signal_wait(self): queue = list(tb.ops) while len(queue) > 0: op = queue[0] - if op.inst == Instruction.recv_reduce_copy: + if op.inst == Instruction.read_reduce_copy: fused = False for next_op in op.next: - if next_op.inst == Instruction.recv_reduce_copy and same_count(op, next_op) and same_buf_dst(op, next_op) and same_chan_type(op, next_op): - op.srcs.append(ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size)) + if next_op.inst == Instruction.read_reduce_copy and same_count(op, next_op) and same_buf_dst(op, next_op) and same_chan_type(op, next_op): + op.srcs.append((ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size), next_op.step)) remove_op(next_op) tb.ops.remove(next_op) queue.remove(next_op) @@ -328,7 +328,7 @@ def _optimize_rrc_r_signal_wait(self): fused = False for next_op in op.next: if next_op.inst == Instruction.reduce and same_buf_dst(op, next_op) and same_chan_type(op, next_op): - op.srcs.append(ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size)) + op.srcs.append((ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size)), next_op.step) remove_op(next_op) tb.ops.remove(next_op) queue.remove(next_op) @@ -340,7 +340,8 @@ def _optimize_rrc_r_signal_wait(self): fused = False for next_op in op.next: if next_op.inst == Instruction.signal and same_buf_src(op, next_op) and same_chan_type(op, next_op): - op.dsts.append(ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size)) + op.dsts.append((ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size), next_op.step)) + op.srcs.append((ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size), next_op.step)) remove_op(next_op) tb.ops.remove(next_op) queue.remove(next_op) @@ -352,7 +353,8 @@ def _optimize_rrc_r_signal_wait(self): fused = False for next_op in op.next: if next_op.inst == Instruction.wait and same_buf_dst(op, next_op) and same_chan_type(op, next_op): - op.srcs.append(ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size)) + op.srcs.append((ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size), next_op.step)) + op.dsts.append((ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size),next_op.step)) remove_op(next_op) tb.ops.remove(next_op) queue.remove(next_op) @@ -362,21 +364,21 @@ def _optimize_rrc_r_signal_wait(self): continue queue = queue[1:] - # rrc(_,_,_,dst,dbuf,di) send(dst,dbuf,di,_,_,_) -> rrcs(_,_,_,_,_,_) - # reduce(_,_,_,dst,dbuf,di) send(dst,dbuf,di,_,_,_) -> rrs(_,_,_,_,_,_) + # rrc(_,_,_,dst,dbuf,di) put(dst,dbuf,di,_,_,_) -> rrcs(_,_,_,_,_,_) + # reduce(_,_,_,dst,dbuf,di) put(dst,dbuf,di,_,_,_) -> rs(_,_,_,_,_,_) def _optimize_rrcs_rs(self): for rank, rank_tbs in enumerate(self.tbs): for tbid, tb in rank_tbs.items(): queue = list(tb.ops) while len(queue) > 0: op = queue[0] - if op.inst == Instruction.recv_reduce_copy or op.inst == Instruction.recv_reduce_copy_send: + if op.inst == Instruction.read_reduce_copy or op.inst == Instruction.read_reduce_copy_send: fused = False for next_op in op.next: if next_op.inst == Instruction.put and same_count(op, next_op) and buf_dst_src_match(op, next_op) and same_chan_type(op, next_op): - if op.inst == Instruction.recv_reduce_copy: - op.inst = Instruction.recv_reduce_copy_send - op.dsts.append(ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size)) + if op.inst == Instruction.read_reduce_copy: + op.inst = Instruction.read_reduce_copy_send + op.dsts.append((ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size), next_op.step)) remove_op(next_op) tb.ops.remove(next_op) queue.remove(next_op) @@ -390,7 +392,7 @@ def _optimize_rrcs_rs(self): if next_op.inst == Instruction.put and same_count(op, next_op) and buf_dst_src_match(op, next_op): if op.inst == Instruction.reduce: op.inst = Instruction.reduce_send - op.dsts.append(ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size)) + op.dsts.append((ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size), next_op.step)) remove_op(next_op) tb.ops.remove(next_op) queue.remove(next_op) @@ -414,8 +416,8 @@ def _parallel_signal_wait(self): if len(queue) > 1: seq_op = queue[1] if seq_op.inst == Instruction.signal and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op): - op.dsts.append(ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size)) - op.srcs.append(ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size)) + op.dsts.append((ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), seq_op.step)) + op.srcs.append((ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), seq_op.step)) merge_op(op, seq_op) tb.ops.remove(seq_op) queue.remove(seq_op) @@ -427,8 +429,8 @@ def _parallel_signal_wait(self): if len(queue) > 1: seq_op = queue[1] if seq_op.inst == Instruction.wait and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op): - op.dsts.append(ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size)) - op.srcs.append(ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size)) + op.dsts.append((ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), seq_op.step)) + op.srcs.append((ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), seq_op.step)) merge_op(op, seq_op) tb.ops.remove(seq_op) queue.remove(seq_op) @@ -577,6 +579,10 @@ def lower_tbs(self): for op in tb.ops: op.src = self.lower_chunk(op.src) op.dst = self.lower_chunk(op.dst) + srcs = sorted(op.srcs, key=lambda x: x[1]) + dsts = sorted(op.dsts, key=lambda x: x[1]) + op.srcs = [src[0] for src in srcs] + op.dsts = [dst[0] for dst in dsts] lowered_tbs[tbid] = tb gpus.append(Gpu(rank, list(lowered_tbs.values()))) return gpus From 3321c5f28f8b3aaea1248b187d87017463f39a17 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Tue, 26 Mar 2024 09:57:02 +0000 Subject: [PATCH 15/76] WIP --- msccl/language/ir.py | 10 ++++++---- msccl/language/rank_dag.py | 6 ++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/msccl/language/ir.py b/msccl/language/ir.py index 5d1b11a..7346b3f 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -578,6 +578,7 @@ def get_channel_ids(chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_ty "connectedTo": [eles[1] for eles in channels] } gpu_instance["channels"].append(obj) + gpu_instance["channels"] = list(filter(lambda x: x["type"] != "none", gpu_instance["channels"])) for tb in gpu.threadblocks: if tb.id == -1: continue @@ -586,14 +587,15 @@ def get_channel_ids(chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_ty tb_channel_dict = {} for (srcBuffer, dstBuffer, type), channels in gpu.channels.items(): obj = { - "srcBuffer": srcBuffer.name if hasattr(srcBuffer, 'name') else srcBuffer, - "dstBuffer": dstBuffer.name if hasattr(dstBuffer, 'name') else dstBuffer, + "srcBuffer": srcBuffer.value if hasattr(srcBuffer, 'name') else srcBuffer, + "dstBuffer": dstBuffer.value if hasattr(dstBuffer, 'name') else dstBuffer, "type": type.name, "chanIds": [id for id, ele in enumerate(channels) if ele[0] == tb.id], "connectedTo": [ele[1] for ele in channels if ele[0] == tb.id], } tb_channel_dict[(srcBuffer, dstBuffer, type)] = obj tb_channels.append(obj) + tb_channels = filter(lambda x: x["type"] != "none", tb_channels) for op in tb.ops: if op.tb == -1: continue @@ -650,14 +652,14 @@ def get_channel_ids(chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_ty "dst": op.dst.rank if op.dst else None, "dstbuff": op.dst.buffer.value if op.dst.buffer else None, "dstoff": op.dst.index if op.dst else None, - "channel_type": op.channel_type.value, + "ctype": op.channel_type.value, "cnt": op.cnt(), } ops.append(instr) threadblock = { 'id': tb.id, 'ops': ops, - 'channels': list(map(lambda x: {"s": x["srcBuffer"], "d": x["dstBuffer"], "t": x["type"], "cid": x["chanIds"]}, tb_channels)) + 'channels': list(map(lambda x: {"src": x["srcBuffer"], "dst": x["dstBuffer"], "ctype": x["type"], "cid": x["chanIds"]}, tb_channels)) } gpu_instance['threadblocks'].append(threadblock) gpus.append(gpu_instance) diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index 49199c7..5d41c5e 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -144,7 +144,7 @@ def add_reduce(self, rank, send_ref, recv_ref, tb, ch): srcindex = send_ref.index size = recv_ref.size prev_ops = [] - op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size)), tb_step) + op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) # Sending part of reduce self._read(rank, srcbuffer, srcindex, size, op) # Reduce part of copy @@ -328,7 +328,7 @@ def _optimize_rrc_r_signal_wait(self): fused = False for next_op in op.next: if next_op.inst == Instruction.reduce and same_buf_dst(op, next_op) and same_chan_type(op, next_op): - op.srcs.append((ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size)), next_op.step) + op.srcs.append((ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size), next_op.step)) remove_op(next_op) tb.ops.remove(next_op) queue.remove(next_op) @@ -408,6 +408,8 @@ def _optimize_rrcs_rs(self): def _parallel_signal_wait(self): for rank, rank_tbs in enumerate(self.tbs): for tbid, tb in rank_tbs.items(): + if tbid == -1: + continue queue = list(tb.ops) while len(queue) > 0: op = queue[0] From 7e4bd8b82a6c84ccbfd808feeccff0366a423b54 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Tue, 26 Mar 2024 11:18:50 +0000 Subject: [PATCH 16/76] WIP --- msccl/language/ir.py | 22 +++++++++++++++++++++- msccl/language/rank_dag.py | 6 +++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/msccl/language/ir.py b/msccl/language/ir.py index 7346b3f..ad710e2 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -632,7 +632,7 @@ def get_channel_ids(chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_ty } elif op.inst == Instruction.read_reduce_copy_send: src_channel_ids = get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) - dst_channel_ids = get_channel_ids(op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) + dst_channel_ids = get_channel_ids(op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, op.channel_type) instr = { "name": op.inst.value, "i_cids": src_channel_ids, @@ -643,6 +643,26 @@ def get_channel_ids(chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_ty "ctype": op.channel_type.value, "cnt": op.cnt(), } + elif op.inst == Instruction.reduce_send: + dst_channel_ids = get_channel_ids(op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, ChannelType.sm) + instr = { + "name": op.inst.value, + "o_cids": dst_channel_ids, + "srcbuff": op.src.buffer.value if op.src.buffer else None, + "dstbuff": op.dst.buffer.value if op.dst.buffer else None, + "dstoff": op.dst.index if op.dst else None, + "srcs": list(map(lambda x: {"buff": x.buffer, "off": x.index}, op.srcs)), + "cnt": op.cnt(), + } + elif op.inst == Instruction.reduce: + instr = { + "name": op.inst.value, + "srcbuff": op.src.buffer.value if op.src.buffer else None, + "dstbuff": op.dst.buffer.value if op.dst.buffer else None, + "dstoff": op.dst.index if op.dst else None, + "srcs": list(map(lambda x: {"buff": x.buffer, "off": x.index}, op.srcs)), + "cnt": op.cnt(), + } else: instr = { "name": op.inst.value, diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index 5d41c5e..b2469be 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -376,6 +376,8 @@ def _optimize_rrcs_rs(self): fused = False for next_op in op.next: if next_op.inst == Instruction.put and same_count(op, next_op) and buf_dst_src_match(op, next_op) and same_chan_type(op, next_op): + if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer: + continue if op.inst == Instruction.read_reduce_copy: op.inst = Instruction.read_reduce_copy_send op.dsts.append((ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size), next_op.step)) @@ -389,7 +391,9 @@ def _optimize_rrcs_rs(self): if op.inst == Instruction.reduce or op.inst == Instruction.reduce_send: fused = False for next_op in op.next: - if next_op.inst == Instruction.put and same_count(op, next_op) and buf_dst_src_match(op, next_op): + if next_op.inst == Instruction.put and same_count(op, next_op) and buf_dst_src_match(op, next_op) and next_op.channel_type == ChannelType.sm: + if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer: + continue if op.inst == Instruction.reduce: op.inst = Instruction.reduce_send op.dsts.append((ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size), next_op.step)) From 7074c016c6728d0b5ad87e8f12ea8236e7ebb47a Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Tue, 26 Mar 2024 15:07:35 +0000 Subject: [PATCH 17/76] WIP --- .../mscclang/allreduce_a100_ring_mscclpp.py | 48 +++++++++++++++++++ msccl/language/rank_dag.py | 2 +- 2 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 examples/mscclang/allreduce_a100_ring_mscclpp.py diff --git a/examples/mscclang/allreduce_a100_ring_mscclpp.py b/examples/mscclang/allreduce_a100_ring_mscclpp.py new file mode 100644 index 0000000..adfd627 --- /dev/null +++ b/examples/mscclang/allreduce_a100_ring_mscclpp.py @@ -0,0 +1,48 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +from msccl.language import * +from msccl.topologies import * +from msccl.language.collectives import AllReduce + +# Ring all reduce for A100s +def allreduce_ring(size, instances, protocol): + topology = fully_connected(size) + collective = AllReduce(size, size, True) + with MSCCLProgram(f"allreduce_ring", topology, collective, instances, + protocol=protocol, threadblock_policy=ThreadblockPolicy.manual): + # Reduce ring + for step in range(0, size-1): + for index in range(0, size): + rank = (index + step) % size + next_rank = (index + step + 1) % size + c = chunk(rank, Buffer.input, index) + c.signal(next_rank, Buffer.input, index, 0) + prev_rank = (index + step - 1) % size + c = chunk(rank, Buffer.input, (index+size-1)%size) + c.wait(prev_rank, Buffer.input, (index+size-1)%size, 0) + c.reduce_mscclpp(chunk(prev_rank, Buffer.input, (index+size-1)%size), recvtb=0) + + # Propagate ring + for step in range(-1, size-2): + for index in range(0, size): + rank = (index + step) % size + c = chunk(rank, Buffer.input, index) + next_rank = (index + step + 1) % size + c.put(next_rank, Buffer.input, index, sendtb=0) + c.signal(next_rank, Buffer.input, index, 0) + prev_rank = (index + step - 1) % size + c = chunk(rank, Buffer.input, (index+size-1)%size) + c.wait(prev_rank, Buffer.input, (index+size-1)%size, 0) + + Json() + # Check() + +parser = argparse.ArgumentParser() +parser.add_argument('num_gpus', type=int, help ='number of gpus') +parser.add_argument('instances', type=int, help='number of instances') +parser.add_argument('--protocol', type=str, default='Simple', choices=['Simple', 'LL'], help ='protocol. Default: Simple') +args = parser.parse_args() + +allreduce_ring(args.num_gpus, args.instances, args.protocol) diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index b2469be..d7fad06 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -260,7 +260,7 @@ def optimize(self): def complete_channels(self): send_op = [Instruction.put, Instruction.signal] - recv_op = [Instruction.wait] + recv_op = [Instruction.wait, Instruction.get, Instruction.read_reduce_copy] for rank, rank_tbs in enumerate(self.tbs): for tbid, tb in rank_tbs.items(): chans = set() From f31d9b4f46ce455848d402f037b92cb48d9c5828 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 27 Mar 2024 03:43:41 +0000 Subject: [PATCH 18/76] Now for deps --- msccl/language/ir.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/msccl/language/ir.py b/msccl/language/ir.py index ad710e2..c440b84 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -231,12 +231,14 @@ def __repr__(self): # Instructions where src is on local GPU -_local_src_insts = {Instruction.send, Instruction.copy, Instruction.reduce, Instruction.put, Instruction.signal} +_local_src_insts = {Instruction.send, Instruction.copy, Instruction.reduce} # Instructions where dst is on local GPU _local_dst_insts = {Instruction.recv, Instruction.recv_copy_send, Instruction.recv_reduce_send, Instruction.recv_reduce_copy, Instruction.copy, Instruction.reduce, - Instruction.recv_reduce_copy_send, Instruction.wait} -_send_insts = {Instruction.put} # do we need this? + Instruction.recv_reduce_copy_send} + +_local_src_insts_mscclpp = {Instruction.put, Instruction.signal, Instruction.copy, Instruction.reduce, Instruction.reduce_send} +_local_dst_insts_mscclpp = {Instruction.get, Instruction.wait, Instruction.read_reduce_copy, Instruction.copy, Instruction.reduce, Instruction.read_reduce_copy_send, Instruction.reduce_send} def ir_to_xml(program: Program, old_format=True, use_scratch=True, pretty_print=True, dependence_nop=False): @@ -431,18 +433,24 @@ def ir_to_json(program: Program, dependence_nop=False): for gpu in program.gpus: for tb in gpu.threadblocks: for op in tb.ops: - if op.inst in _local_src_insts: + if op.inst in _local_src_insts_mscclpp: key = (gpu.rank, op.src.buffer) buffer_sizes[key] = max( buffer_sizes[key], op.src.index + op.src.size) - if op.inst in _send_insts: - key = (op.dst.rank, op.dst.buffer) - buffer_sizes[key] = max( - buffer_sizes[key], op.dst.index + op.dst.size) - if op.inst in _local_dst_insts: + for src in op.srcs: + key = (gpu.rank, src.buffer) + buffer_sizes[key] = max( + buffer_sizes[key], src.index + src.size) + if op.inst in _local_dst_insts_mscclpp: key = (gpu.rank, op.dst.buffer) buffer_sizes[key] = max( buffer_sizes[key], op.dst.index + op.dst.size) + # ignore remote buffers + if op.inst != Instruction.read_reduce_copy_send and op.inst != Instruction.reduce_send: + for dst in op.dsts: + key = (gpu.rank, dst.buffer) + buffer_sizes[key] = max( + buffer_sizes[key], dst.index + dst.size) for gpu in program.gpus: gpu.input_chunks = max(buffer_sizes[(gpu.rank, Buffer.input)], gpu.input_chunks) gpu.output_chunks = max(buffer_sizes[(gpu.rank, Buffer.output)], gpu.output_chunks) From dc8d44ed46e523e2bc12b10265e735798d73c26d Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 27 Mar 2024 08:21:07 +0000 Subject: [PATCH 19/76] let make instance work --- .../allreduce_a100_allpairs_sm_mscclpp.py | 16 ++-- msccl/language/ir.py | 80 ++++++------------- 2 files changed, 32 insertions(+), 64 deletions(-) diff --git a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py index 5840ae5..e504f9b 100644 --- a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py +++ b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py @@ -17,18 +17,20 @@ def allreduce_allpairs(gpus, instances, protocol): for rank in range(size): for tb in range(size): index = rank * size - c = chunk(rank, Buffer.input, index + tb) + c = chunk(rank, Buffer.input, index+tb) # step1 make sure the data is ready for nghr in range(size): + peer_index = nghr * size if rank != nghr: - c.signal(nghr, Buffer.input, index + tb, sendtb=tb) + c_peer = chunk(rank, Buffer.input, peer_index+tb) + c_peer.signal(nghr, Buffer.input, peer_index+tb, sendtb=tb) for nghr in range(size): if rank != nghr: - c.wait(nghr, Buffer.input, index + tb, recvtb=tb) + c.wait(nghr, Buffer.input, index+tb, recvtb=tb) # step2 reduce the chunks and send to peers for nghr in range(size): if rank != nghr: - c.reduce_mscclpp(chunk(nghr, Buffer.input, index + tb), recvtb=tb) + c.reduce_mscclpp(chunk(nghr, Buffer.input, index+tb), recvtb=tb) for nghr in range(size): if rank != nghr: c.put(nghr, Buffer.input, index, sendtb=tb) @@ -38,9 +40,9 @@ def allreduce_allpairs(gpus, instances, protocol): c.signal(nghr, Buffer.input, index, sendtb=tb) for nghr in range(size): if rank != nghr: - index = nghr * size - c = chunk(rank, Buffer.input, index + tb) - c.wait(nghr, Buffer.input, index, recvtb=tb) + peer_index = nghr * size + c_peer = chunk(rank, Buffer.input, peer_index+tb) + c_peer.wait(nghr, Buffer.input, peer_index, recvtb=tb) Json() diff --git a/msccl/language/ir.py b/msccl/language/ir.py index c440b84..bd38e0a 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -472,17 +472,13 @@ def ir_to_json(program: Program, dependence_nop=False): chan_dict[key] = sorted(value) gpu.channels = chan_dict - # Filter out dependencies within the same threadblock - op_tb_id = {} - for gpu in program.gpus: - for tb in gpu.threadblocks: - for op in tb.ops: - op_tb_id[op] = op.tb + # Remove the dependencies of wait after signal. They are actually depends on remote chunk for gpu in program.gpus: for tb in gpu.threadblocks: for op in tb.ops: - op.depends = list( - filter(lambda dep: op_tb_id[dep] != op.tb, op.depends)) + if op.inst == Instruction.wait: + op.depends = list(filter(lambda dep: dep.inst != Instruction.signal, op.depends)) + # Filter out redundant dependencies # e.g. if op1 and op2 depend on op, and op1 happends before op2 # then op2 does not need to explicitly depend on op @@ -494,61 +490,26 @@ def ir_to_json(program: Program, dependence_nop=False): filter(lambda dep: dep not in running_depends, op.depends)) running_depends = running_depends + op.depends - # Mark all ops that have a dependence on them - has_dependence = set() - for gpu in program.gpus: - for tb in gpu.threadblocks: - for op in tb.ops: - has_dependence.update(op.depends) - - if dependence_nop: + # Do some additional postprocessing of operations: + # - Expand operations with dependencies with no-ops + if program.protocol != "LL": # ignore the dependence_nop for LL protocol for gpu in program.gpus: for tb in gpu.threadblocks: - pre_ops = [] - after_ops = [] - first_re = None - first_dep = None - for i, op in enumerate(tb.ops): + new_ops = [] + for op in tb.ops: # Expand extra dependencies into nop operations - num_depends = len(op.depends) - if op.inst is Instruction.reduce: - if num_depends > 0: - for dep in op.depends: - if first_dep is None: - first_dep = dep - else: - pre_ops.append(Op(Instruction.nop, -1, None, None, [dep])) - op.depends = [] - if first_re is None: - first_re = op - - if first_re is not None: - after_ops.append(op) - else: - pre_ops.append(op) - if first_dep is not None: - first_re.depends = [first_dep] - tb.ops = pre_ops + after_ops + for i, dep in enumerate(op.depends): + new_ops.append(Op(Instruction.nop, -1, None, None, [dep])) + #op_tb_id[new_ops[-1]] = op_tb_id[op] + new_ops.append(op) + tb.ops = new_ops - # Do some additional postprocessing of operations: - # - Expand operations with extra dependencies with no-ops - # - Mark the index of each operation taking any extra no-ops into account - op_idx = {} + # update step and tid for ops for gpu in program.gpus: for tb in gpu.threadblocks: - new_ops = [] - for op in tb.ops: - # Expand extra dependencies into nop operations - if len(op.depends) > 1: - extra_deps = op.depends[1:] - op.depends = op.depends[:1] - for i, dep in enumerate(extra_deps): - new_ops.append(Op(Instruction.nop, -1, None, None, [dep])) - op_idx[new_ops[-1]] = len(new_ops) - 1 - #op_tb_id[new_ops[-1]] = op_tb_id[op] - new_ops.append(op) - op_idx[new_ops[-1]] = len(new_ops) - 1 - tb.ops = new_ops + for i, op in enumerate(tb.ops): + op.step = i + op.tb = tb.id # Need to calculate channel info for each GPU nchannels = 0 @@ -671,6 +632,11 @@ def get_channel_ids(chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_ty "srcs": list(map(lambda x: {"buff": x.buffer, "off": x.index}, op.srcs)), "cnt": op.cnt(), } + elif op.inst == Instruction.nop: + instr = { + "name": op.inst.value, + "deps": list(map(lambda dep: {"tb": dep.tb, "step": dep.step}, op.depends)) + } else: instr = { "name": op.inst.value, From 085be4a155aa5d9fa1bc0d28becdaad832cc0849 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Thu, 28 Mar 2024 08:43:34 +0000 Subject: [PATCH 20/76] enable instance --- .../allreduce_a100_allpairs_sm_mscclpp_get.py | 6 +- msccl/language/__init__.py | 5 +- msccl/language/ir.py | 9 +++ msccl/language/rank_dag.py | 72 +++++++++++++++++++ 4 files changed, 88 insertions(+), 4 deletions(-) diff --git a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py index d9442fd..aa7c39e 100644 --- a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py +++ b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py @@ -12,7 +12,7 @@ def allreduce_allpairs(gpus, instances, protocol): topology = fully_connected(size) collective = AllReduce(size, chunksperloop, True) with MSCCLProgram("allreduce_pairs", topology, collective, instances, protocol=protocol, - interleaved_replication=False, threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True): + interleaved_replication=False, threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True, instance_policy=InstancePolicy.dup): # Each rank sends the nth chunk to the nth rank into scratch space for rank in range(size): @@ -21,8 +21,10 @@ def allreduce_allpairs(gpus, instances, protocol): c = chunk(rank, Buffer.input, index + tb) # make sure the data is ready for nghr in range(size): + peer_index = nghr * size if rank != nghr: - c.signal(nghr, Buffer.input, index + tb, sendtb=tb) + c_peer = chunk(rank, Buffer.input, peer_index+tb) + c_peer.signal(nghr, Buffer.input, peer_index + tb, sendtb=tb) for nghr in range(size): if rank != nghr: c.wait(nghr, Buffer.input, index + tb, recvtb=tb) diff --git a/msccl/language/__init__.py b/msccl/language/__init__.py index ffa4524..09fd55d 100755 --- a/msccl/language/__init__.py +++ b/msccl/language/__init__.py @@ -28,7 +28,7 @@ def _curr(): class MSCCLProgram: def __init__(self, name, topo, collective, instances, protocol='Simple', \ threadblock_policy=ThreadblockPolicy.auto, interleaved_replication=True, - instr_fusion=True, check_xml=True, dependence_nop=False): + instr_fusion=True, check_xml=True, dependence_nop=False, instance_policy=InstancePolicy.dup): self.name = name self.topo = topo self.collective = collective @@ -40,6 +40,7 @@ def __init__(self, name, topo, collective, instances, protocol='Simple', \ self.instr_fusion = instr_fusion self.check_xml = check_xml self.dependence_nop = dependence_nop + self.instance_policy = instance_policy assert protocol == 'Simple' or protocol == 'LL' or protocol == 'LL128', \ f'Given protocol: {protocol}. Must be either Simple, LL, LL128' self.run_opt = True # Runs optimization passes @@ -137,7 +138,7 @@ def lower_mscclpp(self): if self.instr_fusion: self.instr_dag.optimize_mscclpp(self.protocol) self.instr_dag.lower_pt1(self.instances) - gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.interleaved_replication) + gpu_prgms = self.instr_dag.lower_pt2_mscclpp(self.instances, self.instance_policy) return Program(self.name, self.collective.name, self.collective.inplace, self.protocol, gpu_prgms) def generate_xml(self): diff --git a/msccl/language/ir.py b/msccl/language/ir.py index bd38e0a..db18921 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -99,6 +99,15 @@ class ThreadblockPolicy(Enum): def __str__(self): return self.value +class InstancePolicy(Enum): + # this means pack multi instrances to deal with the same chunk and share the channels + packed = 'packed' + # this means each instance deal with the same chunk + dup = 'dup' + + def __str__(self): + return self.value + class Instruction(Enum): nop = 'nop' diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index d7fad06..e8abaf6 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -541,6 +541,9 @@ def lower_pt2(self, instances, interleaved): self.replicate(instances, interleaved) return self.lower_tbs() + def lower_pt2_mscclpp(self, instances, instance_pollicy): + self.replicate_mscclpp(instances, instance_pollicy) + return self.lower_tbs() def infer_dependencies(self): for slot, ops in self.operations.items(): @@ -664,3 +667,72 @@ def get_instance_ref(ref): dep_step = dep.step iop.depends[s] = self.instanced_tbs[op.rank][dep_itbid].ops[dep_step] + def replicate_mscclpp(self, instances, instance_policy): + # update op step + for rank, rank_tbs in enumerate(self.tbs): + for _, tb in rank_tbs.items(): + for id, op in enumerate(tb.ops): + op.step = id + + if instances == 1: + self.instanced_tbs = self.tbs + return + + self.instanced_tbs = [] + for _ in range(self.num_ranks): + self.instanced_tbs.append({}) + + def is_scratch(buffer): + return buffer != Buffer.input and buffer != Buffer.output + + def get_new_index(rank, buffer, index, size, i): + # Scratch buffers always use batched + if is_scratch(buffer): + buf_instance_len = self.buffers[rank][buffer].instance_size() + return buf_instance_len * i + index + return len(self.buffers[rank][buffer]) * i + index + + def get_instance_ref(ref): + iindex = get_new_index(ref.rank, ref.buffer, ref.index, ref.size, i) + iref = ChunkRef(ref.rank, ref.buffer, iindex, ref.size) + return iref + + if instance_policy == InstancePolicy.dup: + for i in range(instances): + # Generate all the threadblocks and ops + for rank, rank_tbs in enumerate(self.tbs): + # rank_channels = self.num_channels[rank] + for tbid, tb in rank_tbs.items(): + itbid = tbid * instances + i + itb = Threadblock(id=itbid) + itb.ops = [None] * len(tb.ops) + for s, op in enumerate(tb.ops): + isrc = get_instance_ref(op.src) + idst = get_instance_ref(op.dst) + idepends = [] + # Note: We don't need the fill out the rest of the metadata since replication is the last optimization + iop = Op(op.inst, op.rank, isrc, idst, idepends, op.step, itbid, channel_type=op.channel_type) + itb.ops[s] = iop + for src, step in op.srcs: + isrc = get_instance_ref(src) + iop.srcs.append((isrc, step)) + for dst, step in op.dsts: + idst = get_instance_ref(dst) + iop.dsts.append((idst, step)) + for chan in tb.channels: + itb.channels.append(chan) + self.instanced_tbs[op.rank][itbid] = itb + + # Redo dependency analysis + for rank, rank_tbs in enumerate(self.tbs): + for tbid, tb in rank_tbs.items(): + for i in range(instances): + itbid = tbid * instances + i + itb = self.instanced_tbs[rank][itbid] + for op, iop in zip(tb.ops, itb.ops): + iop.depends = [None] * len(op.depends) + for s, dep in enumerate(op.depends): + dep_tbid = dep.tb + dep_itbid = dep_tbid * instances + i + dep_step = dep.step + iop.depends[s] = self.instanced_tbs[op.rank][dep_itbid].ops[dep_step] From e613558beaf54ddc08fda2a43af1f040462c2fde Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Thu, 28 Mar 2024 09:05:33 +0000 Subject: [PATCH 21/76] fix --- msccl/language/ir.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/msccl/language/ir.py b/msccl/language/ir.py index db18921..97ea10a 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -646,6 +646,28 @@ def get_channel_ids(chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_ty "name": op.inst.value, "deps": list(map(lambda dep: {"tb": dep.tb, "step": dep.step}, op.depends)) } + elif op.inst == Instruction.put: + cids = get_channel_ids([op.dst], tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) + instr = { + "name": op.inst.value, + "o_cids": cids, + "srcbuff": op.src.buffer.value if op.src.buffer else None, + "srcoff": op.src.index if op.src else None, + "dstbuff": op.dst.buffer.value if op.dst.buffer else None, + "ctype": op.channel_type.value, + "cnt": op.cnt(), + } + elif op.inst == Instruction.get: + cids = get_channel_ids([op.src], tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) + instr = { + "name": op.inst.value, + "i_cids": cids, + "srcbuff": op.src.buffer.value if op.src.buffer else None, + "dstbuff": op.dst.buffer.value if op.dst.buffer else None, + "dstoff": op.dst.index if op.dst else None, + "ctype": op.channel_type.value, + "cnt": op.cnt(), + } else: instr = { "name": op.inst.value, From 79f450ac51ede9cb9ea2260c286392bf60f35b11 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Thu, 28 Mar 2024 09:09:54 +0000 Subject: [PATCH 22/76] update ignore --- .gitignore | 3 +++ .vscode/launch.json | 17 ----------------- 2 files changed, 3 insertions(+), 17 deletions(-) delete mode 100644 .vscode/launch.json diff --git a/.gitignore b/.gitignore index 21f3b4c..4beca19 100755 --- a/.gitignore +++ b/.gitignore @@ -131,3 +131,6 @@ dmypy.json # Pyre type checker .pyre/ + +# vscode +.vscode/ diff --git a/.vscode/launch.json b/.vscode/launch.json deleted file mode 100644 index 1bdfdd3..0000000 --- a/.vscode/launch.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - // Use IntelliSense to learn about possible attributes. - // Hover to view descriptions of existing attributes. - // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 - "version": "0.2.0", - "configurations": [ - { - "name": "Python Debugger: Current File with Arguments", - "type": "debugpy", - "request": "launch", - "program": "${file}", - "console": "integratedTerminal", - "args": "4 1 --protocol Simple", - "justMyCode": false - } - ] -} From c4a10ddd89d243e291f1772d878364f80f2a3ec9 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Fri, 29 Mar 2024 10:25:23 +0000 Subject: [PATCH 23/76] bug fix --- msccl/language/ir.py | 4 ++-- msccl/language/rank_dag.py | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/msccl/language/ir.py b/msccl/language/ir.py index 97ea10a..dfff602 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -629,7 +629,7 @@ def get_channel_ids(chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_ty "srcbuff": op.src.buffer.value if op.src.buffer else None, "dstbuff": op.dst.buffer.value if op.dst.buffer else None, "dstoff": op.dst.index if op.dst else None, - "srcs": list(map(lambda x: {"buff": x.buffer, "off": x.index}, op.srcs)), + "srcs": list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs)), "cnt": op.cnt(), } elif op.inst == Instruction.reduce: @@ -638,7 +638,7 @@ def get_channel_ids(chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_ty "srcbuff": op.src.buffer.value if op.src.buffer else None, "dstbuff": op.dst.buffer.value if op.dst.buffer else None, "dstoff": op.dst.index if op.dst else None, - "srcs": list(map(lambda x: {"buff": x.buffer, "off": x.index}, op.srcs)), + "srcs": list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs)), "cnt": op.cnt(), } elif op.inst == Instruction.nop: diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index e8abaf6..d388c87 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -265,11 +265,13 @@ def complete_channels(self): for tbid, tb in rank_tbs.items(): chans = set() for op in tb.ops: + src_buffer = Buffer.scratch if op.src.buffer is not Buffer.input and op.src.buffer is not Buffer.output else op.src.buffer + dst_buffer = Buffer.scratch if op.dst.buffer is not Buffer.input and op.dst.buffer is not Buffer.output else op.dst.buffer if op.inst in send_op: - chan = Channel(op.src.buffer, op.dst.buffer, op.channel_type, op.dst.rank) + chan = Channel(src_buffer, dst_buffer, op.channel_type, op.dst.rank) chans.add(chan) elif op.inst in recv_op: - chan = Channel(op.dst.buffer, op.src.buffer, op.channel_type, op.src.rank) + chan = Channel(dst_buffer, src_buffer, op.channel_type, op.src.rank) chans.add(chan) tb.channels = list(chans) @@ -590,8 +592,8 @@ def lower_tbs(self): op.dst = self.lower_chunk(op.dst) srcs = sorted(op.srcs, key=lambda x: x[1]) dsts = sorted(op.dsts, key=lambda x: x[1]) - op.srcs = [src[0] for src in srcs] - op.dsts = [dst[0] for dst in dsts] + op.srcs = [self.lower_chunk(src[0]) for src in srcs] + op.dsts = [self.lower_chunk(dst[0]) for dst in dsts] lowered_tbs[tbid] = tb gpus.append(Gpu(rank, list(lowered_tbs.values()))) return gpus From ec4a112aa29d2f8521a20f8939b9146e0b151d17 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Tue, 2 Apr 2024 07:12:13 +0000 Subject: [PATCH 24/76] update --- msccl/language/ir.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/msccl/language/ir.py b/msccl/language/ir.py index dfff602..dcfb7ff 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -542,32 +542,32 @@ def get_channel_ids(chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_ty for id, gpu in enumerate(program.gpus): gpu_instance = { 'id': id, - 'input_chunks': gpu.input_chunks, - 'output_chunks': gpu.output_chunks, - 'scratch_chunks': gpu.scratch_chunks, + 'inputChunks': gpu.input_chunks, + 'outputChunks': gpu.output_chunks, + 'scratchChunks': gpu.scratch_chunks, 'threadblocks': [], "channels": [] } for (srcBuffer, dstBuffer, type), channels in gpu.channels.items(): obj = { - "srcBuffer": srcBuffer.name if hasattr(srcBuffer, 'name') else srcBuffer, - "dstBuffer": dstBuffer.name if hasattr(dstBuffer, 'name') else dstBuffer, - "type": type.name, + "srcbuff": srcBuffer.value if hasattr(srcBuffer, 'value') else srcBuffer, + "dstbuff": dstBuffer.value if hasattr(dstBuffer, 'value') else dstBuffer, + "type": type.value, "connectedTo": [eles[1] for eles in channels] } gpu_instance["channels"].append(obj) gpu_instance["channels"] = list(filter(lambda x: x["type"] != "none", gpu_instance["channels"])) for tb in gpu.threadblocks: - if tb.id == -1: + if tb.id < 0: continue ops = [] tb_channels = [] tb_channel_dict = {} for (srcBuffer, dstBuffer, type), channels in gpu.channels.items(): obj = { - "srcBuffer": srcBuffer.value if hasattr(srcBuffer, 'name') else srcBuffer, - "dstBuffer": dstBuffer.value if hasattr(dstBuffer, 'name') else dstBuffer, - "type": type.name, + "srcbuff": srcBuffer.value if hasattr(srcBuffer, 'value') else srcBuffer, + "dstbuff": dstBuffer.value if hasattr(dstBuffer, 'value') else dstBuffer, + "type": type.value, "chanIds": [id for id, ele in enumerate(channels) if ele[0] == tb.id], "connectedTo": [ele[1] for ele in channels if ele[0] == tb.id], } @@ -684,7 +684,7 @@ def get_channel_ids(chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_ty threadblock = { 'id': tb.id, 'ops': ops, - 'channels': list(map(lambda x: {"src": x["srcBuffer"], "dst": x["dstBuffer"], "ctype": x["type"], "cid": x["chanIds"]}, tb_channels)) + 'channels': list(map(lambda x: {"src": x["srcbuff"], "dst": x["dstbuff"], "ctype": x["type"], "cid": x["chanIds"]}, tb_channels)) } gpu_instance['threadblocks'].append(threadblock) gpus.append(gpu_instance) From 82de232be9cf3d974488b98cc52289c0799ff84b Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Tue, 2 Apr 2024 13:05:25 +0000 Subject: [PATCH 25/76] update --- msccl/language/__init__.py | 2 +- msccl/language/ir.py | 131 +++++++++++++------------------------ msccl/language/rank_dag.py | 18 ++--- 3 files changed, 57 insertions(+), 94 deletions(-) diff --git a/msccl/language/__init__.py b/msccl/language/__init__.py index 09fd55d..1b2a26c 100755 --- a/msccl/language/__init__.py +++ b/msccl/language/__init__.py @@ -373,7 +373,7 @@ def reduce_mscclpp(self, other_chunkref, sendtb=-1, recvtb=-1, channel_type=Chan self.prog.apply_reduce(src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size) if src != dst: - self.prog.instr_dag.add_read_reduce_copy(dst, other_chunkref, self, recvtb, channel_type) + self.prog.instr_dag.add_read_reduce(dst, other_chunkref, self, recvtb, channel_type) else: self.prog.instr_dag.add_reduce(src, other_chunkref, self, sendtb, ChannelType.none) diff --git a/msccl/language/ir.py b/msccl/language/ir.py index dcfb7ff..5843260 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -117,8 +117,8 @@ class Instruction(Enum): recv_reduce_send = 'rrs' recv_reduce_copy = 'rrc' recv_reduce_copy_send = 'rrcs' - read_reduce_copy = "rrc" - read_reduce_copy_send = "rrcs" + read_reduce = "rr" + read_reduce_send = "rrs" reduce_send = 'rs' copy = 'cpy' reduce = 're' @@ -247,7 +247,7 @@ def __repr__(self): Instruction.recv_reduce_copy_send} _local_src_insts_mscclpp = {Instruction.put, Instruction.signal, Instruction.copy, Instruction.reduce, Instruction.reduce_send} -_local_dst_insts_mscclpp = {Instruction.get, Instruction.wait, Instruction.read_reduce_copy, Instruction.copy, Instruction.reduce, Instruction.read_reduce_copy_send, Instruction.reduce_send} +_local_dst_insts_mscclpp = {Instruction.get, Instruction.wait, Instruction.read_reduce, Instruction.copy, Instruction.reduce, Instruction.read_reduce_send, Instruction.reduce_send} def ir_to_xml(program: Program, old_format=True, use_scratch=True, pretty_print=True, dependence_nop=False): @@ -455,7 +455,7 @@ def ir_to_json(program: Program, dependence_nop=False): buffer_sizes[key] = max( buffer_sizes[key], op.dst.index + op.dst.size) # ignore remote buffers - if op.inst != Instruction.read_reduce_copy_send and op.inst != Instruction.reduce_send: + if op.inst != Instruction.read_reduce_send and op.inst != Instruction.reduce_send: for dst in op.dsts: key = (gpu.rank, dst.buffer) buffer_sizes[key] = max( @@ -539,6 +539,9 @@ def get_channel_ids(chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_ty channel_ids.extend([{"id": id, "off": c.index} for id, ele in enumerate(tb_channel_dict[key]["connectedTo"]) if ele == c.rank]) return channel_ids + def remove_empty_fields(d): + return {k: v for k, v in d.items() if v not in [None, "", [], {}]} + for id, gpu in enumerate(program.gpus): gpu_instance = { 'id': id, @@ -575,112 +578,72 @@ def get_channel_ids(chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_ty tb_channels.append(obj) tb_channels = filter(lambda x: x["type"] != "none", tb_channels) for op in tb.ops: + o_buff = None + i_buff = None + dst_channel_ids = [] + src_channel_ids = [] + srcs = [] + src = None + dst = None if op.tb == -1: continue if op.inst == Instruction.signal: # get dst channel ids dst_channel_ids = get_channel_ids(op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) - instr = { - "name": op.inst.name, - "o_cids": dst_channel_ids, - "srcbuff": op.src.buffer.value if op.src.buffer else None, - "dstbuff": op.dst.buffer.value if op.dst.buffer else None, - "ctype": op.channel_type.value, - } + o_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} elif op.inst == Instruction.wait: # get src channel ids src_channel_ids = get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) - instr = { - "name": op.inst.name, - "i_cids": src_channel_ids, - "srcbuff": op.src.buffer.value if op.src.buffer else None, - "dstbuff": op.dst.buffer.value if op.dst.buffer else None, - "ctype": op.channel_type.value, - } - elif op.inst == Instruction.read_reduce_copy: + i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} + elif op.inst == Instruction.read_reduce: src_channel_ids = get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) - instr = { - "name": op.inst.value, - "i_cids": src_channel_ids, - "srcbuff": op.src.buffer.value if op.src.buffer else None, - "dstbuff": op.dst.buffer.value if op.dst.buffer else None, - "dstoff": op.dst.index if op.dst else None, - "ctype": op.channel_type.value, - "cnt": op.cnt(), - } - elif op.inst == Instruction.read_reduce_copy_send: + i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} + dst = op.dst + elif op.inst == Instruction.read_reduce_send: src_channel_ids = get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) dst_channel_ids = get_channel_ids(op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, op.channel_type) - instr = { - "name": op.inst.value, - "i_cids": src_channel_ids, - "o_cids": dst_channel_ids, - "srcbuff": op.src.buffer.value if op.src.buffer else None, - "dstbuff": op.dst.buffer.value if op.dst.buffer else None, - "dstoff": op.dst.index if op.dst else None, - "ctype": op.channel_type.value, - "cnt": op.cnt(), - } + i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} + o_buff = {"src": op.dst.buffer.value, "dst": op.dsts[0].buffer.value} + dst = op.dst elif op.inst == Instruction.reduce_send: dst_channel_ids = get_channel_ids(op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, ChannelType.sm) - instr = { - "name": op.inst.value, - "o_cids": dst_channel_ids, - "srcbuff": op.src.buffer.value if op.src.buffer else None, - "dstbuff": op.dst.buffer.value if op.dst.buffer else None, - "dstoff": op.dst.index if op.dst else None, - "srcs": list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs)), - "cnt": op.cnt(), - } + o_buff = {"src": op.dst.buffer.value, "dst": op.dsts[0].buffer.value} + srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs)) + dst = op.dst elif op.inst == Instruction.reduce: - instr = { - "name": op.inst.value, - "srcbuff": op.src.buffer.value if op.src.buffer else None, - "dstbuff": op.dst.buffer.value if op.dst.buffer else None, - "dstoff": op.dst.index if op.dst else None, - "srcs": list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs)), - "cnt": op.cnt(), - } + srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs)) + dst = op.dst elif op.inst == Instruction.nop: instr = { "name": op.inst.value, "deps": list(map(lambda dep: {"tb": dep.tb, "step": dep.step}, op.depends)) } elif op.inst == Instruction.put: - cids = get_channel_ids([op.dst], tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) - instr = { - "name": op.inst.value, - "o_cids": cids, - "srcbuff": op.src.buffer.value if op.src.buffer else None, - "srcoff": op.src.index if op.src else None, - "dstbuff": op.dst.buffer.value if op.dst.buffer else None, - "ctype": op.channel_type.value, - "cnt": op.cnt(), - } + dst_channel_ids = get_channel_ids([op.dst], tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) + o_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} + src = op.src elif op.inst == Instruction.get: - cids = get_channel_ids([op.src], tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) + src_channel_ids = get_channel_ids([op.src], tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) + i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} + dst = op.dst + if op.inst != Instruction.nop: instr = { "name": op.inst.value, - "i_cids": cids, - "srcbuff": op.src.buffer.value if op.src.buffer else None, - "dstbuff": op.dst.buffer.value if op.dst.buffer else None, - "dstoff": op.dst.index if op.dst else None, - "ctype": op.channel_type.value, - "cnt": op.cnt(), - } - else: - instr = { - "name": op.inst.value, - "src": op.src.rank if op.src else None, - "srcbuff": op.src.buffer.value if op.src.buffer else None, - "srcoff": op.src.index if op.src else None, - "dst": op.dst.rank if op.dst else None, - "dstbuff": op.dst.buffer.value if op.dst.buffer else None, - "dstoff": op.dst.index if op.dst else None, + "i_buff": i_buff, + "i_cids": src_channel_ids, + "o_buff": o_buff, + "o_cids": dst_channel_ids, + "src": src.rank if src else None, + "srcs": srcs if srcs else None, + "srcbuff": src.buffer.value if src and src.buffer else None, + "srcoff": src.index if src else None, + "dst": dst.rank if dst else None, + "dstbuff": dst.buffer.value if dst and dst.buffer else None, + "dstoff": dst.index if dst else None, "ctype": op.channel_type.value, "cnt": op.cnt(), } - ops.append(instr) + ops.append(remove_empty_fields(instr)) threadblock = { 'id': tb.id, 'ops': ops, diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index d388c87..c672382 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -223,9 +223,9 @@ def add_recv_reduce_copy(self, rank, send_ref, recv_ref, tb, ch, send_op): op.send_match = send_op return op - def add_read_reduce_copy(self, rank, send_ref, recv_ref, tb, ch_type): + def add_read_reduce(self, rank, send_ref, recv_ref, tb, ch_type): tb_step = self._get_tb_step(rank, tb) - op = Op(Instruction.read_reduce_copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step) + op = Op(Instruction.read_reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step) buffer = recv_ref.buffer index = recv_ref.index size = recv_ref.size @@ -260,7 +260,7 @@ def optimize(self): def complete_channels(self): send_op = [Instruction.put, Instruction.signal] - recv_op = [Instruction.wait, Instruction.get, Instruction.read_reduce_copy] + recv_op = [Instruction.wait, Instruction.get, Instruction.read_reduce] for rank, rank_tbs in enumerate(self.tbs): for tbid, tb in rank_tbs.items(): chans = set() @@ -293,7 +293,7 @@ def _optimize_redandant_signal_wait(self, protocol): break if fused: continue - elif op.inst == Instruction.reduce or op.inst == Instruction.read_reduce_copy or op.inst == Instruction.copy: + elif op.inst == Instruction.reduce or op.inst == Instruction.read_reduce or op.inst == Instruction.copy: fused = False for prev_op in op.prev: if prev_op.inst == Instruction.wait: @@ -314,10 +314,10 @@ def _optimize_rrc_r_signal_wait(self): queue = list(tb.ops) while len(queue) > 0: op = queue[0] - if op.inst == Instruction.read_reduce_copy: + if op.inst == Instruction.read_reduce: fused = False for next_op in op.next: - if next_op.inst == Instruction.read_reduce_copy and same_count(op, next_op) and same_buf_dst(op, next_op) and same_chan_type(op, next_op): + if next_op.inst == Instruction.read_reduce and same_count(op, next_op) and same_buf_dst(op, next_op) and same_chan_type(op, next_op): op.srcs.append((ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size), next_op.step)) remove_op(next_op) tb.ops.remove(next_op) @@ -374,14 +374,14 @@ def _optimize_rrcs_rs(self): queue = list(tb.ops) while len(queue) > 0: op = queue[0] - if op.inst == Instruction.read_reduce_copy or op.inst == Instruction.read_reduce_copy_send: + if op.inst == Instruction.read_reduce or op.inst == Instruction.read_reduce_send: fused = False for next_op in op.next: if next_op.inst == Instruction.put and same_count(op, next_op) and buf_dst_src_match(op, next_op) and same_chan_type(op, next_op): if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer: continue - if op.inst == Instruction.read_reduce_copy: - op.inst = Instruction.read_reduce_copy_send + if op.inst == Instruction.read_reduce: + op.inst = Instruction.read_reduce_send op.dsts.append((ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size), next_op.step)) remove_op(next_op) tb.ops.remove(next_op) From 171e894771b6043a9d05bc6fc041b85f52524614 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Fri, 5 Apr 2024 11:37:10 +0000 Subject: [PATCH 26/76] fix --- .../allreduce_a100_allpairs_sm_mscclpp.py | 9 +++++---- msccl/language/ir.py | 15 ++++++++------- msccl/language/rank_dag.py | 16 ++++++++-------- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py index e504f9b..ca4670f 100644 --- a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py +++ b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py @@ -22,6 +22,7 @@ def allreduce_allpairs(gpus, instances, protocol): for nghr in range(size): peer_index = nghr * size if rank != nghr: + # signal peer the buffer is ready c_peer = chunk(rank, Buffer.input, peer_index+tb) c_peer.signal(nghr, Buffer.input, peer_index+tb, sendtb=tb) for nghr in range(size): @@ -33,16 +34,16 @@ def allreduce_allpairs(gpus, instances, protocol): c.reduce_mscclpp(chunk(nghr, Buffer.input, index+tb), recvtb=tb) for nghr in range(size): if rank != nghr: - c.put(nghr, Buffer.input, index, sendtb=tb) - # step3 signal the peers to receive the chunks + c.put(nghr, Buffer.input, index+tb, sendtb=tb) + # step3 signal the peers buffer is ready for nghr in range(size): if rank != nghr: - c.signal(nghr, Buffer.input, index, sendtb=tb) + c.signal(nghr, Buffer.input, index+tb, sendtb=tb) for nghr in range(size): if rank != nghr: peer_index = nghr * size c_peer = chunk(rank, Buffer.input, peer_index+tb) - c_peer.wait(nghr, Buffer.input, peer_index, recvtb=tb) + c_peer.wait(nghr, Buffer.input, peer_index+tb, recvtb=tb) Json() diff --git a/msccl/language/ir.py b/msccl/language/ir.py index 5843260..95b413f 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -117,8 +117,8 @@ class Instruction(Enum): recv_reduce_send = 'rrs' recv_reduce_copy = 'rrc' recv_reduce_copy_send = 'rrcs' - read_reduce = "rr" - read_reduce_send = "rrs" + read_reduce_copy = "rrc" + read_reduce_copy_send = "rrcs" reduce_send = 'rs' copy = 'cpy' reduce = 're' @@ -247,7 +247,7 @@ def __repr__(self): Instruction.recv_reduce_copy_send} _local_src_insts_mscclpp = {Instruction.put, Instruction.signal, Instruction.copy, Instruction.reduce, Instruction.reduce_send} -_local_dst_insts_mscclpp = {Instruction.get, Instruction.wait, Instruction.read_reduce, Instruction.copy, Instruction.reduce, Instruction.read_reduce_send, Instruction.reduce_send} +_local_dst_insts_mscclpp = {Instruction.get, Instruction.wait, Instruction.read_reduce_copy, Instruction.copy, Instruction.reduce, Instruction.read_reduce_copy_send, Instruction.reduce_send} def ir_to_xml(program: Program, old_format=True, use_scratch=True, pretty_print=True, dependence_nop=False): @@ -455,7 +455,7 @@ def ir_to_json(program: Program, dependence_nop=False): buffer_sizes[key] = max( buffer_sizes[key], op.dst.index + op.dst.size) # ignore remote buffers - if op.inst != Instruction.read_reduce_send and op.inst != Instruction.reduce_send: + if op.inst != Instruction.read_reduce_copy_send and op.inst != Instruction.reduce_send: for dst in op.dsts: key = (gpu.rank, dst.buffer) buffer_sizes[key] = max( @@ -595,16 +595,17 @@ def remove_empty_fields(d): # get src channel ids src_channel_ids = get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} - elif op.inst == Instruction.read_reduce: + elif op.inst == Instruction.read_reduce_copy: src_channel_ids = get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} dst = op.dst - elif op.inst == Instruction.read_reduce_send: + elif op.inst == Instruction.read_reduce_copy_send: src_channel_ids = get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) dst_channel_ids = get_channel_ids(op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, op.channel_type) i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} o_buff = {"src": op.dst.buffer.value, "dst": op.dsts[0].buffer.value} dst = op.dst + src = op.dst # the src is the same as dst elif op.inst == Instruction.reduce_send: dst_channel_ids = get_channel_ids(op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, ChannelType.sm) o_buff = {"src": op.dst.buffer.value, "dst": op.dsts[0].buffer.value} @@ -647,7 +648,7 @@ def remove_empty_fields(d): threadblock = { 'id': tb.id, 'ops': ops, - 'channels': list(map(lambda x: {"src": x["srcbuff"], "dst": x["dstbuff"], "ctype": x["type"], "cid": x["chanIds"]}, tb_channels)) + 'channels': list(map(lambda x: {"src": x["srcbuff"], "dst": x["dstbuff"], "ctype": x["type"], "cids": x["chanIds"]}, tb_channels)) } gpu_instance['threadblocks'].append(threadblock) gpus.append(gpu_instance) diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index c672382..e1df383 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -225,7 +225,7 @@ def add_recv_reduce_copy(self, rank, send_ref, recv_ref, tb, ch, send_op): def add_read_reduce(self, rank, send_ref, recv_ref, tb, ch_type): tb_step = self._get_tb_step(rank, tb) - op = Op(Instruction.read_reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step) + op = Op(Instruction.read_reduce_copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step) buffer = recv_ref.buffer index = recv_ref.index size = recv_ref.size @@ -260,7 +260,7 @@ def optimize(self): def complete_channels(self): send_op = [Instruction.put, Instruction.signal] - recv_op = [Instruction.wait, Instruction.get, Instruction.read_reduce] + recv_op = [Instruction.wait, Instruction.get, Instruction.read_reduce_copy] for rank, rank_tbs in enumerate(self.tbs): for tbid, tb in rank_tbs.items(): chans = set() @@ -293,7 +293,7 @@ def _optimize_redandant_signal_wait(self, protocol): break if fused: continue - elif op.inst == Instruction.reduce or op.inst == Instruction.read_reduce or op.inst == Instruction.copy: + elif op.inst == Instruction.reduce or op.inst == Instruction.read_reduce_copy or op.inst == Instruction.copy: fused = False for prev_op in op.prev: if prev_op.inst == Instruction.wait: @@ -314,10 +314,10 @@ def _optimize_rrc_r_signal_wait(self): queue = list(tb.ops) while len(queue) > 0: op = queue[0] - if op.inst == Instruction.read_reduce: + if op.inst == Instruction.read_reduce_copy: fused = False for next_op in op.next: - if next_op.inst == Instruction.read_reduce and same_count(op, next_op) and same_buf_dst(op, next_op) and same_chan_type(op, next_op): + if next_op.inst == Instruction.read_reduce_copy and same_count(op, next_op) and same_buf_dst(op, next_op) and same_chan_type(op, next_op): op.srcs.append((ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size), next_op.step)) remove_op(next_op) tb.ops.remove(next_op) @@ -374,14 +374,14 @@ def _optimize_rrcs_rs(self): queue = list(tb.ops) while len(queue) > 0: op = queue[0] - if op.inst == Instruction.read_reduce or op.inst == Instruction.read_reduce_send: + if op.inst == Instruction.read_reduce_copy or op.inst == Instruction.read_reduce_copy_send: fused = False for next_op in op.next: if next_op.inst == Instruction.put and same_count(op, next_op) and buf_dst_src_match(op, next_op) and same_chan_type(op, next_op): if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer: continue - if op.inst == Instruction.read_reduce: - op.inst = Instruction.read_reduce_send + if op.inst == Instruction.read_reduce_copy: + op.inst = Instruction.read_reduce_copy_send op.dsts.append((ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size), next_op.step)) remove_op(next_op) tb.ops.remove(next_op) From 99ff31cf1ab00590685610fd09f51f1c6e023f60 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Sun, 7 Apr 2024 07:39:44 +0000 Subject: [PATCH 27/76] WIP --- .../allreduce_a100_allpairs_packet_mscclpp.py | 13 +-- msccl/language/__init__.py | 64 +++++++++++++++ msccl/language/ir.py | 17 +++- msccl/language/rank_dag.py | 82 +++++++++++++++++-- 4 files changed, 158 insertions(+), 18 deletions(-) diff --git a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py index bc12011..f440545 100644 --- a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py +++ b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py @@ -18,9 +18,10 @@ def allreduce_allpairs(gpus, instances): for r1 in range(size): for r2 in range(size): if r1 != r2: - index = r2 * size - c = chunk(r1, Buffer.input, index, size=size) - c.put(r2, 'scratch', index=r1*size, sendtb=r2) + for tb in range(size): + index = r2 * size + tb + c = chunk(r1, Buffer.input, index) + c.put_packet(r2, 'scratch', index=r1*size+tb, sendtb=tb) # Each rank performs a local reduction on the nth chunk # Utilize 8 threadblocks for this reduction for better parallelism @@ -29,10 +30,10 @@ def allreduce_allpairs(gpus, instances): c = chunk(r, Buffer.input, r*size + index) for peer in range(size): if peer != r: - c.reduce_mscclpp(chunk(r, 'scratch', peer*size+index), sendtb=index) + c.reduce_packet(chunk(r, 'scratch', peer*size+index), sendtb=index) for peer in range(size): if peer != r: - c.put(peer, 'scratch', (size*size)+r*size+index, sendtb=index) + c.put_packet(peer, 'scratch', (size*size)+r*size+index, sendtb=index) # Each rank get final result from scratch space for r in range(size): @@ -40,7 +41,7 @@ def allreduce_allpairs(gpus, instances): for peer in range(size): if peer != r: c = chunk(r, 'scratch', size*size+peer*size+index) - c.copy(r, Buffer.input, peer*size+index, sendtb=index) + c.copy_packet(r, Buffer.input, peer*size+index, sendtb=index) Json() # Check() diff --git a/msccl/language/__init__.py b/msccl/language/__init__.py index 1b2a26c..f7f5779 100755 --- a/msccl/language/__init__.py +++ b/msccl/language/__init__.py @@ -241,6 +241,31 @@ def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): self.prog.instr_dag.add_signal(sender, self, dst_chunkref, -1, ChannelType.none) self.prog.instr_dag.add_wait(receiver, dst_chunkref, self, -1, ChannelType.none) + def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, channel_type=ChannelType.sm): + self.prog.check_buffer_exists(dst, buffer) + sender = self.rank + receiver = dst + assert sender != receiver, 'Cannot put to the same rank' + + # If index is not specified assume it is going to the same place in the next gpu + if index == -1 and buffer == None: + index = self.index + buffer = self.buffer + elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output: + index = self.prog.buffers[dst][buffer].instance_size() + + # Some inplace collectives have custom logic for buffers and index (ReduceScatter, AllGather) + buffer, index = self.prog.collective.get_buffer_index(self.rank, buffer, index) + + # Direct put + assert (self.prog.topo.link(self.rank, dst) or dst == self.rank), f'No link from {self.rank} to {dst}' + dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) + + self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) + self.prog.instr_dag.add_put_packet(sender, self, dst_chunkref, sendtb, channel_type) + self.prog.instr_dag.add_signal(sender, self, dst_chunkref, -1, ChannelType.none) + self.prog.instr_dag.add_wait(receiver, dst_chunkref, self, -1, ChannelType.none) + def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm): self.prog.check_buffer_exists(src, buffer) sender = src @@ -340,6 +365,35 @@ def copy(self, dst, buffer=None, index=-1, sendtb=-1, recvtb=-1, ch=-1): return dst_chunkref + def copy_packet(self, dst, buffer=None, index=-1, sendtb=-1): + self.prog.check_buffer_exists(dst, buffer) + + # If index is not specified assume it is going to the same place in the next gpu + if index == -1 and buffer == None: + index = self.index + buffer = self.buffer + elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output: + index = self.prog.buffers[dst][buffer].instance_size() + + # Some inplace collectives have custom logic for buffers and index (ReduceScatter, AllGather) + buffer, index = self.prog.collective.get_buffer_index(self.rank, buffer, index) + + dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) + + # Check if we are copying the chunk to the same index (easy mistake when we are using inplace) + if dst_chunkref == self: + return + + self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) + + # self.prog.chunk_dag.add_send(chunks, overwritten_chunks, self, dst_chunkref, sendtb, recvtb, ch) + sender = self.rank + receiver = dst + assert sender == receiver, 'Packet copy only supports intra-rank communication' + self.prog.instr_dag.add_copy_packet(sender, self, dst_chunkref, sendtb) + + return dst_chunkref + # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref def reduce(self, other_chunkref, sendtb=-1, recvtb=-1, ch=-1): # Receive reduce copy @@ -379,6 +433,16 @@ def reduce_mscclpp(self, other_chunkref, sendtb=-1, recvtb=-1, channel_type=Chan return self + # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref + def reduce_packet(self, other_chunkref, sendtb=-1): + # Receive reduce copy + dst = self.rank + src = other_chunkref.rank + assert dst == src, 'Packet reduce only supports intra-rank communication' + self.prog.apply_reduce(src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size) + self.prog.instr_dag.add_reduce_packet(src, other_chunkref, self, sendtb) + return self + def get_origin_index(self, index=0): return self._get_chunk(index + self.index).origin_index diff --git a/msccl/language/ir.py b/msccl/language/ir.py index 95b413f..2a8a27e 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -120,11 +120,15 @@ class Instruction(Enum): read_reduce_copy = "rrc" read_reduce_copy_send = "rrcs" reduce_send = 'rs' + reduce_send_packet = 'rspkt' copy = 'cpy' + copy_packet = 'cpkt' reduce = 're' + reduce_packet = 'rpkt' delete = 'd' start = 'st' put = 'put' + put_packet = 'ppkt' get = 'get' wait = 'wait' signal = 'signal' @@ -501,7 +505,7 @@ def ir_to_json(program: Program, dependence_nop=False): # Do some additional postprocessing of operations: # - Expand operations with dependencies with no-ops - if program.protocol != "LL": # ignore the dependence_nop for LL protocol + if program.protocol != "LL": # (TODO(binyli): fix it) ignore the dependence_nop for LL protocol for gpu in program.gpus: for tb in gpu.threadblocks: new_ops = [] @@ -599,18 +603,20 @@ def remove_empty_fields(d): src_channel_ids = get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} dst = op.dst + src = op.dst # TODO(binyli): fix this elif op.inst == Instruction.read_reduce_copy_send: src_channel_ids = get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) dst_channel_ids = get_channel_ids(op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, op.channel_type) i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} o_buff = {"src": op.dst.buffer.value, "dst": op.dsts[0].buffer.value} dst = op.dst - src = op.dst # the src is the same as dst - elif op.inst == Instruction.reduce_send: + src = op.dst # TODO(binyli): fix this + elif op.inst == Instruction.reduce_send or op.inst == Instruction.reduce_send_packet: dst_channel_ids = get_channel_ids(op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, ChannelType.sm) o_buff = {"src": op.dst.buffer.value, "dst": op.dsts[0].buffer.value} srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs)) dst = op.dst + src = op.dst # TODO(binyli): fix this elif op.inst == Instruction.reduce: srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs)) dst = op.dst @@ -619,7 +625,7 @@ def remove_empty_fields(d): "name": op.inst.value, "deps": list(map(lambda dep: {"tb": dep.tb, "step": dep.step}, op.depends)) } - elif op.inst == Instruction.put: + elif op.inst == Instruction.put or op.inst == Instruction.put_packet: dst_channel_ids = get_channel_ids([op.dst], tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) o_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} src = op.src @@ -627,6 +633,9 @@ def remove_empty_fields(d): src_channel_ids = get_channel_ids([op.src], tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} dst = op.dst + elif op.inst == Instruction.copy or op.inst == Instruction.copy_packet: + src = op.src + dst = op.dst if op.inst != Instruction.nop: instr = { "name": op.inst.value, diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index e1df383..5ace4a1 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -134,6 +134,20 @@ def add_copy(self, rank, send_ref, recv_ref, tb, ch): self._write(rank, dstbuffer, dstindex, size, op) return op + def add_copy_packet(self, rank, send_ref, recv_ref, tb): + tb_step = self._get_tb_step(rank, tb) + op = Op(Instruction.copy_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) + dstbuffer = recv_ref.buffer + dstindex = recv_ref.index + srcbuffer = send_ref.buffer + srcindex = send_ref.index + size = recv_ref.size + # Sending part of copy [Read] + self._read(rank, srcbuffer, srcindex, size, op) + # Receiving part of copy [Write] + self._write(rank, dstbuffer, dstindex, size, op) + return op + # InstructionDAG - adds a redduce node def add_reduce(self, rank, send_ref, recv_ref, tb, ch): tb_step = self._get_tb_step(rank, tb) @@ -151,6 +165,22 @@ def add_reduce(self, rank, send_ref, recv_ref, tb, ch): self._write(rank, dstbuffer, dstindex, size, op, read=True) return op + # InstructionDAG - adds a redduce packet node + def add_reduce_packet(self, rank, send_ref, recv_ref, tb): + tb_step = self._get_tb_step(rank, tb) + op = Op(Instruction.reduce_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) + dstbuffer = recv_ref.buffer + dstindex = recv_ref.index + srcbuffer = send_ref.buffer + srcindex = send_ref.index + size = recv_ref.size + op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) + # Sending part of reduce + self._read(rank, srcbuffer, srcindex, size, op) + # Reduce part of copy + self._write(rank, dstbuffer, dstindex, size, op, read=True) + return op + # InstructionDAG - adds a send node def add_send(self, rank, send_ref, recv_ref, tb, ch): op = Op(Instruction.send, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) @@ -170,6 +200,15 @@ def add_put(self, rank, send_ref, recv_ref, tb, ch_type): self._read(rank, buffer, index, size, op) return op + def add_put_packet(self, rank, send_ref, recv_ref, tb, ch_type): + tb_step = self._get_tb_step(rank, tb) + op = Op(Instruction.put_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step) + buffer = send_ref.buffer + index = send_ref.index + size = send_ref.size + self._read(rank, buffer, index, size, op) + return op + def add_get(self, rank, send_ref, recv_ref, tb, ch_type): tb_step = self._get_tb_step(rank, tb) op = Op(Instruction.get, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step) @@ -259,7 +298,7 @@ def optimize(self): self._optimize_rcs() def complete_channels(self): - send_op = [Instruction.put, Instruction.signal] + send_op = [Instruction.put, Instruction.signal, Instruction.put_packet] recv_op = [Instruction.wait, Instruction.get, Instruction.read_reduce_copy] for rank, rank_tbs in enumerate(self.tbs): for tbid, tb in rank_tbs.items(): @@ -275,16 +314,14 @@ def complete_channels(self): chans.add(chan) tb.channels = list(chans) - def _optimize_redandant_signal_wait(self, protocol): - if protocol != 'LL': - return - # For LL algorithm, we can remove signal/wait + def _optimize_redandant_signal_wait(self): + # For packet ops, we can remove signal/wait for rank, rank_tbs in enumerate(self.tbs): for tbid, tb in rank_tbs.items(): queue = list(tb.ops) while len(queue) > 0: op = queue[0] - if op.inst == Instruction.put: + if op.inst == Instruction.put_packet: fused = False for next_op in op.next: if next_op.inst == Instruction.signal: @@ -293,7 +330,7 @@ def _optimize_redandant_signal_wait(self, protocol): break if fused: continue - elif op.inst == Instruction.reduce or op.inst == Instruction.read_reduce_copy or op.inst == Instruction.copy: + elif op.inst == Instruction.reduce_packet or op.inst == Instruction.copy_packet: fused = False for prev_op in op.prev: if prev_op.inst == Instruction.wait: @@ -308,6 +345,7 @@ def _optimize_redandant_signal_wait(self, protocol): # signal(_,_,_,dst,dbuf,di) signal(_,_,_,dst,dbuf,di) -> signal(_,_,_,list[dst,dbuf,di]) # wait(src,sbuf,si,_,_,_) wait(src,sbuf,si,_,_,_) -> wait(list[src,sbuf,si],_,_,_,_]) # reduce(_,_,_,dst,dbuf,di) reduce(_,_,_,dst,dbuf,di) -> reduce(list[src,sbuf,si], dst, dbuf, di) + # reduce_packet(_,_,_,dst,dbuf,di) reduce_packet(_,_,_,dst,dbuf,di) -> reduce_packet(list[src,sbuf,si], dst, dbuf, di) def _optimize_rrc_r_signal_wait(self): for rank, rank_tbs in enumerate(self.tbs): for tbid, tb in rank_tbs.items(): @@ -338,6 +376,18 @@ def _optimize_rrc_r_signal_wait(self): break if fused: continue + elif op.inst == Instruction.reduce_packet: + fused = False + for next_op in op.next: + if next_op.inst == Instruction.reduce_packet and same_buf_dst(op, next_op) and same_chan_type(op, next_op): + op.srcs.append((ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size), next_op.step)) + remove_op(next_op) + tb.ops.remove(next_op) + queue.remove(next_op) + fused = True + break + if fused: + continue elif op.inst == Instruction.signal: fused = False for next_op in op.next: @@ -406,6 +456,22 @@ def _optimize_rrcs_rs(self): break if fused: continue + if op.inst == Instruction.reduce_packet or op.inst == Instruction.reduce_send_packet: + fused = False + for next_op in op.next: + if next_op.inst == Instruction.put_packet and same_count(op, next_op) and buf_dst_src_match(op, next_op): + if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer: + continue + if op.inst == Instruction.reduce_packet: + op.inst = Instruction.reduce_send_packet + op.dsts.append((ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size), next_op.step)) + remove_op(next_op) + tb.ops.remove(next_op) + queue.remove(next_op) + fused = True + break + if fused: + continue queue = queue[1:] # For signal/wait ops, if they are independent of other operations and no other operations in between, @@ -448,7 +514,7 @@ def _parallel_signal_wait(self): queue = queue[1:] def optimize_mscclpp(self, protocol): - self._optimize_redandant_signal_wait(protocol) + self._optimize_redandant_signal_wait() self._optimize_rrc_r_signal_wait() self._optimize_rrcs_rs() From 93683b5449719f0f72c2fb37a4bd4f017dac0911 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 8 Apr 2024 03:58:12 +0000 Subject: [PATCH 28/76] WIP --- msccl/language/rank_dag.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index 5ace4a1..329fdc7 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -459,11 +459,12 @@ def _optimize_rrcs_rs(self): if op.inst == Instruction.reduce_packet or op.inst == Instruction.reduce_send_packet: fused = False for next_op in op.next: - if next_op.inst == Instruction.put_packet and same_count(op, next_op) and buf_dst_src_match(op, next_op): + if next_op.inst == Instruction.put_packet and same_count(op, next_op) and buf_dst_src_match(op, next_op) and next_op.channel_type == ChannelType.sm: if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer: continue if op.inst == Instruction.reduce_packet: op.inst = Instruction.reduce_send_packet + op.channel_type = ChannelType.sm op.dsts.append((ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size), next_op.step)) remove_op(next_op) tb.ops.remove(next_op) From 10b648ca0bdd2a2984a4b4e170871199edd80612 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 8 Apr 2024 08:38:15 +0000 Subject: [PATCH 29/76] WIP --- .../allreduce_a100_allpairs_packet_mscclpp.py | 28 +- .../allreduce_a100_allpairs_sm_mscclpp.py | 30 +- msccl/language/__init__.py | 187 +---------- msccl/language/mscclpp.py | 316 ++++++++++++++++++ pyproject.toml | 4 + 5 files changed, 360 insertions(+), 205 deletions(-) create mode 100644 msccl/language/mscclpp.py create mode 100644 pyproject.toml diff --git a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py index f440545..1e5ae90 100644 --- a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py +++ b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py @@ -6,13 +6,20 @@ from msccl.topologies import * from msccl.language.collectives import AllReduce + def allreduce_allpairs(gpus, instances): size = gpus chunksperloop = gpus * gpus topology = fully_connected(size) collective = AllReduce(size, chunksperloop, True) - with MSCCLProgram("allreduce_pairs", topology, collective, instances, protocol="LL", - interleaved_replication=False, threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True): + with MSCCLPPProgram( + "allreduce_pairs", + topology, + collective, + instances, + protocol="LL", + dependence_nop=True, + ): # Each rank sends the nth chunk to the nth rank into scratch space for r1 in range(size): @@ -21,34 +28,35 @@ def allreduce_allpairs(gpus, instances): for tb in range(size): index = r2 * size + tb c = chunk(r1, Buffer.input, index) - c.put_packet(r2, 'scratch', index=r1*size+tb, sendtb=tb) + c.put_packet(r2, "scratch", index=r1 * size + tb, sendtb=tb) # Each rank performs a local reduction on the nth chunk # Utilize 8 threadblocks for this reduction for better parallelism for r in range(size): for index in range(size): - c = chunk(r, Buffer.input, r*size + index) + c = chunk(r, Buffer.input, r * size + index) for peer in range(size): if peer != r: - c.reduce_packet(chunk(r, 'scratch', peer*size+index), sendtb=index) + c.reduce_packet(chunk(r, "scratch", peer * size + index), sendtb=index) for peer in range(size): if peer != r: - c.put_packet(peer, 'scratch', (size*size)+r*size+index, sendtb=index) + c.put_packet(peer, "scratch", (size * size) + r * size + index, sendtb=index) # Each rank get final result from scratch space for r in range(size): for index in range(size): for peer in range(size): if peer != r: - c = chunk(r, 'scratch', size*size+peer*size+index) - c.copy_packet(r, Buffer.input, peer*size+index, sendtb=index) + c = chunk(r, "scratch", size * size + peer * size + index) + c.copy_packet(r, Buffer.input, peer * size + index, sendtb=index) Json() # Check() + parser = argparse.ArgumentParser() -parser.add_argument('num_gpus', type=int, help ='number of gpus') -parser.add_argument('instances', type=int, help='number of instances') +parser.add_argument("num_gpus", type=int, help="number of gpus") +parser.add_argument("instances", type=int, help="number of instances") args = parser.parse_args() diff --git a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py index ca4670f..26138c9 100644 --- a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py +++ b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py @@ -6,51 +6,51 @@ from msccl.topologies import * from msccl.language.collectives import AllReduce + def allreduce_allpairs(gpus, instances, protocol): size = gpus chunksperloop = gpus * gpus topology = fully_connected(size) collective = AllReduce(size, chunksperloop, True) - with MSCCLProgram("allreduce_pairs", topology, collective, instances, protocol=protocol, - interleaved_replication=False, threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True): - + with MSCCLPPProgram("allreduce_pairs", topology, collective, instances, protocol=protocol, dependence_nop=True): for rank in range(size): for tb in range(size): index = rank * size - c = chunk(rank, Buffer.input, index+tb) + c = chunk(rank, Buffer.input, index + tb) # step1 make sure the data is ready for nghr in range(size): peer_index = nghr * size if rank != nghr: # signal peer the buffer is ready - c_peer = chunk(rank, Buffer.input, peer_index+tb) - c_peer.signal(nghr, Buffer.input, peer_index+tb, sendtb=tb) + c_peer = chunk(rank, Buffer.input, peer_index + tb) + c_peer.signal(nghr, Buffer.input, peer_index + tb, sendtb=tb) for nghr in range(size): if rank != nghr: - c.wait(nghr, Buffer.input, index+tb, recvtb=tb) + c.wait(nghr, Buffer.input, index + tb, recvtb=tb) # step2 reduce the chunks and send to peers for nghr in range(size): if rank != nghr: - c.reduce_mscclpp(chunk(nghr, Buffer.input, index+tb), recvtb=tb) + c.reduce(chunk(nghr, Buffer.input, index + tb), recvtb=tb) for nghr in range(size): if rank != nghr: - c.put(nghr, Buffer.input, index+tb, sendtb=tb) + c.put(nghr, Buffer.input, index + tb, sendtb=tb) # step3 signal the peers buffer is ready for nghr in range(size): if rank != nghr: - c.signal(nghr, Buffer.input, index+tb, sendtb=tb) + c.signal(nghr, Buffer.input, index + tb, sendtb=tb) for nghr in range(size): if rank != nghr: peer_index = nghr * size - c_peer = chunk(rank, Buffer.input, peer_index+tb) - c_peer.wait(nghr, Buffer.input, peer_index+tb, recvtb=tb) + c_peer = chunk(rank, Buffer.input, peer_index + tb) + c_peer.wait(nghr, Buffer.input, peer_index + tb, recvtb=tb) Json() + parser = argparse.ArgumentParser() -parser.add_argument('num_gpus', type=int, help ='number of gpus') -parser.add_argument('instances', type=int, help='number of instances') -parser.add_argument('--protocol', type=str, default='Simple', choices=['Simple'], help='Protocol') +parser.add_argument("num_gpus", type=int, help="number of gpus") +parser.add_argument("instances", type=int, help="number of instances") +parser.add_argument("--protocol", type=str, default="Simple", choices=["Simple"], help="Protocol") args = parser.parse_args() diff --git a/msccl/language/__init__.py b/msccl/language/__init__.py index f7f5779..7b89990 100755 --- a/msccl/language/__init__.py +++ b/msccl/language/__init__.py @@ -11,24 +11,25 @@ from msccl.language.buffer import * from msccl.language.rank_dag import * import msccl.collectives as collectives +import msccl.language.mscclpp as mscclpp +from msccl.language.mscclpp import * + # from msccl.language.visualize import * _current_program = None def _curr(): global _current_program - if _current_program == None: + if _current_program == None and mscclpp._current_program == None: raise RuntimeError("No Program in context") + if _current_program == None: + return mscclpp._current_program return _current_program -# For msccl++ program, we have one assumption that for channel can be identified by (send_buffer, recv_buffer, type, send_tb/recv_tb) -# which means the send_tb and recv_tb should be the same for a pair of signal and wait, also same for put/get operation. -# If one sender what to send data to peer want to use different tb in receiver side. We need to send to same tb in receiver side first, -# then performance a across tb sync. This is a limitation of current implementation. class MSCCLProgram: def __init__(self, name, topo, collective, instances, protocol='Simple', \ threadblock_policy=ThreadblockPolicy.auto, interleaved_replication=True, - instr_fusion=True, check_xml=True, dependence_nop=False, instance_policy=InstancePolicy.dup): + instr_fusion=True, check_xml=True, dependence_nop=False): self.name = name self.topo = topo self.collective = collective @@ -40,7 +41,6 @@ def __init__(self, name, topo, collective, instances, protocol='Simple', \ self.instr_fusion = instr_fusion self.check_xml = check_xml self.dependence_nop = dependence_nop - self.instance_policy = instance_policy assert protocol == 'Simple' or protocol == 'LL' or protocol == 'LL128', \ f'Given protocol: {protocol}. Must be either Simple, LL, LL128' self.run_opt = True # Runs optimization passes @@ -131,22 +131,9 @@ def lower(self): check_threadblock_ordering(self.instr_dag) return Program(self.name, self.collective.name, self.collective.inplace, self.protocol, gpu_prgms) - # Lower program to MSCCLPP - def lower_mscclpp(self): - convert_to_exectuion_plan(self.instr_dag) - self.instr_dag.complete_channels() - if self.instr_fusion: - self.instr_dag.optimize_mscclpp(self.protocol) - self.instr_dag.lower_pt1(self.instances) - gpu_prgms = self.instr_dag.lower_pt2_mscclpp(self.instances, self.instance_policy) - return Program(self.name, self.collective.name, self.collective.inplace, self.protocol, gpu_prgms) - def generate_xml(self): return ir_to_xml(self.lower(), dependence_nop=self.dependence_nop) - def generate_json(self): - return ir_to_json(self.lower_mscclpp(), dependence_nop=self.dependence_nop) - def print_chunk_dag(self): visualize_chunk_dag(self.chunk_dag.chunk_paths) @@ -157,9 +144,6 @@ def print_instr_dags(self, rank): else: visualize_instr_dag(self.instr_dags[rank].operations) -class MSCCLPPProgram: - pass - def Print(): _curr().print_chunk_dag() @@ -174,8 +158,6 @@ def create_scratch(rank, name): def XML(): print(_curr().generate_xml()) -def Json(): - print(_curr().generate_json()) def Check(): return _curr().check() @@ -215,117 +197,6 @@ def group(self, other): end = max(first._end(), second._end()) return Ref(self.rank, self.buffer, first.index, end - first.index, self.prog) - def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): - self.prog.check_buffer_exists(dst, buffer) - sender = self.rank - receiver = dst - assert sender != receiver, 'Cannot put to the same rank' - - # If index is not specified assume it is going to the same place in the next gpu - if index == -1 and buffer == None: - index = self.index - buffer = self.buffer - elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output: - index = self.prog.buffers[dst][buffer].instance_size() - - # Some inplace collectives have custom logic for buffers and index (ReduceScatter, AllGather) - buffer, index = self.prog.collective.get_buffer_index(self.rank, buffer, index) - - # Direct put - assert (self.prog.topo.link(self.rank, dst) or dst == self.rank), f'No link from {self.rank} to {dst}' - dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) - - self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) - self.prog.instr_dag.add_put(sender, self, dst_chunkref, sendtb, chan_type) - if self.prog.protocol == 'LL': - self.prog.instr_dag.add_signal(sender, self, dst_chunkref, -1, ChannelType.none) - self.prog.instr_dag.add_wait(receiver, dst_chunkref, self, -1, ChannelType.none) - - def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, channel_type=ChannelType.sm): - self.prog.check_buffer_exists(dst, buffer) - sender = self.rank - receiver = dst - assert sender != receiver, 'Cannot put to the same rank' - - # If index is not specified assume it is going to the same place in the next gpu - if index == -1 and buffer == None: - index = self.index - buffer = self.buffer - elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output: - index = self.prog.buffers[dst][buffer].instance_size() - - # Some inplace collectives have custom logic for buffers and index (ReduceScatter, AllGather) - buffer, index = self.prog.collective.get_buffer_index(self.rank, buffer, index) - - # Direct put - assert (self.prog.topo.link(self.rank, dst) or dst == self.rank), f'No link from {self.rank} to {dst}' - dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) - - self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) - self.prog.instr_dag.add_put_packet(sender, self, dst_chunkref, sendtb, channel_type) - self.prog.instr_dag.add_signal(sender, self, dst_chunkref, -1, ChannelType.none) - self.prog.instr_dag.add_wait(receiver, dst_chunkref, self, -1, ChannelType.none) - - def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm): - self.prog.check_buffer_exists(src, buffer) - sender = src - receiver = self.rank - assert sender != receiver, 'Cannot get from the same rank' - - # If index is not specified assume it is going to the same place in the next gpu - if index == -1 and buffer == None: - index = self.index - buffer = self.buffer - elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output: - index = self.prog.buffers[src][buffer].instance_size() - - # Some inplace collectives have custom logic for buffers and index (ReduceScatter, AllGather) - buffer, index = self.prog.collective.get_buffer_index(src, buffer, index) - - # Direct get - assert (self.prog.topo.link(self.rank, src) or src == self.rank), f'No link from {self.rank} to {src}' - src_chunkref = self.prog.get_ref(src, buffer, index, self.size) - - self.prog.apply_send(src, buffer, index, self.rank, self.buffer, self.index, self.size) - self.prog.instr_dag.add_get(receiver, src_chunkref, self, recvtb, chan_type) - - # for signal and wait, currently we assuem the pair will use the same tb index. In future we need - # to infer the tb index from the instruction DAG Add a channel is define as (send_tb, src_buffer, recv_tb, dst_buffer, type). - # Then we can use DAG info to reduce the number of channels. - def signal(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): - sender = self.rank - receiver = dst - assert sender != receiver, 'Cannot signal to the same rank' - - if index == -1 and buffer == None: - index = self.index - buffer = self.buffer - elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output: - index = self.prog.buffers[dst][buffer].instance_size() - - # Direct signal - assert (self.prog.topo.link(self.rank, dst) or dst == self.rank), f'No link from {self.rank} to {dst}' - dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) - - self.prog.instr_dag.add_signal(sender, self, dst_chunkref, sendtb, chan_type) - - def wait(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm): - sender = src - receiver = self.rank - assert sender != receiver, 'Cannot wait on the same rank' - - if index == -1 and buffer == None: - index = self.index - buffer = self.buffer - elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output: - index = self.prog.buffers[src][buffer].instance_size() - - # Direct signal - assert (self.prog.topo.link(self.rank, src) or src == self.rank), f'No link from {self.rank} to {src}' - src_chunkref = self.prog.get_ref(src, buffer, index, self.size) - - self.prog.instr_dag.add_wait(receiver, self, src_chunkref, recvtb, chan_type) - # Copies the chunk(s) referenced by this chunkref onto Rank dst at location (buffer, index) def copy(self, dst, buffer=None, index=-1, sendtb=-1, recvtb=-1, ch=-1): self.prog.check_buffer_exists(dst, buffer) @@ -365,35 +236,6 @@ def copy(self, dst, buffer=None, index=-1, sendtb=-1, recvtb=-1, ch=-1): return dst_chunkref - def copy_packet(self, dst, buffer=None, index=-1, sendtb=-1): - self.prog.check_buffer_exists(dst, buffer) - - # If index is not specified assume it is going to the same place in the next gpu - if index == -1 and buffer == None: - index = self.index - buffer = self.buffer - elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output: - index = self.prog.buffers[dst][buffer].instance_size() - - # Some inplace collectives have custom logic for buffers and index (ReduceScatter, AllGather) - buffer, index = self.prog.collective.get_buffer_index(self.rank, buffer, index) - - dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) - - # Check if we are copying the chunk to the same index (easy mistake when we are using inplace) - if dst_chunkref == self: - return - - self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) - - # self.prog.chunk_dag.add_send(chunks, overwritten_chunks, self, dst_chunkref, sendtb, recvtb, ch) - sender = self.rank - receiver = dst - assert sender == receiver, 'Packet copy only supports intra-rank communication' - self.prog.instr_dag.add_copy_packet(sender, self, dst_chunkref, sendtb) - - return dst_chunkref - # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref def reduce(self, other_chunkref, sendtb=-1, recvtb=-1, ch=-1): # Receive reduce copy @@ -418,21 +260,6 @@ def reduce(self, other_chunkref, sendtb=-1, recvtb=-1, ch=-1): return self - # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref - def reduce_mscclpp(self, other_chunkref, sendtb=-1, recvtb=-1, channel_type=ChannelType.sm): - # Receive reduce copy - dst = self.rank - src = other_chunkref.rank - assert (self.prog.topo.link(src, dst) or src == dst), f'No link from {src} to {dst}' - self.prog.apply_reduce(src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size) - - if src != dst: - self.prog.instr_dag.add_read_reduce(dst, other_chunkref, self, recvtb, channel_type) - else: - self.prog.instr_dag.add_reduce(src, other_chunkref, self, sendtb, ChannelType.none) - - return self - # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref def reduce_packet(self, other_chunkref, sendtb=-1): # Receive reduce copy diff --git a/msccl/language/mscclpp.py b/msccl/language/mscclpp.py new file mode 100644 index 0000000..c329625 --- /dev/null +++ b/msccl/language/mscclpp.py @@ -0,0 +1,316 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from msccl.collectives import Collective +from msccl.language.buffer import * +from msccl.language.ir import * +from msccl.language.rank_dag import * +from msccl.language.tb_assignment import * +from msccl.topologies.topology import Topology + +_current_program = None + + +def _curr(): + global _current_program + if _current_program == None: + raise RuntimeError("No Program in context") + return _current_program + +# For msccl++ program, we have one assumption that for channel can be identified by (send_buffer, recv_buffer, type, send_tb/recv_tb) +# which means the send_tb and recv_tb should be the same for a pair of signal and wait, also same for put/get operation. +# If one sender what to send data to peer want to use different tb in receiver side. We need to send to same tb in receiver side first, +# then performance a across tb sync. This is a limitation of current implementation. +class MSCCLPPProgram: + def __init__( + self, + name: str, + topo: Topology, + collective: Collective, + instances: int, + protocol: str = "Simple", + instr_fusion: bool = True, + dependence_nop: bool = False, + instance_policy: InstancePolicy = InstancePolicy.dup, + ): + self.name = name + self.topo = topo + self.collective = collective + self.num_ranks = topo.num_nodes() + self.instances = instances + self.protocol = protocol + self.instr_fusion = instr_fusion + self.dependence_nop = dependence_nop + self.instance_policy = instance_policy + assert protocol == "Simple" or protocol == "LL", f"Given protocol: {protocol}. Must be either Simple, LL" + self.run_opt = True # Runs optimization passes + # Initialize the input buffers + self.buffers = collective.init_buffers() + self.instr_dag = InstructionDAG(self.num_ranks, self.buffers) + for r in range(self.num_ranks): + for index, chunk in enumerate(self.buffers[r][Buffer.input]): + buffer, index = self.collective.get_buffer_index(r, Buffer.input, index) + ref = self.get_ref(r, buffer, index, 1) + # self.chunk_dag.init_chunk(chunk, ref) + self.instr_dag.add_start(r, buffer, index, ref) + + def __enter__(self): + global _current_program + if _current_program != None: + raise RuntimeError("There is already a MSCCLPP Program in context") + _current_program = self + + def __exit__(self, exc_type, exc_value, exc_traceback): + global _current_program + if _current_program != self: + raise RuntimeError("This program is not currently in context") + _current_program = None + + # Tracks a send operation on the buffers + def apply_send(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size): + src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index) + dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index) + sb = self.buffers[src][src_buffer] + db = self.buffers[dst][dst_buffer] + for i in range(size): + db[dst_index + i] = sb[src_index + i] + + # Tracks a reduce operation on the buffers + def apply_reduce(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size): + src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index) + dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index) + sb = self.buffers[src][src_buffer] + db = self.buffers[dst][dst_buffer] + for i in range(size): + reduce_chunk = db[dst_index + i] + sent_chunk = sb[src_index + i] + db[dst_index + i] = reduce_chunk.reduce(dst, sent_chunk) + + def get_ref(self, rank, buffer, index, size): + buffer, index = self.collective.get_buffer_index(rank, buffer, index) + return Ref(rank, buffer, index, size, self) + + def get_chunks(self, rank, buffer, index, size=1): + chunks = [None] * size + for i in range(0, size): + if self.buffers[rank][buffer] and index + i < len(self.buffers[rank][buffer]): + chunks[i] = self.buffers[rank][buffer][index + i] + else: + chunks[i] = None + return chunks + + def check_buffer_exists(self, rank, name): + if name not in self.buffers[rank]: + self.buffers[rank][name] = BufferSlice(Buffer.scratch, name) + + # Checks that all chunks that should be on each rank + # are present in the output buffer. + def check(self): + return self.collective.check(self) + + # Lower program to MSCCLPP + def lower(self): + convert_to_exectuion_plan(self.instr_dag) + self.instr_dag.complete_channels() + if self.instr_fusion: + self.instr_dag.optimize_mscclpp(self.protocol) + self.instr_dag.lower_pt1(self.instances) + gpu_prgms = self.instr_dag.lower_pt2_mscclpp(self.instances, self.instance_policy) + return Program( + self.name, + self.collective.name, + self.collective.inplace, + self.protocol, + gpu_prgms, + ) + + def generate_json(self): + return ir_to_json(self.lower(), dependence_nop=self.dependence_nop) + + +def Json(): + print(_curr().generate_json()) + +@dataclass +class Ref(ChunkRef): + prog: MSCCLPPProgram + + def __repr__(self): + return f"Ref(Buffer:{self.buffer}, Index:{self.index}, Size:{self.size}, Rank:{self.rank})" + + def _end(self): + return self.index + self.size + + def _get_chunk(self, index): + return self.prog.buffers[self.rank][self.buffer][index] + + def split(self, num): + assert self.size % num == 0, f"Trying to split a chunk of {self.size} elements into {num} parts" + chunks = [None] * num + size = self.size // num + for i in range(num): + index = self.index + i * size + chunks[i] = self.prog.get_ref(self.rank, self.buffer, index, size) + return chunks + + def group(self, other): + assert self.rank == other.rank, f"Trying to concatenate chunks on ranks {self.rank} and {other.rank}" + assert self.buffer == other.buffer, f"Trying to concatenate chunks in {self.buffer} and {other.buffer}" + if self.index < other.index: + first = self + second = other + else: + first = other + second = self + + end = max(first._end(), second._end()) + return Ref(self.rank, self.buffer, first.index, end - first.index, self.prog) + + def _get_buffer_index(self, remote_rank, buffer, index): + if index == -1 and buffer == None: + return self.buffer, self.index + elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output: + return buffer, self.prog.buffers[remote_rank][buffer].instance_size() + return buffer, index + + def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): + self.prog.check_buffer_exists(dst, buffer) + sender = self.rank + receiver = dst + assert sender != receiver, "Cannot put to the same rank" + buffer, index = self._get_buffer_index(dst, buffer, index) + + # Direct put + assert self.prog.topo.link(self.rank, dst) or dst == self.rank, f"No link from {self.rank} to {dst}" + dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) + + self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) + self.prog.instr_dag.add_put(sender, self, dst_chunkref, sendtb, chan_type) + + def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, channel_type=ChannelType.sm): + self.prog.check_buffer_exists(dst, buffer) + sender = self.rank + receiver = dst + assert sender != receiver, "Cannot put to the same rank" + buffer, index = self._get_buffer_index(dst, buffer, index) + + # Direct put + assert self.prog.topo.link(self.rank, dst) or dst == self.rank, f"No link from {self.rank} to {dst}" + dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) + + self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) + self.prog.instr_dag.add_put_packet(sender, self, dst_chunkref, sendtb, channel_type) + self.prog.instr_dag.add_signal(sender, self, dst_chunkref, -1, ChannelType.none) + self.prog.instr_dag.add_wait(receiver, dst_chunkref, self, -1, ChannelType.none) + + def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm): + self.prog.check_buffer_exists(src, buffer) + sender = src + receiver = self.rank + assert sender != receiver, "Cannot get from the same rank" + buffer, index = self._get_buffer_index(src, buffer, index) + + # Direct get + assert self.prog.topo.link(self.rank, src) or src == self.rank, f"No link from {self.rank} to {src}" + src_chunkref = self.prog.get_ref(src, buffer, index, self.size) + + self.prog.apply_send(src, buffer, index, self.rank, self.buffer, self.index, self.size) + self.prog.instr_dag.add_get(receiver, src_chunkref, self, recvtb, chan_type) + + # for signal and wait, currently we assuem the pair will use the same tb index. In future we need + # to infer the tb index from the instruction DAG Add a channel is define as (send_tb, src_buffer, recv_tb, dst_buffer, type). + # Then we can use DAG info to reduce the number of channels. + def signal(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): + sender = self.rank + receiver = dst + assert sender != receiver, "Cannot signal to the same rank" + buffer, index = self._get_buffer_index(dst, buffer, index) + + # Direct signal + assert self.prog.topo.link(self.rank, dst) or dst == self.rank, f"No link from {self.rank} to {dst}" + dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) + self.prog.instr_dag.add_signal(sender, self, dst_chunkref, sendtb, chan_type) + + def wait(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm): + sender = src + receiver = self.rank + assert sender != receiver, "Cannot wait on the same rank" + buffer, index = self._get_buffer_index(src, buffer, index) + + # Direct wait + assert self.prog.topo.link(self.rank, src) or src == self.rank, f"No link from {self.rank} to {src}" + src_chunkref = self.prog.get_ref(src, buffer, index, self.size) + self.prog.instr_dag.add_wait(receiver, self, src_chunkref, recvtb, chan_type) + + # Copies the chunk(s) referenced by this chunkref onto Rank dst at location (buffer, index) + def copy(self, dst, buffer=None, index=-1, sendtb=-1, ch=-1): + self.prog.check_buffer_exists(dst, buffer) + buffer, index = self._get_buffer_index(dst, buffer, index) + + dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) + # Check if we are copying the chunk to the same index (easy mistake when we are using inplace) + if dst_chunkref == self: + return + self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) + + assert self.rank == dst, "Chunk copy only supports intra-rank communication" + self.prog.instr_dag.add_copy(self.rank, self, dst_chunkref, sendtb, ch) + + return dst_chunkref + + def copy_packet(self, dst, buffer=None, index=-1, sendtb=-1): + self.prog.check_buffer_exists(dst, buffer) + buffer, index = self._get_buffer_index(dst, buffer, index) + + dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) + # Check if we are copying the chunk to the same index (easy mistake when we are using inplace) + if dst_chunkref == self: + return + + self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) + assert self.rank == dst, "Packet copy only supports intra-rank communication" + self.prog.instr_dag.add_copy_packet(self.rank, self, dst_chunkref, sendtb) + + return dst_chunkref + + # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref + def reduce(self, other_chunkref, sendtb=-1, recvtb=-1, channel_type=ChannelType.sm): + dst = self.rank + src = other_chunkref.rank + assert self.prog.topo.link(src, dst) or src == dst, f"No link from {src} to {dst}" + self.prog.apply_reduce( + src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size + ) + + if src != dst: + self.prog.instr_dag.add_read_reduce(dst, other_chunkref, self, recvtb, channel_type) + else: + self.prog.instr_dag.add_reduce(src, other_chunkref, self, sendtb, ChannelType.none) + + return self + + # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref + def reduce_packet(self, other_chunkref, sendtb=-1): + dst = self.rank + src = other_chunkref.rank + assert dst == src, "Packet reduce only supports intra-rank communication" + self.prog.apply_reduce( + src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size + ) + self.prog.instr_dag.add_reduce_packet(src, other_chunkref, self, sendtb) + return self + + def get_origin_index(self, index=0): + return self._get_chunk(index + self.index).origin_index + + def get_origin_rank(self, index=0): + return self._get_chunk(index + self.index).origin_rank + + def get_dst_index(self, index=0): + return self._get_chunk(index + self.index).dst_index + + def get_dst_rank(self, index=0): + return self._get_chunk(index + self.index).dst_rank + + def print_chunk_info(self, index=0): + print(self._get_chunk(index + self.index)) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..3d74b6c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,4 @@ +[tool.black] +line-length = 120 +target-version = ['py38'] +include = '\.pyi?$' From 52fd030b4a0dd7368658f095179a3e88d7936c3b Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 8 Apr 2024 08:40:22 +0000 Subject: [PATCH 30/76] update --- msccl/language/__init__.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/msccl/language/__init__.py b/msccl/language/__init__.py index 7b89990..da6cbf4 100755 --- a/msccl/language/__init__.py +++ b/msccl/language/__init__.py @@ -260,16 +260,6 @@ def reduce(self, other_chunkref, sendtb=-1, recvtb=-1, ch=-1): return self - # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref - def reduce_packet(self, other_chunkref, sendtb=-1): - # Receive reduce copy - dst = self.rank - src = other_chunkref.rank - assert dst == src, 'Packet reduce only supports intra-rank communication' - self.prog.apply_reduce(src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size) - self.prog.instr_dag.add_reduce_packet(src, other_chunkref, self, sendtb) - return self - def get_origin_index(self, index=0): return self._get_chunk(index + self.index).origin_index From 7dd76b6949b5fe8f70265862b848aca6543f5afb Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 8 Apr 2024 09:03:33 +0000 Subject: [PATCH 31/76] WIP --- .../allreduce_a100_allpairs_packet_mscclpp.py | 1 - .../allreduce_a100_allpairs_sm_mscclpp.py | 2 +- .../allreduce_a100_allpairs_sm_mscclpp_get.py | 21 +++++--- .../mscclang/allreduce_a100_ring_mscclpp.py | 34 +++++++------ examples/mscclang/put_mscclpp.py | 50 ------------------- msccl/language/ir.py | 9 ++-- msccl/language/mscclpp.py | 4 +- 7 files changed, 41 insertions(+), 80 deletions(-) delete mode 100644 examples/mscclang/put_mscclpp.py diff --git a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py index 1e5ae90..abd0a54 100644 --- a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py +++ b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py @@ -18,7 +18,6 @@ def allreduce_allpairs(gpus, instances): collective, instances, protocol="LL", - dependence_nop=True, ): # Each rank sends the nth chunk to the nth rank into scratch space diff --git a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py index 26138c9..74ae223 100644 --- a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py +++ b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py @@ -12,7 +12,7 @@ def allreduce_allpairs(gpus, instances, protocol): chunksperloop = gpus * gpus topology = fully_connected(size) collective = AllReduce(size, chunksperloop, True) - with MSCCLPPProgram("allreduce_pairs", topology, collective, instances, protocol=protocol, dependence_nop=True): + with MSCCLPPProgram("allreduce_pairs", topology, collective, instances, protocol=protocol): for rank in range(size): for tb in range(size): index = rank * size diff --git a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py index aa7c39e..efca998 100644 --- a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py +++ b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py @@ -6,13 +6,19 @@ from msccl.topologies import * from msccl.language.collectives import AllReduce + def allreduce_allpairs(gpus, instances, protocol): size = gpus chunksperloop = gpus * gpus topology = fully_connected(size) collective = AllReduce(size, chunksperloop, True) - with MSCCLProgram("allreduce_pairs", topology, collective, instances, protocol=protocol, - interleaved_replication=False, threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True, instance_policy=InstancePolicy.dup): + with MSCCLPPProgram( + "allreduce_pairs", + topology, + collective, + instances, + protocol=protocol, + ): # Each rank sends the nth chunk to the nth rank into scratch space for rank in range(size): @@ -23,7 +29,7 @@ def allreduce_allpairs(gpus, instances, protocol): for nghr in range(size): peer_index = nghr * size if rank != nghr: - c_peer = chunk(rank, Buffer.input, peer_index+tb) + c_peer = chunk(rank, Buffer.input, peer_index + tb) c_peer.signal(nghr, Buffer.input, peer_index + tb, sendtb=tb) for nghr in range(size): if rank != nghr: @@ -31,7 +37,7 @@ def allreduce_allpairs(gpus, instances, protocol): # reduce the chunks for nghr in range(size): if rank != nghr: - c.reduce_mscclpp(chunk(nghr, Buffer.input, index + tb), recvtb=tb) + c.reduce(chunk(nghr, Buffer.input, index + tb), recvtb=tb) for nghr in range(size): if rank != nghr: c.signal(nghr, Buffer.input, index + tb, sendtb=tb) @@ -52,10 +58,11 @@ def allreduce_allpairs(gpus, instances, protocol): Json() + parser = argparse.ArgumentParser() -parser.add_argument('num_gpus', type=int, help ='number of gpus') -parser.add_argument('instances', type=int, help='number of instances') -parser.add_argument('--protocol', type=str, default='Simple', choices=['Simple'], help='Protocol') +parser.add_argument("num_gpus", type=int, help="number of gpus") +parser.add_argument("instances", type=int, help="number of instances") +parser.add_argument("--protocol", type=str, default="Simple", choices=["Simple"], help="Protocol") args = parser.parse_args() diff --git a/examples/mscclang/allreduce_a100_ring_mscclpp.py b/examples/mscclang/allreduce_a100_ring_mscclpp.py index adfd627..c4b8d9b 100644 --- a/examples/mscclang/allreduce_a100_ring_mscclpp.py +++ b/examples/mscclang/allreduce_a100_ring_mscclpp.py @@ -6,26 +6,32 @@ from msccl.topologies import * from msccl.language.collectives import AllReduce + # Ring all reduce for A100s -def allreduce_ring(size, instances, protocol): +def allreduce_ring(size, instances): topology = fully_connected(size) collective = AllReduce(size, size, True) - with MSCCLProgram(f"allreduce_ring", topology, collective, instances, - protocol=protocol, threadblock_policy=ThreadblockPolicy.manual): + with MSCCLPPProgram( + f"allreduce_ring", + topology, + collective, + instances, + protocol="Simple", + ): # Reduce ring - for step in range(0, size-1): + for step in range(0, size - 1): for index in range(0, size): rank = (index + step) % size next_rank = (index + step + 1) % size c = chunk(rank, Buffer.input, index) c.signal(next_rank, Buffer.input, index, 0) prev_rank = (index + step - 1) % size - c = chunk(rank, Buffer.input, (index+size-1)%size) - c.wait(prev_rank, Buffer.input, (index+size-1)%size, 0) - c.reduce_mscclpp(chunk(prev_rank, Buffer.input, (index+size-1)%size), recvtb=0) + c = chunk(rank, Buffer.input, (index + size - 1) % size) + c.wait(prev_rank, Buffer.input, (index + size - 1) % size, 0) + c.reduce(chunk(prev_rank, Buffer.input, (index + size - 1) % size), recvtb=0) # Propagate ring - for step in range(-1, size-2): + for step in range(-1, size - 2): for index in range(0, size): rank = (index + step) % size c = chunk(rank, Buffer.input, index) @@ -33,16 +39,16 @@ def allreduce_ring(size, instances, protocol): c.put(next_rank, Buffer.input, index, sendtb=0) c.signal(next_rank, Buffer.input, index, 0) prev_rank = (index + step - 1) % size - c = chunk(rank, Buffer.input, (index+size-1)%size) - c.wait(prev_rank, Buffer.input, (index+size-1)%size, 0) + c = chunk(rank, Buffer.input, (index + size - 1) % size) + c.wait(prev_rank, Buffer.input, (index + size - 1) % size, 0) Json() # Check() + parser = argparse.ArgumentParser() -parser.add_argument('num_gpus', type=int, help ='number of gpus') -parser.add_argument('instances', type=int, help='number of instances') -parser.add_argument('--protocol', type=str, default='Simple', choices=['Simple', 'LL'], help ='protocol. Default: Simple') +parser.add_argument("num_gpus", type=int, help="number of gpus") +parser.add_argument("instances", type=int, help="number of instances") args = parser.parse_args() -allreduce_ring(args.num_gpus, args.instances, args.protocol) +allreduce_ring(args.num_gpus, args.instances) diff --git a/examples/mscclang/put_mscclpp.py b/examples/mscclang/put_mscclpp.py deleted file mode 100644 index cdf2691..0000000 --- a/examples/mscclang/put_mscclpp.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import argparse -from msccl.language import * -from msccl.topologies import * -from msccl.language.collectives import AllReduce - -def allreduce_allpairs(gpus, instances, protocol): - size = gpus - chunksperloop = gpus - topology = fully_connected(size) - collective = AllReduce(size, chunksperloop, True) - with MSCCLProgram("allreduce_pairs", topology, collective, instances, protocol=protocol, - interleaved_replication=False, threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True): - - for rank in range(gpus): - c = chunk(rank, Buffer.input, rank, size=1) - for i in range(gpus - 1): - peer = (rank + i + 1) % gpus - c.put(peer, Buffer.input, rank, sendtb=0) - for i in range(gpus - 1): - peer = (rank + i + 1) % gpus - c.signal(peer, Buffer.input, rank, sendtb=0) - for i in range(gpus - 1): - peer = (rank + i + 1) % gpus - c.wait(peer, Buffer.input, peer, recvtb=0) - - # c = chunk(0, Buffer.input, 0, size=1) - # c.put(1, Buffer.input, index=0, sendtb=0) - # c.put(1, Buffer.input, index=1, sendtb=0) - # c.signal(1, Buffer.input, index=0, sendtb=0) - # c.signal(1, Buffer.input, index=1, sendtb=0) - - # dc0 = chunk(1, Buffer.input, 1, size=1) - # dc1 = chunk(1, Buffer.input, 0, size=1) - # dc0.wait(0, Buffer.input, index=0, recvtb=1) - # dc1.wait(0, Buffer.input, index=1, recvtb=1) - - Json() - #Check() - -parser = argparse.ArgumentParser() -parser.add_argument('num_gpus', type=int, help ='number of gpus') -parser.add_argument('instances', type=int, help='number of instances') -parser.add_argument('--protocol', type=str, default='Simple', choices=['Simple', 'LL128', 'LL'], help='Protocol') - -args = parser.parse_args() - -allreduce_allpairs(args.num_gpus, args.instances, args.protocol) diff --git a/msccl/language/ir.py b/msccl/language/ir.py index 2a8a27e..b962bb3 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -120,13 +120,14 @@ class Instruction(Enum): read_reduce_copy = "rrc" read_reduce_copy_send = "rrcs" reduce_send = 'rs' - reduce_send_packet = 'rspkt' copy = 'cpy' - copy_packet = 'cpkt' reduce = 're' - reduce_packet = 'rpkt' delete = 'd' start = 'st' + # used by mscclpp only + copy_packet = 'cpkt' + reduce_send_packet = 'rspkt' + reduce_packet = 'rpkt' put = 'put' put_packet = 'ppkt' get = 'get' @@ -440,7 +441,7 @@ def ir_to_xml(program: Program, old_format=True, use_scratch=True, pretty_print= ET.indent(algo_elem, space=' ') return ET.tostring(algo_elem, encoding='unicode') -def ir_to_json(program: Program, dependence_nop=False): +def ir_to_json(program: Program): # Figure out sizes of buffers based on usage buffer_sizes = defaultdict(lambda: 0) for gpu in program.gpus: diff --git a/msccl/language/mscclpp.py b/msccl/language/mscclpp.py index c329625..9cd8f55 100644 --- a/msccl/language/mscclpp.py +++ b/msccl/language/mscclpp.py @@ -30,7 +30,6 @@ def __init__( instances: int, protocol: str = "Simple", instr_fusion: bool = True, - dependence_nop: bool = False, instance_policy: InstancePolicy = InstancePolicy.dup, ): self.name = name @@ -40,7 +39,6 @@ def __init__( self.instances = instances self.protocol = protocol self.instr_fusion = instr_fusion - self.dependence_nop = dependence_nop self.instance_policy = instance_policy assert protocol == "Simple" or protocol == "LL", f"Given protocol: {protocol}. Must be either Simple, LL" self.run_opt = True # Runs optimization passes @@ -125,7 +123,7 @@ def lower(self): ) def generate_json(self): - return ir_to_json(self.lower(), dependence_nop=self.dependence_nop) + return ir_to_json(self.lower()) def Json(): From 451f31d7f86f9fe26bd591a6cbf63eae3f5092af Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 8 Apr 2024 09:27:46 +0000 Subject: [PATCH 32/76] WIP --- .../mscclang/allreduce_a100_ring_mscclpp.py | 1 - msccl/language/ir.py | 234 -------------- msccl/language/ir_mscclpp.py | 286 ++++++++++++++++++ msccl/language/mscclpp.py | 2 +- 4 files changed, 287 insertions(+), 236 deletions(-) create mode 100644 msccl/language/ir_mscclpp.py diff --git a/examples/mscclang/allreduce_a100_ring_mscclpp.py b/examples/mscclang/allreduce_a100_ring_mscclpp.py index c4b8d9b..8801ac5 100644 --- a/examples/mscclang/allreduce_a100_ring_mscclpp.py +++ b/examples/mscclang/allreduce_a100_ring_mscclpp.py @@ -43,7 +43,6 @@ def allreduce_ring(size, instances): c.wait(prev_rank, Buffer.input, (index + size - 1) % size, 0) Json() - # Check() parser = argparse.ArgumentParser() diff --git a/msccl/language/ir.py b/msccl/language/ir.py index b962bb3..160ecd4 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -251,10 +251,6 @@ def __repr__(self): Instruction.recv_reduce_copy, Instruction.copy, Instruction.reduce, Instruction.recv_reduce_copy_send} -_local_src_insts_mscclpp = {Instruction.put, Instruction.signal, Instruction.copy, Instruction.reduce, Instruction.reduce_send} -_local_dst_insts_mscclpp = {Instruction.get, Instruction.wait, Instruction.read_reduce_copy, Instruction.copy, Instruction.reduce, Instruction.read_reduce_copy_send, Instruction.reduce_send} - - def ir_to_xml(program: Program, old_format=True, use_scratch=True, pretty_print=True, dependence_nop=False): # Figure out sizes of buffers based on usage buffer_sizes = defaultdict(lambda: 0) @@ -440,233 +436,3 @@ def ir_to_xml(program: Program, old_format=True, use_scratch=True, pretty_print= if pretty_print: ET.indent(algo_elem, space=' ') return ET.tostring(algo_elem, encoding='unicode') - -def ir_to_json(program: Program): - # Figure out sizes of buffers based on usage - buffer_sizes = defaultdict(lambda: 0) - for gpu in program.gpus: - for tb in gpu.threadblocks: - for op in tb.ops: - if op.inst in _local_src_insts_mscclpp: - key = (gpu.rank, op.src.buffer) - buffer_sizes[key] = max( - buffer_sizes[key], op.src.index + op.src.size) - for src in op.srcs: - key = (gpu.rank, src.buffer) - buffer_sizes[key] = max( - buffer_sizes[key], src.index + src.size) - if op.inst in _local_dst_insts_mscclpp: - key = (gpu.rank, op.dst.buffer) - buffer_sizes[key] = max( - buffer_sizes[key], op.dst.index + op.dst.size) - # ignore remote buffers - if op.inst != Instruction.read_reduce_copy_send and op.inst != Instruction.reduce_send: - for dst in op.dsts: - key = (gpu.rank, dst.buffer) - buffer_sizes[key] = max( - buffer_sizes[key], dst.index + dst.size) - for gpu in program.gpus: - gpu.input_chunks = max(buffer_sizes[(gpu.rank, Buffer.input)], gpu.input_chunks) - gpu.output_chunks = max(buffer_sizes[(gpu.rank, Buffer.output)], gpu.output_chunks) - gpu.scratch_chunks = max(buffer_sizes[(gpu.rank, Buffer.scratch)], gpu.scratch_chunks) - - # get channel info for each GPU and threadblock - for gpu in program.gpus: - gpu.threadblocks = sorted(gpu.threadblocks, key=lambda tb: tb.id) - chan_dict = {} - # the channel key is the tuple (srcBuffer, dstBuffer, type) - for tb in gpu.threadblocks: - for ch in tb.channels: - key = (ch.srcBuffer, ch.dstBuffer, ch.type) - if key not in chan_dict: - chan_dict[key] = [(tb.id, ch.connected_to)] - else: - chan_dict[key].append((tb.id, ch.connected_to)) - for key, value in chan_dict.items(): - chan_dict[key] = sorted(value) - gpu.channels = chan_dict - - # Remove the dependencies of wait after signal. They are actually depends on remote chunk - for gpu in program.gpus: - for tb in gpu.threadblocks: - for op in tb.ops: - if op.inst == Instruction.wait: - op.depends = list(filter(lambda dep: dep.inst != Instruction.signal, op.depends)) - - # Filter out redundant dependencies - # e.g. if op1 and op2 depend on op, and op1 happends before op2 - # then op2 does not need to explicitly depend on op - for gpu in program.gpus: - for tb in gpu.threadblocks: - running_depends = [] - for op in tb.ops: - op.depends = list( - filter(lambda dep: dep not in running_depends, op.depends)) - running_depends = running_depends + op.depends - - # Do some additional postprocessing of operations: - # - Expand operations with dependencies with no-ops - if program.protocol != "LL": # (TODO(binyli): fix it) ignore the dependence_nop for LL protocol - for gpu in program.gpus: - for tb in gpu.threadblocks: - new_ops = [] - for op in tb.ops: - # Expand extra dependencies into nop operations - for i, dep in enumerate(op.depends): - new_ops.append(Op(Instruction.nop, -1, None, None, [dep])) - #op_tb_id[new_ops[-1]] = op_tb_id[op] - new_ops.append(op) - tb.ops = new_ops - - # update step and tid for ops - for gpu in program.gpus: - for tb in gpu.threadblocks: - for i, op in enumerate(tb.ops): - op.step = i - op.tb = tb.id - - # Need to calculate channel info for each GPU - nchannels = 0 - for gpu in program.gpus: - max_tb_channels = 0 - if len(gpu.threadblocks) > 0: - max_tb_channels = max(tb.channel+1 for tb in gpu.threadblocks) - nchannels = max(nchannels, max_tb_channels) - return dump_to_json(program) - -def dump_to_json(program: Program): - gpus = [] - - def get_channel_ids(chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_type): - channel_ids = [] - for c in chunk_list: - key = (src_buffer, dst_buffer, chan_type) - channel_ids.extend([{"id": id, "off": c.index} for id, ele in enumerate(tb_channel_dict[key]["connectedTo"]) if ele == c.rank]) - return channel_ids - - def remove_empty_fields(d): - return {k: v for k, v in d.items() if v not in [None, "", [], {}]} - - for id, gpu in enumerate(program.gpus): - gpu_instance = { - 'id': id, - 'inputChunks': gpu.input_chunks, - 'outputChunks': gpu.output_chunks, - 'scratchChunks': gpu.scratch_chunks, - 'threadblocks': [], - "channels": [] - } - for (srcBuffer, dstBuffer, type), channels in gpu.channels.items(): - obj = { - "srcbuff": srcBuffer.value if hasattr(srcBuffer, 'value') else srcBuffer, - "dstbuff": dstBuffer.value if hasattr(dstBuffer, 'value') else dstBuffer, - "type": type.value, - "connectedTo": [eles[1] for eles in channels] - } - gpu_instance["channels"].append(obj) - gpu_instance["channels"] = list(filter(lambda x: x["type"] != "none", gpu_instance["channels"])) - for tb in gpu.threadblocks: - if tb.id < 0: - continue - ops = [] - tb_channels = [] - tb_channel_dict = {} - for (srcBuffer, dstBuffer, type), channels in gpu.channels.items(): - obj = { - "srcbuff": srcBuffer.value if hasattr(srcBuffer, 'value') else srcBuffer, - "dstbuff": dstBuffer.value if hasattr(dstBuffer, 'value') else dstBuffer, - "type": type.value, - "chanIds": [id for id, ele in enumerate(channels) if ele[0] == tb.id], - "connectedTo": [ele[1] for ele in channels if ele[0] == tb.id], - } - tb_channel_dict[(srcBuffer, dstBuffer, type)] = obj - tb_channels.append(obj) - tb_channels = filter(lambda x: x["type"] != "none", tb_channels) - for op in tb.ops: - o_buff = None - i_buff = None - dst_channel_ids = [] - src_channel_ids = [] - srcs = [] - src = None - dst = None - if op.tb == -1: - continue - if op.inst == Instruction.signal: - # get dst channel ids - dst_channel_ids = get_channel_ids(op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) - o_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} - elif op.inst == Instruction.wait: - # get src channel ids - src_channel_ids = get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) - i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} - elif op.inst == Instruction.read_reduce_copy: - src_channel_ids = get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) - i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} - dst = op.dst - src = op.dst # TODO(binyli): fix this - elif op.inst == Instruction.read_reduce_copy_send: - src_channel_ids = get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) - dst_channel_ids = get_channel_ids(op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, op.channel_type) - i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} - o_buff = {"src": op.dst.buffer.value, "dst": op.dsts[0].buffer.value} - dst = op.dst - src = op.dst # TODO(binyli): fix this - elif op.inst == Instruction.reduce_send or op.inst == Instruction.reduce_send_packet: - dst_channel_ids = get_channel_ids(op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, ChannelType.sm) - o_buff = {"src": op.dst.buffer.value, "dst": op.dsts[0].buffer.value} - srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs)) - dst = op.dst - src = op.dst # TODO(binyli): fix this - elif op.inst == Instruction.reduce: - srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs)) - dst = op.dst - elif op.inst == Instruction.nop: - instr = { - "name": op.inst.value, - "deps": list(map(lambda dep: {"tb": dep.tb, "step": dep.step}, op.depends)) - } - elif op.inst == Instruction.put or op.inst == Instruction.put_packet: - dst_channel_ids = get_channel_ids([op.dst], tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) - o_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} - src = op.src - elif op.inst == Instruction.get: - src_channel_ids = get_channel_ids([op.src], tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) - i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} - dst = op.dst - elif op.inst == Instruction.copy or op.inst == Instruction.copy_packet: - src = op.src - dst = op.dst - if op.inst != Instruction.nop: - instr = { - "name": op.inst.value, - "i_buff": i_buff, - "i_cids": src_channel_ids, - "o_buff": o_buff, - "o_cids": dst_channel_ids, - "src": src.rank if src else None, - "srcs": srcs if srcs else None, - "srcbuff": src.buffer.value if src and src.buffer else None, - "srcoff": src.index if src else None, - "dst": dst.rank if dst else None, - "dstbuff": dst.buffer.value if dst and dst.buffer else None, - "dstoff": dst.index if dst else None, - "ctype": op.channel_type.value, - "cnt": op.cnt(), - } - ops.append(remove_empty_fields(instr)) - threadblock = { - 'id': tb.id, - 'ops': ops, - 'channels': list(map(lambda x: {"src": x["srcbuff"], "dst": x["dstbuff"], "ctype": x["type"], "cids": x["chanIds"]}, tb_channels)) - } - gpu_instance['threadblocks'].append(threadblock) - gpus.append(gpu_instance) - obj = { - 'name': program.name, - 'colletive': program.collective, - 'protocol': program.protocol, - 'inplace': program.inplace, - 'gpus': gpus - } - return json.dumps(obj, indent=2) diff --git a/msccl/language/ir_mscclpp.py b/msccl/language/ir_mscclpp.py new file mode 100644 index 0000000..dbc1aac --- /dev/null +++ b/msccl/language/ir_mscclpp.py @@ -0,0 +1,286 @@ +from collections import defaultdict +import json + +from msccl.language.ir import Buffer, ChannelType, Instruction, Program + +_local_src_insts_mscclpp = { + Instruction.put, + Instruction.put_packet, + Instruction.signal, + Instruction.copy, + Instruction.copy_packet, + Instruction.reduce, + Instruction.reduce_packet, + Instruction.reduce_send, + Instruction.reduce_send_packet, +} +_local_dst_insts_mscclpp = { + Instruction.get, + Instruction.wait, + Instruction.read_reduce_copy, + Instruction.copy, + Instruction.copy_packet, + Instruction.reduce, + Instruction.read_reduce_copy_send, + Instruction.reduce_send, + Instruction.reduce_packet, + Instruction.reduce_send_packet, +} + + +def ir_to_json(program: Program): + # Figure out sizes of buffers based on usage + buffer_sizes = defaultdict(lambda: 0) + for gpu in program.gpus: + for tb in gpu.threadblocks: + for op in tb.ops: + if op.inst in _local_src_insts_mscclpp: + key = (gpu.rank, op.src.buffer) + buffer_sizes[key] = max(buffer_sizes[key], op.src.index + op.src.size) + for src in op.srcs: + key = (gpu.rank, src.buffer) + buffer_sizes[key] = max(buffer_sizes[key], src.index + src.size) + if op.inst in _local_dst_insts_mscclpp: + key = (gpu.rank, op.dst.buffer) + buffer_sizes[key] = max(buffer_sizes[key], op.dst.index + op.dst.size) + # ignore remote buffers + if ( + op.inst != Instruction.read_reduce_copy_send + and op.inst != Instruction.reduce_send + and op.inst != Instruction.reduce_send_packet + ): + for dst in op.dsts: + key = (gpu.rank, dst.buffer) + buffer_sizes[key] = max(buffer_sizes[key], dst.index + dst.size) + for gpu in program.gpus: + gpu.input_chunks = max(buffer_sizes[(gpu.rank, Buffer.input)], gpu.input_chunks) + gpu.output_chunks = max(buffer_sizes[(gpu.rank, Buffer.output)], gpu.output_chunks) + gpu.scratch_chunks = max(buffer_sizes[(gpu.rank, Buffer.scratch)], gpu.scratch_chunks) + + # get channel info for each GPU and threadblock + for gpu in program.gpus: + gpu.threadblocks = sorted(gpu.threadblocks, key=lambda tb: tb.id) + chan_dict = {} + # the channel key is the tuple (srcBuffer, dstBuffer, type) + for tb in gpu.threadblocks: + for ch in tb.channels: + key = (ch.srcBuffer, ch.dstBuffer, ch.type) + if key not in chan_dict: + chan_dict[key] = [(tb.id, ch.connected_to)] + else: + chan_dict[key].append((tb.id, ch.connected_to)) + for key, value in chan_dict.items(): + chan_dict[key] = sorted(value) + gpu.channels = chan_dict + + # Remove the dependencies of wait after signal. They are actually depends on remote chunk + for gpu in program.gpus: + for tb in gpu.threadblocks: + for op in tb.ops: + if op.inst == Instruction.wait: + op.depends = list(filter(lambda dep: dep.inst != Instruction.signal, op.depends)) + + # Filter out redundant dependencies + # e.g. if op1 and op2 depend on op, and op1 happends before op2 + # then op2 does not need to explicitly depend on op + for gpu in program.gpus: + for tb in gpu.threadblocks: + running_depends = [] + for op in tb.ops: + op.depends = list(filter(lambda dep: dep not in running_depends, op.depends)) + running_depends = running_depends + op.depends + + # Do some additional postprocessing of operations: + # - Expand operations with dependencies with no-ops + if program.protocol != "LL": # (TODO(binyli): fix it) ignore the dependence_nop for LL protocol + for gpu in program.gpus: + for tb in gpu.threadblocks: + new_ops = [] + for op in tb.ops: + # Expand extra dependencies into nop operations + for i, dep in enumerate(op.depends): + new_ops.append(Op(Instruction.nop, -1, None, None, [dep])) + # op_tb_id[new_ops[-1]] = op_tb_id[op] + new_ops.append(op) + tb.ops = new_ops + + # update step and tid for ops + for gpu in program.gpus: + for tb in gpu.threadblocks: + for i, op in enumerate(tb.ops): + op.step = i + op.tb = tb.id + + # Need to calculate channel info for each GPU + nchannels = 0 + for gpu in program.gpus: + max_tb_channels = 0 + if len(gpu.threadblocks) > 0: + max_tb_channels = max(tb.channel + 1 for tb in gpu.threadblocks) + nchannels = max(nchannels, max_tb_channels) + return dump_to_json(program) + + +def dump_to_json(program: Program): + gpus = [] + + def get_channel_ids(chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_type): + channel_ids = [] + for c in chunk_list: + key = (src_buffer, dst_buffer, chan_type) + channel_ids.extend( + [ + {"id": id, "off": c.index} + for id, ele in enumerate(tb_channel_dict[key]["connectedTo"]) + if ele == c.rank + ] + ) + return channel_ids + + def remove_empty_fields(d): + return {k: v for k, v in d.items() if v not in [None, "", [], {}]} + + for id, gpu in enumerate(program.gpus): + gpu_instance = { + "id": id, + "inputChunks": gpu.input_chunks, + "outputChunks": gpu.output_chunks, + "scratchChunks": gpu.scratch_chunks, + "threadblocks": [], + "channels": [], + } + for (srcBuffer, dstBuffer, type), channels in gpu.channels.items(): + obj = { + "srcbuff": srcBuffer.value if hasattr(srcBuffer, "value") else srcBuffer, + "dstbuff": dstBuffer.value if hasattr(dstBuffer, "value") else dstBuffer, + "type": type.value, + "connectedTo": [eles[1] for eles in channels], + } + gpu_instance["channels"].append(obj) + gpu_instance["channels"] = list(filter(lambda x: x["type"] != "none", gpu_instance["channels"])) + for tb in gpu.threadblocks: + if tb.id < 0: + continue + ops = [] + tb_channels = [] + tb_channel_dict = {} + for (srcBuffer, dstBuffer, type), channels in gpu.channels.items(): + obj = { + "srcbuff": srcBuffer.value if hasattr(srcBuffer, "value") else srcBuffer, + "dstbuff": dstBuffer.value if hasattr(dstBuffer, "value") else dstBuffer, + "type": type.value, + "chanIds": [id for id, ele in enumerate(channels) if ele[0] == tb.id], + "connectedTo": [ele[1] for ele in channels if ele[0] == tb.id], + } + tb_channel_dict[(srcBuffer, dstBuffer, type)] = obj + tb_channels.append(obj) + tb_channels = filter(lambda x: x["type"] != "none", tb_channels) + for op in tb.ops: + o_buff = None + i_buff = None + dst_channel_ids = [] + src_channel_ids = [] + srcs = [] + src = None + dst = None + if op.tb == -1: + continue + if op.inst == Instruction.signal: + # get dst channel ids + dst_channel_ids = get_channel_ids( + op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type + ) + o_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} + elif op.inst == Instruction.wait: + # get src channel ids + src_channel_ids = get_channel_ids( + op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type + ) + i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} + elif op.inst == Instruction.read_reduce_copy: + src_channel_ids = get_channel_ids( + op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type + ) + i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} + dst = op.dst + src = op.dst # TODO(binyli): fix this + elif op.inst == Instruction.read_reduce_copy_send: + src_channel_ids = get_channel_ids( + op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type + ) + dst_channel_ids = get_channel_ids( + op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, op.channel_type + ) + i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} + o_buff = {"src": op.dst.buffer.value, "dst": op.dsts[0].buffer.value} + dst = op.dst + src = op.dst # TODO(binyli): fix this + elif op.inst == Instruction.reduce_send or op.inst == Instruction.reduce_send_packet: + dst_channel_ids = get_channel_ids( + op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, ChannelType.sm + ) + o_buff = {"src": op.dst.buffer.value, "dst": op.dsts[0].buffer.value} + srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs)) + dst = op.dst + src = op.dst # TODO(binyli): fix this + elif op.inst == Instruction.reduce: + srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs)) + dst = op.dst + elif op.inst == Instruction.nop: + instr = { + "name": op.inst.value, + "deps": list(map(lambda dep: {"tb": dep.tb, "step": dep.step}, op.depends)), + } + elif op.inst == Instruction.put or op.inst == Instruction.put_packet: + dst_channel_ids = get_channel_ids( + [op.dst], tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type + ) + o_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} + src = op.src + elif op.inst == Instruction.get: + src_channel_ids = get_channel_ids( + [op.src], tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type + ) + i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} + dst = op.dst + elif op.inst == Instruction.copy or op.inst == Instruction.copy_packet: + src = op.src + dst = op.dst + if op.inst != Instruction.nop: + instr = { + "name": op.inst.value, + "i_buff": i_buff, + "i_cids": src_channel_ids, + "o_buff": o_buff, + "o_cids": dst_channel_ids, + "src": src.rank if src else None, + "srcs": srcs if srcs else None, + "srcbuff": src.buffer.value if src and src.buffer else None, + "srcoff": src.index if src else None, + "dst": dst.rank if dst else None, + "dstbuff": dst.buffer.value if dst and dst.buffer else None, + "dstoff": dst.index if dst else None, + "ctype": op.channel_type.value, + "cnt": op.cnt(), + } + ops.append(remove_empty_fields(instr)) + threadblock = { + "id": tb.id, + "ops": ops, + "channels": list( + map( + lambda x: {"src": x["srcbuff"], "dst": x["dstbuff"], "ctype": x["type"], "cids": x["chanIds"]}, + tb_channels, + ) + ), + } + gpu_instance["threadblocks"].append(threadblock) + gpus.append(gpu_instance) + obj = { + "name": program.name, + "colletive": program.collective, + "protocol": program.protocol, + "inplace": program.inplace, + "gpus": gpus, + } + return json.dumps(obj, indent=2) diff --git a/msccl/language/mscclpp.py b/msccl/language/mscclpp.py index 9cd8f55..7d8e430 100644 --- a/msccl/language/mscclpp.py +++ b/msccl/language/mscclpp.py @@ -3,7 +3,7 @@ from msccl.collectives import Collective from msccl.language.buffer import * -from msccl.language.ir import * +from msccl.language.ir_mscclpp import * from msccl.language.rank_dag import * from msccl.language.tb_assignment import * from msccl.topologies.topology import Topology From b2ceb138ca33d4909aa9483058bda2e869452a09 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 8 Apr 2024 09:46:16 +0000 Subject: [PATCH 33/76] WIP --- msccl/language/buffer.py | 26 +++++++++++++++++++++---- msccl/language/channel.py | 25 ++++++++++++++++++++++++ msccl/language/ir.py | 37 +++--------------------------------- msccl/language/ir_mscclpp.py | 6 +++--- msccl/language/rank_dag.py | 6 +----- 5 files changed, 54 insertions(+), 46 deletions(-) create mode 100644 msccl/language/channel.py diff --git a/msccl/language/buffer.py b/msccl/language/buffer.py index c0ab297..cc2f01c 100755 --- a/msccl/language/buffer.py +++ b/msccl/language/buffer.py @@ -1,17 +1,20 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from enum import Enum + + # Scratch buffer slice with manual indexing class BufferSlice: def __init__(self, buf, name): self.name = name self.buf = buf - self.offset = -1 # Offset into the global scratch buffer + self.offset = -1 # Offset into the global scratch buffer self.chunks = [] # Returns the global index into the scratch buffer def get_global_index(self, index): - assert (self.offset > -1), 'set_offset needs to be called first' + assert self.offset > -1, "set_offset needs to be called first" return self.offset + index def get_buffer(self): @@ -25,7 +28,7 @@ def set_offset(self, offset): def __getitem__(self, index): return self.chunks[index] - + def __setitem__(self, index, value): current_size = len(self.chunks) while index > current_size: @@ -37,4 +40,19 @@ def __setitem__(self, index, value): self.chunks[index] = value def __len__(self): - return len(self.chunks) \ No newline at end of file + return len(self.chunks) + + +class Buffer(Enum): + input = "i" + output = "o" + scratch = "s" + + def __str__(self): + return self.value + + def __lt__(self, other): + return self.value < other.value + + def __gt__(self, other): + return self.value < other.value diff --git a/msccl/language/channel.py b/msccl/language/channel.py new file mode 100644 index 0000000..fb97b7e --- /dev/null +++ b/msccl/language/channel.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from dataclasses import dataclass +from enum import Enum + +from msccl.language.buffer import Buffer + + +class ChannelType(Enum): + proxy = "proxy" + sm = "sm" + none = "none" + + def __str__(self): + return self.value + + +@dataclass(frozen=True) +class Channel: + srcBuffer: Buffer + dstBuffer: Buffer + type: ChannelType + connected_to: int diff --git a/msccl/language/ir.py b/msccl/language/ir.py index 160ecd4..e4e2198 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -1,12 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import json from lxml import etree as ET from dataclasses import dataclass, field from enum import Enum from collections import defaultdict +from msccl.language.buffer import Buffer +from msccl.language.channel import ChannelType + @dataclass class Program: @@ -36,36 +38,6 @@ class Gpu: def scratch_size(self): return max((idx for addr, idx in self.scratch.items()), default=-1) + 1 -class ChannelType(Enum): - proxy = 'proxy' - sm = 'sm' - none = 'none' - - def __str__(self): - return self.value - -class Buffer(Enum): - input = 'i' - output = 'o' - scratch = 's' - - def __str__(self): - return self.value - - def __lt__(self, other): - return self.value < other.value - - def __gt__(self, other): - return self.value < other.value - -@dataclass(frozen=True) -class Channel: - srcBuffer: Buffer - dstBuffer: Buffer - type: ChannelType - connected_to: int - - @dataclass class Threadblock: id: int = -1 @@ -138,8 +110,6 @@ class Instruction(Enum): def __str__(self): return self.value - - @dataclass class ChunkRef: rank: int @@ -150,7 +120,6 @@ class ChunkRef: def __hash__(self): return hash((self.rank, self.buffer, self.index, self.size)) - @dataclass class Op: inst: Instruction diff --git a/msccl/language/ir_mscclpp.py b/msccl/language/ir_mscclpp.py index dbc1aac..7a57a51 100644 --- a/msccl/language/ir_mscclpp.py +++ b/msccl/language/ir_mscclpp.py @@ -1,7 +1,8 @@ from collections import defaultdict +from dataclasses import dataclass import json -from msccl.language.ir import Buffer, ChannelType, Instruction, Program +from msccl.language.ir import Buffer, ChannelType, Instruction, Op, Program _local_src_insts_mscclpp = { Instruction.put, @@ -92,7 +93,7 @@ def ir_to_json(program: Program): # Do some additional postprocessing of operations: # - Expand operations with dependencies with no-ops - if program.protocol != "LL": # (TODO(binyli): fix it) ignore the dependence_nop for LL protocol + if program.protocol != "LL": # TODO(binyli): fix this. Should based on OP type not algorithm for gpu in program.gpus: for tb in gpu.threadblocks: new_ops = [] @@ -100,7 +101,6 @@ def ir_to_json(program: Program): # Expand extra dependencies into nop operations for i, dep in enumerate(op.depends): new_ops.append(Op(Instruction.nop, -1, None, None, [dep])) - # op_tb_id[new_ops[-1]] = op_tb_id[op] new_ops.append(op) tb.ops = new_ops diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index 329fdc7..e9927fc 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -1,11 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from dataclasses import dataclass -from enum import Enum -import heapq -import functools - +from msccl.language.channel import Channel from msccl.language.ir import * from msccl.language.passes import * From b683d7f2a038cc671f6e3e6d6eef6f5ca06a49a2 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 8 Apr 2024 10:59:42 +0000 Subject: [PATCH 34/76] update --- msccl/language/mscclpp.py | 45 +++++++++++++++----------------------- msccl/language/rank_dag.py | 19 ++++++++++------ 2 files changed, 30 insertions(+), 34 deletions(-) diff --git a/msccl/language/mscclpp.py b/msccl/language/mscclpp.py index 7d8e430..b6e3133 100644 --- a/msccl/language/mscclpp.py +++ b/msccl/language/mscclpp.py @@ -17,6 +17,7 @@ def _curr(): raise RuntimeError("No Program in context") return _current_program + # For msccl++ program, we have one assumption that for channel can be identified by (send_buffer, recv_buffer, type, send_tb/recv_tb) # which means the send_tb and recv_tb should be the same for a pair of signal and wait, also same for put/get operation. # If one sender what to send data to peer want to use different tb in receiver side. We need to send to same tb in receiver side first, @@ -129,6 +130,7 @@ def generate_json(self): def Json(): print(_curr().generate_json()) + @dataclass class Ref(ChunkRef): prog: MSCCLPPProgram @@ -240,8 +242,7 @@ def wait(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm): src_chunkref = self.prog.get_ref(src, buffer, index, self.size) self.prog.instr_dag.add_wait(receiver, self, src_chunkref, recvtb, chan_type) - # Copies the chunk(s) referenced by this chunkref onto Rank dst at location (buffer, index) - def copy(self, dst, buffer=None, index=-1, sendtb=-1, ch=-1): + def _copy(self, dst, buffer=None, index=-1, sendtb=-1, use_packet=False): self.prog.check_buffer_exists(dst, buffer) buffer, index = self._get_buffer_index(dst, buffer, index) @@ -252,51 +253,41 @@ def copy(self, dst, buffer=None, index=-1, sendtb=-1, ch=-1): self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) assert self.rank == dst, "Chunk copy only supports intra-rank communication" - self.prog.instr_dag.add_copy(self.rank, self, dst_chunkref, sendtb, ch) + self.prog.instr_dag.add_copy_mscclpp(self.rank, self, dst_chunkref, sendtb, use_packet) return dst_chunkref - def copy_packet(self, dst, buffer=None, index=-1, sendtb=-1): - self.prog.check_buffer_exists(dst, buffer) - buffer, index = self._get_buffer_index(dst, buffer, index) - - dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) - # Check if we are copying the chunk to the same index (easy mistake when we are using inplace) - if dst_chunkref == self: - return - - self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) - assert self.rank == dst, "Packet copy only supports intra-rank communication" - self.prog.instr_dag.add_copy_packet(self.rank, self, dst_chunkref, sendtb) + # Copies the chunk(s) referenced by this chunkref onto Rank dst at location (buffer, index) + def copy(self, dst, buffer=None, index=-1, sendtb=-1): + return self._copy(dst, buffer, index, sendtb) - return dst_chunkref + def copy_packet(self, dst, buffer=None, index=-1, sendtb=-1): + return self._copy(dst, buffer, index, sendtb, use_packet=True) - # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref - def reduce(self, other_chunkref, sendtb=-1, recvtb=-1, channel_type=ChannelType.sm): + def _reduce(self, other_chunkref, sendtb=-1, recvtb=-1, channel_type=ChannelType.sm, use_packet=False): dst = self.rank src = other_chunkref.rank assert self.prog.topo.link(src, dst) or src == dst, f"No link from {src} to {dst}" self.prog.apply_reduce( src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size ) + if use_packet: + assert src == dst, "Packet reduce only supports intra-rank communication" if src != dst: self.prog.instr_dag.add_read_reduce(dst, other_chunkref, self, recvtb, channel_type) else: - self.prog.instr_dag.add_reduce(src, other_chunkref, self, sendtb, ChannelType.none) + self.prog.instr_dag.add_reduce_mscclpp(src, other_chunkref, self, sendtb, use_packet) return self + # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref + def reduce(self, other_chunkref, sendtb=-1, recvtb=-1, channel_type=ChannelType.sm): + return self._reduce(other_chunkref, sendtb, recvtb, channel_type) + # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref def reduce_packet(self, other_chunkref, sendtb=-1): - dst = self.rank - src = other_chunkref.rank - assert dst == src, "Packet reduce only supports intra-rank communication" - self.prog.apply_reduce( - src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size - ) - self.prog.instr_dag.add_reduce_packet(src, other_chunkref, self, sendtb) - return self + return self._reduce(other_chunkref, sendtb, use_packet=True) def get_origin_index(self, index=0): return self._get_chunk(index + self.index).origin_index diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index e9927fc..c576be5 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -117,8 +117,7 @@ def add_start(self, rank, buffer, index, ref): # InstructionDAG - adds a copy node def add_copy(self, rank, send_ref, recv_ref, tb, ch): - tb_step = self._get_tb_step(rank, tb) - op = Op(Instruction.copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch, step=tb_step) + op = Op(Instruction.copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) dstbuffer = recv_ref.buffer dstindex = recv_ref.index srcbuffer = send_ref.buffer @@ -130,9 +129,12 @@ def add_copy(self, rank, send_ref, recv_ref, tb, ch): self._write(rank, dstbuffer, dstindex, size, op) return op - def add_copy_packet(self, rank, send_ref, recv_ref, tb): + def add_copy_mscclpp(self, rank, send_ref, recv_ref, tb, use_packet = False): tb_step = self._get_tb_step(rank, tb) - op = Op(Instruction.copy_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) + if use_packet: + op = Op(Instruction.copy_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) + else: + op = Op(Instruction.copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) dstbuffer = recv_ref.buffer dstindex = recv_ref.index srcbuffer = send_ref.buffer @@ -161,10 +163,13 @@ def add_reduce(self, rank, send_ref, recv_ref, tb, ch): self._write(rank, dstbuffer, dstindex, size, op, read=True) return op - # InstructionDAG - adds a redduce packet node - def add_reduce_packet(self, rank, send_ref, recv_ref, tb): + # InstructionDAG - adds a redduce node + def add_reduce_mscclpp(self, rank, send_ref, recv_ref, tb, use_packet = False): tb_step = self._get_tb_step(rank, tb) - op = Op(Instruction.reduce_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) + if use_packet: + op = Op(Instruction.reduce_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) + else: + op = Op(Instruction.reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) dstbuffer = recv_ref.buffer dstindex = recv_ref.index srcbuffer = send_ref.buffer From 3cf049ee6fe1160fa9cc37f62490d71f8c010653 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 8 Apr 2024 11:11:27 +0000 Subject: [PATCH 35/76] WIP --- msccl/language/mscclpp.py | 32 +++++++++++----------------- msccl/language/rank_dag.py | 43 ++++++++++++++++++++++++-------------- 2 files changed, 39 insertions(+), 36 deletions(-) diff --git a/msccl/language/mscclpp.py b/msccl/language/mscclpp.py index b6e3133..599f068 100644 --- a/msccl/language/mscclpp.py +++ b/msccl/language/mscclpp.py @@ -173,35 +173,27 @@ def _get_buffer_index(self, remote_rank, buffer, index): return buffer, self.prog.buffers[remote_rank][buffer].instance_size() return buffer, index - def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): + def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, use_packet=False): self.prog.check_buffer_exists(dst, buffer) - sender = self.rank - receiver = dst - assert sender != receiver, "Cannot put to the same rank" + assert self.rank != dst, "Cannot put to the same rank" buffer, index = self._get_buffer_index(dst, buffer, index) # Direct put assert self.prog.topo.link(self.rank, dst) or dst == self.rank, f"No link from {self.rank} to {dst}" dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) - self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) - self.prog.instr_dag.add_put(sender, self, dst_chunkref, sendtb, chan_type) - - def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, channel_type=ChannelType.sm): - self.prog.check_buffer_exists(dst, buffer) - sender = self.rank - receiver = dst - assert sender != receiver, "Cannot put to the same rank" - buffer, index = self._get_buffer_index(dst, buffer, index) + if use_packet: + self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, use_packet) + self.prog.instr_dag.add_signal(self.rank, self, dst_chunkref, -1, ChannelType.none) + self.prog.instr_dag.add_wait(dst, dst_chunkref, self, -1, ChannelType.none) + else: + self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type) - # Direct put - assert self.prog.topo.link(self.rank, dst) or dst == self.rank, f"No link from {self.rank} to {dst}" - dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) + def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): + self._put(dst, buffer, index, sendtb, chan_type) - self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) - self.prog.instr_dag.add_put_packet(sender, self, dst_chunkref, sendtb, channel_type) - self.prog.instr_dag.add_signal(sender, self, dst_chunkref, -1, ChannelType.none) - self.prog.instr_dag.add_wait(receiver, dst_chunkref, self, -1, ChannelType.none) + def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, channel_type=ChannelType.sm): + return self._put(dst, buffer, index, sendtb, channel_type, use_packet=True) def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm): self.prog.check_buffer_exists(src, buffer) diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index c576be5..2c0134e 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -47,6 +47,7 @@ def same_buf_src(op1, op2): def same_chan_type(op1, op2): return op1.channel_type == op2.channel_type +# TODO:(binyli): Need to treat it as base class. For MSCCLPP/MSCCL implement different methods class InstructionDAG: def __init__(self, num_ranks, buffers): self.num_ranks = num_ranks @@ -148,15 +149,13 @@ def add_copy_mscclpp(self, rank, send_ref, recv_ref, tb, use_packet = False): # InstructionDAG - adds a redduce node def add_reduce(self, rank, send_ref, recv_ref, tb, ch): - tb_step = self._get_tb_step(rank, tb) - op = Op(Instruction.reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch, step=tb_step) + op = Op(Instruction.reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) dstbuffer = recv_ref.buffer dstindex = recv_ref.index srcbuffer = send_ref.buffer srcindex = send_ref.index size = recv_ref.size prev_ops = [] - op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) # Sending part of reduce self._read(rank, srcbuffer, srcindex, size, op) # Reduce part of copy @@ -192,18 +191,32 @@ def add_send(self, rank, send_ref, recv_ref, tb, ch): return op # InstructionDAG - adds a put node - def add_put(self, rank, send_ref, recv_ref, tb, ch_type): + def add_put(self, rank, send_ref, recv_ref, tb, ch_type, use_packet = False): tb_step = self._get_tb_step(rank, tb) - op = Op(Instruction.put, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step) - buffer = send_ref.buffer - index = send_ref.index - size = send_ref.size - self._read(rank, buffer, index, size, op) - return op - - def add_put_packet(self, rank, send_ref, recv_ref, tb, ch_type): - tb_step = self._get_tb_step(rank, tb) - op = Op(Instruction.put_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step) + if use_packet: + op = Op( + Instruction.put_packet, + rank, + send_ref, + recv_ref, + next=set(), + prev=set(), + tb=tb, + channel_type=ch_type, + step=tb_step, + ) + else: + op = Op( + Instruction.put, + rank, + send_ref, + recv_ref, + next=set(), + prev=set(), + tb=tb, + channel_type=ch_type, + step=tb_step, + ) buffer = send_ref.buffer index = send_ref.index size = send_ref.size @@ -544,7 +557,6 @@ def dfs(op, cs): if op.inst == Instruction.start: dfs(op,-2) # Start instructions should start at -1 - # Given the set of operations that operate over a particular slot (rank, buffer, idx) fixed # Try and replace operations with pipelined ops like receive copy send (rcs) # or receive reduce send (rrs) and receive reduce copy send (rrcs) @@ -666,7 +678,6 @@ def lower_tbs(self): gpus.append(Gpu(rank, list(lowered_tbs.values()))) return gpus - # Automatically replicates the algorithm instance number of times # interleaved sets the replication policy # if True chunks are split as: ChunkA ChunkB -> ChunkA0 ChunkA1 .. ChunkB0 ChunkB1 ... From b1fd9520432eaa29fff49fd3f72a66493f0cbd6d Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 8 Apr 2024 11:14:07 +0000 Subject: [PATCH 36/76] Done for today --- msccl/language/mscclpp.py | 2 +- msccl/language/rank_dag.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/msccl/language/mscclpp.py b/msccl/language/mscclpp.py index 599f068..bea6502 100644 --- a/msccl/language/mscclpp.py +++ b/msccl/language/mscclpp.py @@ -112,7 +112,7 @@ def lower(self): convert_to_exectuion_plan(self.instr_dag) self.instr_dag.complete_channels() if self.instr_fusion: - self.instr_dag.optimize_mscclpp(self.protocol) + self.instr_dag.optimize_mscclpp() self.instr_dag.lower_pt1(self.instances) gpu_prgms = self.instr_dag.lower_pt2_mscclpp(self.instances, self.instance_policy) return Program( diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index 2c0134e..07daada 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -528,7 +528,7 @@ def _parallel_signal_wait(self): continue queue = queue[1:] - def optimize_mscclpp(self, protocol): + def optimize_mscclpp(self): self._optimize_redandant_signal_wait() self._optimize_rrc_r_signal_wait() self._optimize_rrcs_rs() From a4728fa37054bfc2780f86155cff4fab8bea19b8 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Fri, 19 Apr 2024 09:03:29 +0000 Subject: [PATCH 37/76] update packet algo --- .../allreduce_a100_allpairs_packet_mscclpp.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py index abd0a54..b052ec8 100644 --- a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py +++ b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py @@ -22,12 +22,13 @@ def allreduce_allpairs(gpus, instances): # Each rank sends the nth chunk to the nth rank into scratch space for r1 in range(size): - for r2 in range(size): - if r1 != r2: - for tb in range(size): - index = r2 * size + tb - c = chunk(r1, Buffer.input, index) - c.put_packet(r2, "scratch", index=r1 * size + tb, sendtb=tb) + for tb in range(size): + if tb == r1: + continue + remote_rank = tb + index = remote_rank * size + c = chunk(r1, Buffer.input, index, size) + c.put_packet(remote_rank, "scratch", index=r1*size, sendtb=tb) # Each rank performs a local reduction on the nth chunk # Utilize 8 threadblocks for this reduction for better parallelism @@ -43,11 +44,10 @@ def allreduce_allpairs(gpus, instances): # Each rank get final result from scratch space for r in range(size): - for index in range(size): - for peer in range(size): - if peer != r: - c = chunk(r, "scratch", size * size + peer * size + index) - c.copy_packet(r, Buffer.input, peer * size + index, sendtb=index) + for peer in range(size): + if peer != r: + c = chunk(r, "scratch", size * size + peer * size, size) + c.copy_packet(r, Buffer.input, peer * size, sendtb=peer) Json() # Check() From 42c4a7db779d4888ab996544796026997f4d0085 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Fri, 19 Apr 2024 09:19:14 +0000 Subject: [PATCH 38/76] fix comments --- examples/mscclang/allreduce_a100_allpairs.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/mscclang/allreduce_a100_allpairs.py b/examples/mscclang/allreduce_a100_allpairs.py index ea080bc..e6ec80d 100755 --- a/examples/mscclang/allreduce_a100_allpairs.py +++ b/examples/mscclang/allreduce_a100_allpairs.py @@ -11,9 +11,9 @@ def allreduce_allpairs(gpus, instances, protocol): chunksperloop = gpus * gpus topology = fully_connected(size) collective = AllReduce(size, chunksperloop, True) - with MSCCLProgram("allreduce_pairs", topology, collective, instances, protocol=protocol, + with MSCCLProgram("allreduce_pairs", topology, collective, instances, protocol=protocol, interleaved_replication=False, threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True): - + # Each rank sends the nth chunk to the nth rank into scratch space for r1 in range(size): for r2 in range(size): @@ -28,7 +28,7 @@ def allreduce_allpairs(gpus, instances, protocol): for index in range(0, size * (size-1)): c = chunk(r, Buffer.input, r*size + (index % size)) c.reduce(chunk(r, 'scratch', index), sendtb=(index % size)) - + # Each rank sends the fully reduced nth chunk to all other gpus for r1 in range(size): for r2 in range(size): @@ -36,7 +36,7 @@ def allreduce_allpairs(gpus, instances, protocol): index = r1 * size c = chunk(r1, Buffer.input, index, size) c.copy(r2, Buffer.input, index, sendtb=r2, recvtb=r1) - + XML() Check() @@ -47,4 +47,4 @@ def allreduce_allpairs(gpus, instances, protocol): args = parser.parse_args() -allreduce_allpairs(args.num_gpus, args.instances, args.protocol) +allreduce_allpairs(args.num_gpus, args.instances, args.protocol) \ No newline at end of file From b494b753cb719ca164928471969fc74864cb8704 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Fri, 19 Apr 2024 15:06:42 +0000 Subject: [PATCH 39/76] OPT --- .../allreduce_a100_allpairs_sm_mscclpp_get.py | 6 ++-- msccl/language/ir_mscclpp.py | 6 ++-- msccl/language/rank_dag.py | 30 +++++++++++++++++-- 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py index efca998..49f3606 100644 --- a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py +++ b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py @@ -35,7 +35,8 @@ def allreduce_allpairs(gpus, instances, protocol): if rank != nghr: c.wait(nghr, Buffer.input, index + tb, recvtb=tb) # reduce the chunks - for nghr in range(size): + for i in range(size): + nghr = (rank + i) % size if rank != nghr: c.reduce(chunk(nghr, Buffer.input, index + tb), recvtb=tb) for nghr in range(size): @@ -50,7 +51,8 @@ def allreduce_allpairs(gpus, instances, protocol): index = nghr * size c = chunk(rank, Buffer.input, index + tb) c.wait(nghr, Buffer.input, index + tb, recvtb=tb) - for nghr in range(size): + for i in range(size): + nghr = (rank + i) % size index = nghr * size if rank != nghr: c = chunk(rank, Buffer.input, index + tb) diff --git a/msccl/language/ir_mscclpp.py b/msccl/language/ir_mscclpp.py index 7a57a51..1677399 100644 --- a/msccl/language/ir_mscclpp.py +++ b/msccl/language/ir_mscclpp.py @@ -181,6 +181,7 @@ def remove_empty_fields(d): dst_channel_ids = [] src_channel_ids = [] srcs = [] + dsts = [] src = None dst = None if op.tb == -1: @@ -239,10 +240,10 @@ def remove_empty_fields(d): src = op.src elif op.inst == Instruction.get: src_channel_ids = get_channel_ids( - [op.src], tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type + op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type ) i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} - dst = op.dst + dsts = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.dsts)) elif op.inst == Instruction.copy or op.inst == Instruction.copy_packet: src = op.src dst = op.dst @@ -255,6 +256,7 @@ def remove_empty_fields(d): "o_cids": dst_channel_ids, "src": src.rank if src else None, "srcs": srcs if srcs else None, + "dsts": dsts if dsts else None, "srcbuff": src.buffer.value if src and src.buffer else None, "srcoff": src.index if src else None, "dst": dst.rank if dst else None, diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index 07daada..b66475a 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -230,6 +230,8 @@ def add_get(self, rank, send_ref, recv_ref, tb, ch_type): index = recv_ref.index size = recv_ref.size self._write(rank, buffer, index, size, op) + op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) + op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step)) return op # InstructionDAG - adds a signal node. @@ -328,7 +330,7 @@ def complete_channels(self): chans.add(chan) tb.channels = list(chans) - def _optimize_redandant_signal_wait(self): + def _optimize_redundant_signal_wait(self): # For packet ops, we can remove signal/wait for rank, rank_tbs in enumerate(self.tbs): for tbid, tb in rank_tbs.items(): @@ -489,6 +491,29 @@ def _optimize_rrcs_rs(self): continue queue = queue[1:] + # get(src, sbuf. si, dst, dbuf, di) get(src, sbuf, si, dst, dbuf, di) -> get(list[src,sbuf,si], list[dst,dbuf,di]) + # put(src, sbuf, si, dst, dbuf, di) put(src, sbuf, si, dst, dbuf, di) -> put(list[src,sbuf,si], list[dst,dbuf,di]) + def _optimize_get_put(self): + for rank, rank_tbs in enumerate(self.tbs): + for tbid, tb in rank_tbs.items(): + queue = list(tb.ops) + while len(queue) > 0: + op = queue[0] + if op.inst == Instruction.get: + fused = False + if len(queue) > 1: + seq_op = queue[1] + if seq_op.inst == Instruction.get and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op) and same_count(op, seq_op): + op.dsts.append((ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), seq_op.step)) + op.srcs.append((ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), seq_op.step)) + merge_op(op, seq_op) + tb.ops.remove(seq_op) + queue.remove(seq_op) + fused = True + if fused: + continue + queue = queue[1:] + # For signal/wait ops, if they are independent of other operations and no other operations in between, # then merge them into a single signal/wait op # wait(src,sbuf,si,_,_,_) wait(src,sbuf,si,_,_,_) -> wait(list[src,sbuf,si],_,_,_,_]) @@ -529,9 +554,10 @@ def _parallel_signal_wait(self): queue = queue[1:] def optimize_mscclpp(self): - self._optimize_redandant_signal_wait() + self._optimize_redundant_signal_wait() self._optimize_rrc_r_signal_wait() self._optimize_rrcs_rs() + self._optimize_get_put() self._parallel_signal_wait() From c2bd38fde0af301e926b482aa075c3a4315a0caa Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 22 Apr 2024 11:12:42 +0000 Subject: [PATCH 40/76] Fix --- msccl/language/ir_mscclpp.py | 6 ++++-- msccl/language/mscclpp.py | 12 ++++++------ msccl/language/rank_dag.py | 29 ++++++++++++++++++++++++++++- 3 files changed, 38 insertions(+), 9 deletions(-) diff --git a/msccl/language/ir_mscclpp.py b/msccl/language/ir_mscclpp.py index 1677399..8cbcbb4 100644 --- a/msccl/language/ir_mscclpp.py +++ b/msccl/language/ir_mscclpp.py @@ -158,6 +158,7 @@ def remove_empty_fields(d): } gpu_instance["channels"].append(obj) gpu_instance["channels"] = list(filter(lambda x: x["type"] != "none", gpu_instance["channels"])) + gpu_instance["channels"] = sorted(gpu_instance["channels"], key=lambda x: (x["srcbuff"], x["dstbuff"])) for tb in gpu.threadblocks: if tb.id < 0: continue @@ -175,6 +176,7 @@ def remove_empty_fields(d): tb_channel_dict[(srcBuffer, dstBuffer, type)] = obj tb_channels.append(obj) tb_channels = filter(lambda x: x["type"] != "none", tb_channels) + tb_channels = sorted(tb_channels, key=lambda x: (x["srcbuff"], x["dstbuff"])) for op in tb.ops: o_buff = None i_buff = None @@ -234,10 +236,10 @@ def remove_empty_fields(d): } elif op.inst == Instruction.put or op.inst == Instruction.put_packet: dst_channel_ids = get_channel_ids( - [op.dst], tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type + op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type ) o_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} - src = op.src + srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs)) elif op.inst == Instruction.get: src_channel_ids = get_channel_ids( op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type diff --git a/msccl/language/mscclpp.py b/msccl/language/mscclpp.py index bea6502..e9d5bfe 100644 --- a/msccl/language/mscclpp.py +++ b/msccl/language/mscclpp.py @@ -256,7 +256,7 @@ def copy(self, dst, buffer=None, index=-1, sendtb=-1): def copy_packet(self, dst, buffer=None, index=-1, sendtb=-1): return self._copy(dst, buffer, index, sendtb, use_packet=True) - def _reduce(self, other_chunkref, sendtb=-1, recvtb=-1, channel_type=ChannelType.sm, use_packet=False): + def _reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.sm, use_packet=False): dst = self.rank src = other_chunkref.rank assert self.prog.topo.link(src, dst) or src == dst, f"No link from {src} to {dst}" @@ -269,17 +269,17 @@ def _reduce(self, other_chunkref, sendtb=-1, recvtb=-1, channel_type=ChannelType if src != dst: self.prog.instr_dag.add_read_reduce(dst, other_chunkref, self, recvtb, channel_type) else: - self.prog.instr_dag.add_reduce_mscclpp(src, other_chunkref, self, sendtb, use_packet) + self.prog.instr_dag.add_reduce_mscclpp(src, other_chunkref, self, recvtb, use_packet) return self # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref - def reduce(self, other_chunkref, sendtb=-1, recvtb=-1, channel_type=ChannelType.sm): - return self._reduce(other_chunkref, sendtb, recvtb, channel_type) + def reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.sm): + return self._reduce(other_chunkref, recvtb, channel_type) # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref - def reduce_packet(self, other_chunkref, sendtb=-1): - return self._reduce(other_chunkref, sendtb, use_packet=True) + def reduce_packet(self, other_chunkref, recvtb=-1): + return self._reduce(other_chunkref, recvtb, use_packet=True) def get_origin_index(self, index=0): return self._get_chunk(index + self.index).origin_index diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index b66475a..119c85a 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -221,6 +221,8 @@ def add_put(self, rank, send_ref, recv_ref, tb, ch_type, use_packet = False): index = send_ref.index size = send_ref.size self._read(rank, buffer, index, size, op) + op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step)) + op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) return op def add_get(self, rank, send_ref, recv_ref, tb, ch_type): @@ -326,7 +328,7 @@ def complete_channels(self): chan = Channel(src_buffer, dst_buffer, op.channel_type, op.dst.rank) chans.add(chan) elif op.inst in recv_op: - chan = Channel(dst_buffer, src_buffer, op.channel_type, op.src.rank) + chan = Channel(src_buffer, dst_buffer, op.channel_type, op.src.rank) chans.add(chan) tb.channels = list(chans) @@ -464,6 +466,7 @@ def _optimize_rrcs_rs(self): continue if op.inst == Instruction.reduce: op.inst = Instruction.reduce_send + op.channel_type = ChannelType.sm op.dsts.append((ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size), next_op.step)) remove_op(next_op) tb.ops.remove(next_op) @@ -512,6 +515,30 @@ def _optimize_get_put(self): fused = True if fused: continue + elif op.inst == Instruction.put: + fused = False + if len(queue) > 1: + seq_op = queue[1] + if seq_op.inst == Instruction.put and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op) and same_count(op, seq_op): + op.dsts.append((ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), seq_op.step)) + op.srcs.append((ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), seq_op.step)) + merge_op(op, seq_op) + tb.ops.remove(seq_op) + queue.remove(seq_op) + fused = True + if fused: + continue + elif op.inst == Instruction.put_packet: + fused = False + if len(queue) > 1: + seq_op = queue[1] + if seq_op.inst == Instruction.put_packet and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op) and same_count(op, seq_op): + op.dsts.append((ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), seq_op.step)) + op.srcs.append((ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), seq_op.step)) + merge_op(op, seq_op) + tb.ops.remove(seq_op) + queue.remove(seq_op) + fused = True queue = queue[1:] # For signal/wait ops, if they are independent of other operations and no other operations in between, From bb3aebe802ef9941bb3577ee3346bdb28ec5a3cd Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 24 Apr 2024 06:38:43 +0000 Subject: [PATCH 41/76] WIP --- msccl/language/collectives.py | 27 +++++++++++++-------------- msccl/language/ir.py | 1 + msccl/language/ir_mscclpp.py | 1 + msccl/language/mscclpp.py | 1 + 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/msccl/language/collectives.py b/msccl/language/collectives.py index ddbacaa..dd243ec 100755 --- a/msccl/language/collectives.py +++ b/msccl/language/collectives.py @@ -3,11 +3,12 @@ from msccl.language import * class Collective(): - def __init__(self, num_ranks, chunk_factor, inplace): + def __init__(self, num_ranks, chunk_factor, inplace, num_chunk_groups = 1): self.num_ranks = num_ranks self.chunk_factor = chunk_factor self.inplace = inplace self.name = "custom" + self.num_chunk_groups = num_chunk_groups def init_buffers(self): pass @@ -35,10 +36,10 @@ def init_buffers(self): chunk = Chunk(r, index, index//self.chunk_factor, index % self.chunk_factor + r*self.chunk_factor) input_buffer[index] = chunk if self.inplace: - buffers = {Buffer.input : input_buffer, + buffers = {Buffer.input : input_buffer, Buffer.output : input_buffer} else: - buffers = {Buffer.input : input_buffer, + buffers = {Buffer.input : input_buffer, Buffer.output : output_buffer} rank_buffers.append(buffers) return rank_buffers @@ -69,7 +70,7 @@ def __init__(self, num_ranks, chunk_factor, inplace): def init_buffers(self): rank_buffers = [] if self.inplace: - # Inplace AllGather only uses the output buffer + # Inplace AllGather only uses the output buffer for r in range(self.num_ranks): output_buffer = [None] * (self.num_ranks * self.chunk_factor) for ch in range(self.chunk_factor): @@ -83,11 +84,11 @@ def init_buffers(self): output_buffer = [None] * (self.num_ranks * self.chunk_factor) for ch in range(self.chunk_factor): input_buffer[ch] = Chunk(r, ch, -1, r*self.chunk_factor+ch) - buffers = {Buffer.input : input_buffer, + buffers = {Buffer.input : input_buffer, Buffer.output : output_buffer} rank_buffers.append(buffers) return rank_buffers - + # Expected output buffer for allgather def check(self, prog): correct = True @@ -106,7 +107,7 @@ def check(self, prog): correct = False return correct - + def get_buffer_index(self, rank, buffer, index): # For inplace AllGathers, the input buffer points into the output buffer if self.inplace and buffer == Buffer.input: @@ -115,12 +116,11 @@ def get_buffer_index(self, rank, buffer, index): return buffer, index - class AllReduce(Collective): def __init__(self, num_ranks, chunk_factor, inplace): - Collective.__init__(self, num_ranks, chunk_factor, inplace) - self.name = 'allreduce' + Collective.__init__(self, num_ranks, chunk_factor, inplace, num_ranks) + self.name = "allreduce" def init_buffers(self): chunks_per_node = self.chunk_factor @@ -133,10 +133,10 @@ def init_buffers(self): input_buffer.append(Chunk(r, c, -1, c)) # Input and output buffer are the same. if self.inplace: - buffers = {Buffer.input : input_buffer, + buffers = {Buffer.input : input_buffer, Buffer.output : input_buffer} else: - buffers = {Buffer.input : input_buffer, + buffers = {Buffer.input : input_buffer, Buffer.output : output_buffer} rank_buffers.append(buffers) return rank_buffers @@ -190,7 +190,7 @@ def init_buffers(self): for i in range(self.num_ranks): for c in range(self.chunk_factor): input_buffer.append(Chunk(r, i*self.chunk_factor + c, i, c)) - buffers = {Buffer.input : input_buffer, + buffers = {Buffer.input : input_buffer, Buffer.output : output_buffer} rank_buffers.append(buffers) return rank_buffers @@ -223,4 +223,3 @@ def get_buffer_index(self, rank, buffer, index): return Buffer.input, index + rank * self.chunk_factor else: return buffer, index - diff --git a/msccl/language/ir.py b/msccl/language/ir.py index e4e2198..13556d4 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -17,6 +17,7 @@ class Program: inplace: bool protocol: str gpus: list = field(default_factory=list) + num_chunk_groups: int = 1 @dataclass diff --git a/msccl/language/ir_mscclpp.py b/msccl/language/ir_mscclpp.py index 8cbcbb4..95cc9ec 100644 --- a/msccl/language/ir_mscclpp.py +++ b/msccl/language/ir_mscclpp.py @@ -146,6 +146,7 @@ def remove_empty_fields(d): "inputChunks": gpu.input_chunks, "outputChunks": gpu.output_chunks, "scratchChunks": gpu.scratch_chunks, + "chunkGroups": program.num_chunk_groups, "threadblocks": [], "channels": [], } diff --git a/msccl/language/mscclpp.py b/msccl/language/mscclpp.py index e9d5bfe..1661925 100644 --- a/msccl/language/mscclpp.py +++ b/msccl/language/mscclpp.py @@ -121,6 +121,7 @@ def lower(self): self.collective.inplace, self.protocol, gpu_prgms, + self.collective.num_chunk_groups * self.instances ) def generate_json(self): From 40217f97853d12c0a10165149e5e1adee32ff950 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 24 Apr 2024 07:17:44 +0000 Subject: [PATCH 42/76] WIP --- msccl/language/__init__.py | 256 ++--------------------------------- msccl/language/msccl.py | 266 +++++++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- 3 files changed, 278 insertions(+), 246 deletions(-) create mode 100644 msccl/language/msccl.py diff --git a/msccl/language/__init__.py b/msccl/language/__init__.py index da6cbf4..e38be01 100755 --- a/msccl/language/__init__.py +++ b/msccl/language/__init__.py @@ -1,280 +1,46 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from dataclasses import dataclass -from enum import Enum -import functools from msccl.language.ir import * from msccl.language.passes import * from msccl.language.tb_assignment import * from msccl.language.chunk import * from msccl.language.buffer import * from msccl.language.rank_dag import * -import msccl.collectives as collectives +import msccl.language.msccl as msccl_lang import msccl.language.mscclpp as mscclpp from msccl.language.mscclpp import * +from msccl.language.msccl import * +from typing import Union # from msccl.language.visualize import * -_current_program = None + def _curr(): - global _current_program - if _current_program == None and mscclpp._current_program == None: + if msccl_lang._current_program == None and mscclpp._current_program == None: raise RuntimeError("No Program in context") - if _current_program == None: + if msccl_lang._current_program == None: return mscclpp._current_program - return _current_program - - -class MSCCLProgram: - def __init__(self, name, topo, collective, instances, protocol='Simple', \ - threadblock_policy=ThreadblockPolicy.auto, interleaved_replication=True, - instr_fusion=True, check_xml=True, dependence_nop=False): - self.name = name - self.topo = topo - self.collective = collective - self.num_ranks = topo.num_nodes() - self.instances = instances - self.protocol = protocol - self.threadblock_policy = threadblock_policy - self.interleaved_replication = interleaved_replication - self.instr_fusion = instr_fusion - self.check_xml = check_xml - self.dependence_nop = dependence_nop - assert protocol == 'Simple' or protocol == 'LL' or protocol == 'LL128', \ - f'Given protocol: {protocol}. Must be either Simple, LL, LL128' - self.run_opt = True # Runs optimization passes - # Initialize the input buffers - # self.chunk_dag = ChunkDAG() - self.buffers = collective.init_buffers() - self.instr_dag = InstructionDAG(self.num_ranks, self.buffers) - for r in range(self.num_ranks): - for index, chunk in enumerate(self.buffers[r][Buffer.input]): - buffer, index = self.collective.get_buffer_index(r, Buffer.input, index) - ref = self.get_ref(r, buffer, index, 1) - # self.chunk_dag.init_chunk(chunk, ref) - self.instr_dag.add_start(r, buffer, index, ref) - - def __enter__(self): - global _current_program - if _current_program != None: - raise RuntimeError("There is already a MSCCL Program in context") - _current_program = self - - def __exit__(self, exc_type, exc_value, exc_traceback): - global _current_program - if _current_program != self: - raise RuntimeError("This program is not currently in context") - _current_program = None - - # Tracks a send operation on the buffers - def apply_send(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size): - src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index) - dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index) - sb = self.buffers[src][src_buffer] - db = self.buffers[dst][dst_buffer] - for i in range(size): - db[dst_index + i] = sb[src_index + i] - - # Tracks a reduce operation on the buffers - def apply_reduce(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size): - src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index) - dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index) - sb = self.buffers[src][src_buffer] - db = self.buffers[dst][dst_buffer] - for i in range(size): - reduce_chunk = db[dst_index + i] - sent_chunk = sb[src_index + i] - db[dst_index + i] = reduce_chunk.reduce(dst, sent_chunk) - - def get_ref(self, rank, buffer, index, size): - buffer, index = self.collective.get_buffer_index(rank, buffer, index) - return Ref(rank, buffer, index, size, self) - - def get_chunks(self, rank, buffer, index, size=1): - chunks = [None] * size - for i in range(0, size): - if self.buffers[rank][buffer] and index+i < len(self.buffers[rank][buffer]): - chunks[i] = self.buffers[rank][buffer][index+i] - else: - chunks[i] = None - return chunks - - def check_buffer_exists(self, rank, name): - if name not in self.buffers[rank]: - self.buffers[rank][name] = BufferSlice(Buffer.scratch, name) - - # Checks that all chunks that should be on each rank - # are present in the output buffer. - def check(self): - return self.collective.check(self) - - # Lower program to XML - def lower(self): - # self.chunk_dag._complete_metadata() - # self.chunk_dag.channel_assignment() - # self.chunk_dag.lower_instr_dag(self.instr_dag) - self.instr_dag.convert_set_list() # Pre-emptively convert sets to lists - if self.instr_fusion: - self.instr_dag.optimize() - self.instr_dag._complete_metadata() - if self.threadblock_policy == ThreadblockPolicy.manual: - manual_assign_tbs(self.instr_dag) - else: - auto_assign_tbs(self.instr_dag) - self.instr_dag.lower_pt1(self.instances) - gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.interleaved_replication) - if self.check_xml: - # Check generated MSCCL-IR for correctness - no circular dependencies, sends and receives are ordered - # For very large programs, turn off check_xml when shipping - check_dependency_cycles(self.instr_dag.tbs) - check_threadblock_ordering(self.instr_dag) - return Program(self.name, self.collective.name, self.collective.inplace, self.protocol, gpu_prgms) - - def generate_xml(self): - return ir_to_xml(self.lower(), dependence_nop=self.dependence_nop) - - def print_chunk_dag(self): - visualize_chunk_dag(self.chunk_dag.chunk_paths) - - def print_instr_dags(self, rank): - if rank == 0: - for r in range(len(self.ranks)): - visualize_instr_dag(self.instr_dags[r].operations) - else: - visualize_instr_dag(self.instr_dags[rank].operations) + return msccl_lang._current_program + def Print(): _curr().print_chunk_dag() -def chunk(rank, buffer, index, size=1): + +def chunk(rank, buffer, index, size=1) -> Union[mscclpp.Ref, msccl_lang.Ref]: if _curr().buffers[rank][buffer][index] is None: return None return _curr().get_ref(rank, buffer, index, size) + def create_scratch(rank, name): return _curr().create_scratch(rank, name) -def XML(): - print(_curr().generate_xml()) - def Check(): return _curr().check() -@dataclass -class Ref(ChunkRef): - prog: MSCCLProgram - - def __repr__(self): - return f'Ref(Buffer:{self.buffer}, Index:{self.index}, Size:{self.size}, Rank:{self.rank})' - - def _end(self): - return self.index + self.size - - def _get_chunk(self, index): - return self.prog.buffers[self.rank][self.buffer][index] - - def split(self, num): - assert (self.size % num == 0), f'Trying to split a chunk of {self.size} elements into {num} parts' - chunks = [None] * num - size = self.size // num - for i in range(num): - index = self.index + i * size - chunks[i] = self.prog.get_ref(self.rank, self.buffer, index, size) - return chunks - - def group(self, other): - assert (self.rank == other.rank), f'Trying to concatenate chunks on ranks {self.rank} and {other.rank}' - assert (self.buffer == other.buffer), f'Trying to concatenate chunks in {self.buffer} and {other.buffer}' - if self.index < other.index: - first = self - second = other - else: - first = other - second = self - - end = max(first._end(), second._end()) - return Ref(self.rank, self.buffer, first.index, end - first.index, self.prog) - - # Copies the chunk(s) referenced by this chunkref onto Rank dst at location (buffer, index) - def copy(self, dst, buffer=None, index=-1, sendtb=-1, recvtb=-1, ch=-1): - self.prog.check_buffer_exists(dst, buffer) - - # If index is not specified assume it is going to the same place in the next gpu - if index == -1 and buffer == None: - index = self.index - buffer = self.buffer - elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output: - index = self.prog.buffers[dst][buffer].instance_size() - - # Some inplace collectives have custom logic for buffers and index (ReduceScatter, AllGather) - buffer, index = self.prog.collective.get_buffer_index(self.rank, buffer, index) - - # Direct send - assert (self.prog.topo.link(self.rank, dst) or dst == self.rank), f'No link from {self.rank} to {dst}' - dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) - - # Check if we are copying the chunk to the same index (easy mistake when we are using inplace) - if dst_chunkref == self: - return - - # chunks = self.prog.get_chunks(self.rank, self.buffer, self.index, self.size) - # overwritten_chunks = self.prog.get_chunks(dst, buffer, index, self.size) - - self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) - - # self.prog.chunk_dag.add_send(chunks, overwritten_chunks, self, dst_chunkref, sendtb, recvtb, ch) - sender = self.rank - receiver = dst - if sender != receiver: - sop = self.prog.instr_dag.add_send(sender, self, dst_chunkref, sendtb, ch) - rop = self.prog.instr_dag.add_recv(receiver, self, dst_chunkref, recvtb, ch, sop) - sop.recv_match = rop - else: - self.prog.instr_dag.add_copy(sender, self, dst_chunkref, sendtb, ch) - - return dst_chunkref - - # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref - def reduce(self, other_chunkref, sendtb=-1, recvtb=-1, ch=-1): - # Receive reduce copy - dst = self.rank - src = other_chunkref.rank - assert (self.prog.topo.link(src, dst) or src == dst), f'No link from {src} to {dst}' - # dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) - - # chunks1 = self.prog.get_chunks(self.rank, self.buffer, self.index, self.size) - # chunks2 = self.prog.get_chunks(other_chunkref.rank, other_chunkref.buffer, other_chunkref.index self.size) - - self.prog.apply_reduce(src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size) - - # reduce_chunks = self.prog.get_chunks(dst, buffer, index, self.size) - # self.prog.chunk_dag.add_reduce(chunks1, chunks2, reduce_chunks, self, dst_chunkref, sendtb, recvtb, ch) - if src != dst: - sop = self.prog.instr_dag.add_send(src, other_chunkref, self, sendtb, ch) - rop = self.prog.instr_dag.add_recv_reduce_copy(dst, other_chunkref, self, recvtb, ch, sop) - sop.recv_match = rop - else: - self.prog.instr_dag.add_reduce(src, other_chunkref, self, sendtb, ch) - - return self - - def get_origin_index(self, index=0): - return self._get_chunk(index + self.index).origin_index - - def get_origin_rank(self, index=0): - return self._get_chunk(index + self.index).origin_rank - - def get_dst_index(self, index=0): - return self._get_chunk(index + self.index).dst_index - - def get_dst_rank(self, index=0): - return self._get_chunk(index + self.index).dst_rank - - def print_chunk_info(self, index=0): - print(self._get_chunk(index + self.index)) - # @dataclass # class ChunkOp(): diff --git a/msccl/language/msccl.py b/msccl/language/msccl.py new file mode 100644 index 0000000..683bf8f --- /dev/null +++ b/msccl/language/msccl.py @@ -0,0 +1,266 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from msccl.language.buffer import * +from msccl.language.ir_mscclpp import * +from msccl.language.rank_dag import * +from msccl.language.tb_assignment import * + +_current_program = None + + +def _curr(): + global _current_program + if _current_program == None: + raise RuntimeError("No Program in context") + return _current_program + + +class MSCCLProgram: + def __init__( + self, + name, + topo, + collective, + instances, + protocol="Simple", + threadblock_policy=ThreadblockPolicy.auto, + interleaved_replication=True, + instr_fusion=True, + check_xml=True, + dependence_nop=False, + ): + self.name = name + self.topo = topo + self.collective = collective + self.num_ranks = topo.num_nodes() + self.instances = instances + self.protocol = protocol + self.threadblock_policy = threadblock_policy + self.interleaved_replication = interleaved_replication + self.instr_fusion = instr_fusion + self.check_xml = check_xml + self.dependence_nop = dependence_nop + assert ( + protocol == "Simple" or protocol == "LL" or protocol == "LL128" + ), f"Given protocol: {protocol}. Must be either Simple, LL, LL128" + self.run_opt = True # Runs optimization passes + # Initialize the input buffers + # self.chunk_dag = ChunkDAG() + self.buffers = collective.init_buffers() + self.instr_dag = InstructionDAG(self.num_ranks, self.buffers) + for r in range(self.num_ranks): + for index, chunk in enumerate(self.buffers[r][Buffer.input]): + buffer, index = self.collective.get_buffer_index(r, Buffer.input, index) + ref = self.get_ref(r, buffer, index, 1) + # self.chunk_dag.init_chunk(chunk, ref) + self.instr_dag.add_start(r, buffer, index, ref) + + def __enter__(self): + global _current_program + if _current_program != None: + raise RuntimeError("There is already a MSCCL Program in context") + _current_program = self + + def __exit__(self, exc_type, exc_value, exc_traceback): + global _current_program + if _current_program != self: + raise RuntimeError("This program is not currently in context") + _current_program = None + + # Tracks a send operation on the buffers + def apply_send(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size): + src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index) + dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index) + sb = self.buffers[src][src_buffer] + db = self.buffers[dst][dst_buffer] + for i in range(size): + db[dst_index + i] = sb[src_index + i] + + # Tracks a reduce operation on the buffers + def apply_reduce(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size): + src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index) + dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index) + sb = self.buffers[src][src_buffer] + db = self.buffers[dst][dst_buffer] + for i in range(size): + reduce_chunk = db[dst_index + i] + sent_chunk = sb[src_index + i] + db[dst_index + i] = reduce_chunk.reduce(dst, sent_chunk) + + def get_ref(self, rank, buffer, index, size): + buffer, index = self.collective.get_buffer_index(rank, buffer, index) + return Ref(rank, buffer, index, size, self) + + def get_chunks(self, rank, buffer, index, size=1): + chunks = [None] * size + for i in range(0, size): + if self.buffers[rank][buffer] and index + i < len(self.buffers[rank][buffer]): + chunks[i] = self.buffers[rank][buffer][index + i] + else: + chunks[i] = None + return chunks + + def check_buffer_exists(self, rank, name): + if name not in self.buffers[rank]: + self.buffers[rank][name] = BufferSlice(Buffer.scratch, name) + + # Checks that all chunks that should be on each rank + # are present in the output buffer. + def check(self): + return self.collective.check(self) + + # Lower program to XML + def lower(self): + # self.chunk_dag._complete_metadata() + # self.chunk_dag.channel_assignment() + # self.chunk_dag.lower_instr_dag(self.instr_dag) + self.instr_dag.convert_set_list() # Pre-emptively convert sets to lists + if self.instr_fusion: + self.instr_dag.optimize() + self.instr_dag._complete_metadata() + if self.threadblock_policy == ThreadblockPolicy.manual: + manual_assign_tbs(self.instr_dag) + else: + auto_assign_tbs(self.instr_dag) + self.instr_dag.lower_pt1(self.instances) + gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.interleaved_replication) + if self.check_xml: + # Check generated MSCCL-IR for correctness - no circular dependencies, sends and receives are ordered + # For very large programs, turn off check_xml when shipping + check_dependency_cycles(self.instr_dag.tbs) + check_threadblock_ordering(self.instr_dag) + return Program(self.name, self.collective.name, self.collective.inplace, self.protocol, gpu_prgms) + + def generate_xml(self): + return ir_to_xml(self.lower(), dependence_nop=self.dependence_nop) + + def print_chunk_dag(self): + visualize_chunk_dag(self.chunk_dag.chunk_paths) + + def print_instr_dags(self, rank): + if rank == 0: + for r in range(len(self.ranks)): + visualize_instr_dag(self.instr_dags[r].operations) + else: + visualize_instr_dag(self.instr_dags[rank].operations) + + +def XML(): + print(_curr().generate_xml()) + + +@dataclass +class Ref(ChunkRef): + prog: MSCCLProgram + + def __repr__(self): + return f"Ref(Buffer:{self.buffer}, Index:{self.index}, Size:{self.size}, Rank:{self.rank})" + + def _end(self): + return self.index + self.size + + def _get_chunk(self, index): + return self.prog.buffers[self.rank][self.buffer][index] + + def split(self, num): + assert self.size % num == 0, f"Trying to split a chunk of {self.size} elements into {num} parts" + chunks = [None] * num + size = self.size // num + for i in range(num): + index = self.index + i * size + chunks[i] = self.prog.get_ref(self.rank, self.buffer, index, size) + return chunks + + def group(self, other): + assert self.rank == other.rank, f"Trying to concatenate chunks on ranks {self.rank} and {other.rank}" + assert self.buffer == other.buffer, f"Trying to concatenate chunks in {self.buffer} and {other.buffer}" + if self.index < other.index: + first = self + second = other + else: + first = other + second = self + + end = max(first._end(), second._end()) + return Ref(self.rank, self.buffer, first.index, end - first.index, self.prog) + + # Copies the chunk(s) referenced by this chunkref onto Rank dst at location (buffer, index) + def copy(self, dst, buffer=None, index=-1, sendtb=-1, recvtb=-1, ch=-1): + self.prog.check_buffer_exists(dst, buffer) + + # If index is not specified assume it is going to the same place in the next gpu + if index == -1 and buffer == None: + index = self.index + buffer = self.buffer + elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output: + index = self.prog.buffers[dst][buffer].instance_size() + + # Some inplace collectives have custom logic for buffers and index (ReduceScatter, AllGather) + buffer, index = self.prog.collective.get_buffer_index(self.rank, buffer, index) + + # Direct send + assert self.prog.topo.link(self.rank, dst) or dst == self.rank, f"No link from {self.rank} to {dst}" + dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) + + # Check if we are copying the chunk to the same index (easy mistake when we are using inplace) + if dst_chunkref == self: + return + + # chunks = self.prog.get_chunks(self.rank, self.buffer, self.index, self.size) + # overwritten_chunks = self.prog.get_chunks(dst, buffer, index, self.size) + + self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) + + # self.prog.chunk_dag.add_send(chunks, overwritten_chunks, self, dst_chunkref, sendtb, recvtb, ch) + sender = self.rank + receiver = dst + if sender != receiver: + sop = self.prog.instr_dag.add_send(sender, self, dst_chunkref, sendtb, ch) + rop = self.prog.instr_dag.add_recv(receiver, self, dst_chunkref, recvtb, ch, sop) + sop.recv_match = rop + else: + self.prog.instr_dag.add_copy(sender, self, dst_chunkref, sendtb, ch) + + return dst_chunkref + + # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref + def reduce(self, other_chunkref, sendtb=-1, recvtb=-1, ch=-1): + # Receive reduce copy + dst = self.rank + src = other_chunkref.rank + assert self.prog.topo.link(src, dst) or src == dst, f"No link from {src} to {dst}" + # dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) + + # chunks1 = self.prog.get_chunks(self.rank, self.buffer, self.index, self.size) + # chunks2 = self.prog.get_chunks(other_chunkref.rank, other_chunkref.buffer, other_chunkref.index self.size) + + self.prog.apply_reduce( + src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size + ) + + # reduce_chunks = self.prog.get_chunks(dst, buffer, index, self.size) + # self.prog.chunk_dag.add_reduce(chunks1, chunks2, reduce_chunks, self, dst_chunkref, sendtb, recvtb, ch) + if src != dst: + sop = self.prog.instr_dag.add_send(src, other_chunkref, self, sendtb, ch) + rop = self.prog.instr_dag.add_recv_reduce_copy(dst, other_chunkref, self, recvtb, ch, sop) + sop.recv_match = rop + else: + self.prog.instr_dag.add_reduce(src, other_chunkref, self, sendtb, ch) + + return self + + def get_origin_index(self, index=0): + return self._get_chunk(index + self.index).origin_index + + def get_origin_rank(self, index=0): + return self._get_chunk(index + self.index).origin_rank + + def get_dst_index(self, index=0): + return self._get_chunk(index + self.index).dst_index + + def get_dst_rank(self, index=0): + return self._get_chunk(index + self.index).dst_rank + + def print_chunk_info(self, index=0): + print(self._get_chunk(index + self.index)) diff --git a/pyproject.toml b/pyproject.toml index 3d74b6c..d891952 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,4 @@ [tool.black] -line-length = 120 +line-length = 140 target-version = ['py38'] include = '\.pyi?$' From 2e5bac6386698c025720f5efd1d1a1bc502ac606 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 24 Apr 2024 11:27:22 +0000 Subject: [PATCH 43/76] WIP --- .github/workflows/tests.yaml | 2 +- msccl/autosynth/ndv4_plans.py | 4 +- msccl/autosynth/registry.py | 8 +- msccl/language/__init__.py | 2 +- msccl/language/instruction_dag.py | 445 +++++++++ msccl/language/ir.py | 334 ++----- msccl/language/msccl.py | 7 +- .../{mscclpp.py => mscclpp/__init__.py} | 16 +- msccl/language/mscclpp/instruction_dag.py | 704 ++++++++++++++ .../language/{ir_mscclpp.py => mscclpp/ir.py} | 5 +- msccl/language/rank_dag.py | 872 ------------------ msccl/language/tb_assignment.py | 8 +- msccl/language/types.py | 240 +++++ msccl/language/visualize.py | 10 +- 14 files changed, 1489 insertions(+), 1168 deletions(-) create mode 100755 msccl/language/instruction_dag.py rename msccl/language/{mscclpp.py => mscclpp/__init__.py} (95%) create mode 100644 msccl/language/mscclpp/instruction_dag.py rename msccl/language/{ir_mscclpp.py => mscclpp/ir.py} (98%) delete mode 100755 msccl/language/rank_dag.py create mode 100644 msccl/language/types.py diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 97e9dea..3137151 100755 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -11,7 +11,7 @@ jobs: strategy: matrix: - python-version: [3.6, 3.7, 3.8, 3.9] + python-version: [3.8, 3.9, 3.10] name: Test with Python ${{ matrix.python-version }} diff --git a/msccl/autosynth/ndv4_plans.py b/msccl/autosynth/ndv4_plans.py index a5453c0..84acb1e 100755 --- a/msccl/autosynth/ndv4_plans.py +++ b/msccl/autosynth/ndv4_plans.py @@ -7,7 +7,7 @@ from msccl.programs.alltoall_a100_yifan import alltoall_hierarchical from msccl.programs.alltoall_a100_8kp1 import alltoall_three_step from msccl.topologies import fully_connected -from msccl.language.ir import ThreadblockPolicy +from msccl.language.types import ThreadblockPolicy def register_ndv4_plans(): @@ -47,4 +47,4 @@ def ndv4_alltoall_three_step(prog, nodes): def ndv4_alltoall_hierarchical_config2(prog, nodes): alltoall_hierarchical(num_nodes=nodes, gpus_per_node=8) - + diff --git a/msccl/autosynth/registry.py b/msccl/autosynth/registry.py index 5dc89c6..fbe9810 100755 --- a/msccl/autosynth/registry.py +++ b/msccl/autosynth/registry.py @@ -9,7 +9,7 @@ import humanfriendly from msccl.language import MSCCLProgram, ir_to_xml -from msccl.language.ir import ThreadblockPolicy +from msccl.language.types import ThreadblockPolicy import msccl.language.collectives as lang_collectives from msccl.topologies import distributed_fully_connected @@ -62,7 +62,7 @@ def wrapped(machines): return decorator -def register_msccl_program(local_topology, collective, machine_type, machines=lambda x: True, sizes=None, protocol='Simple', +def register_msccl_program(local_topology, collective, machine_type, machines=lambda x: True, sizes=None, protocol='Simple', chunk_factor=1, priority=0, collective_obj=None, instances=1, inplace=False, threadblock_policy=ThreadblockPolicy.auto, interleaved_replication=True, dependence_nop=False): def decorator(fun): @@ -81,7 +81,7 @@ def wrapped(machines): co = lang_collectives.ReduceScatter(topology.num_nodes(), chunk_factor, inplace) else: raise RuntimeError(f'No collective_obj in msccl.language.collectives known for "{collective}"') - prog = MSCCLProgram(name, topology, co, instances, protocol, threadblock_policy=threadblock_policy, + prog = MSCCLProgram(name, topology, co, instances, protocol, threadblock_policy=threadblock_policy, interleaved_replication=interleaved_replication, dependence_nop=dependence_nop) with prog: fun(prog, machines) @@ -96,4 +96,4 @@ def wrapped(machines): machine_type, machines, sizes, protocol, priority) # Return the original function to not break other usage return fun - return decorator \ No newline at end of file + return decorator diff --git a/msccl/language/__init__.py b/msccl/language/__init__.py index e38be01..7f0822a 100755 --- a/msccl/language/__init__.py +++ b/msccl/language/__init__.py @@ -6,7 +6,7 @@ from msccl.language.tb_assignment import * from msccl.language.chunk import * from msccl.language.buffer import * -from msccl.language.rank_dag import * +from msccl.language.instruction_dag import * import msccl.language.msccl as msccl_lang import msccl.language.mscclpp as mscclpp from msccl.language.mscclpp import * diff --git a/msccl/language/instruction_dag.py b/msccl/language/instruction_dag.py new file mode 100755 index 0000000..f3e9748 --- /dev/null +++ b/msccl/language/instruction_dag.py @@ -0,0 +1,445 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from abc import ABC, abstractmethod +from collections import defaultdict + +from msccl.language.buffer import Buffer +from msccl.language.types import ChunkRef, Gpu, InstancePolicy, Instruction, MscclInstruction, Op, Threadblock + + +def remove_op(op: Op): + for p in op.prev: + p.next.remove(op) + p.next += op.next + + for n in op.next: + n.prev.remove(op) + n.prev = op.prev.union(n.prev) + + +def merge_op(op: Op, other_op: Op): + for p in other_op.prev: + p.next.remove(other_op) + p.next.append(op) + + for n in other_op.next: + n.prev.remove(other_op) + n.prev.add(op) + + op.prev = op.prev.union(other_op.prev) + op.next += other_op.next + + +def same_tb(op1: Op, op2: Op): + return op1.tb == op2.tb and op1.channel == op2.channel + + +def same_count(op1: Op, op2: Op): + return op1.cnt() == op2.cnt() + + +def same_buf_dst(op1: Op, op2: Op): + return op1.dst.buffer == op2.dst.buffer and op1.dst.index == op2.dst.index + + +def same_src_dst_buffer_type(op1: Op, op2: Op): + return op1.src.buffer == op2.src.buffer and op1.dst.buffer == op2.dst.buffer + + +def buf_dst_src_match(op1: Op, op2: Op): + return op1.dst.buffer == op2.src.buffer and op1.dst.index == op2.src.index + + +def same_buf_src(op1: Op, op2: Op): + return op1.src.buffer == op2.src.buffer and op1.src.index == op2.src.index + + +def same_chan_type(op1: Op, op2: Op): + return op1.channel_type == op2.channel_type + + +class InstructionDAG(ABC): + def __init__(self, num_ranks, buffers): + self.num_ranks = num_ranks + self.buffers = buffers + # State for the actual instruction DAG + self.operations = {} # slot -> operations + self.last_writer = {} # slot -> last writing op + self.last_readers = defaultdict(list) # slot -> list of last reading ops + # State for the MSCCL-IR + self.tbs = [] + for _ in range(num_ranks): + self.tbs.append({}) + self.tb_mapping = {} + self.num_channels = [1] * num_ranks + self.tb_steps = [{} for _ in range(num_ranks)] + + # InstructionDAG helper - identifies the dependencies for a write-type operation (recv, copy, rrc, reduce) + def _write(self, rank, buffer, index, size, op, read=False): + prev_ops = set() + for i in range(index, index + size): + slot = (rank, buffer, i) + if read: + assert slot in self.last_writer, f"Destination slot has never been written before a reduce {op}" + + # First write to this slot + if slot not in self.operations: + self.operations[slot] = op + + # If there are active readers - these are the previous operations + # Else the previous operation is the last write (if there is one) + readers = self.last_readers[slot] + if len(readers) > 0: + prev_ops.update(readers) + elif slot in self.last_writer: + prev_ops.add(self.last_writer[slot]) + + # Set the last_writer to this op, and clear all readers + self.last_writer[slot] = op + self.last_readers[slot] = [] + + # Update the next pointer of the previous ops + for prev_op in prev_ops: + prev_op.next.add(op) + op.prev.add(prev_op) + + # InstructionDAG helper - identifies the dependencies for read-type operations (send, copy, reduce) + def _read(self, rank, buffer, index, size, op): + prev_ops = set() + for i in range(index, index + size): + slot = (rank, buffer, i) + assert slot in self.last_writer, f"Slot has never been written before a read-type {op}" + # The previous operation for a reader is the last write to the slot + writer = self.last_writer[slot] + prev_ops.add(writer) + self.last_readers[slot].append(op) + + # Update the next pointer of the previous ops + for prev_op in prev_ops: + prev_op.next.add(op) + op.prev.add(prev_op) + + def _infer_dependencies(self): + for slot, ops in self.operations.items(): + frontier = [ops] + while len(frontier) > 0: + op = frontier[0] + # Dependencies for every op is the same as the ops that are stored in prev + # Filter out dependencies that are satisified by tbs executing ops sequentially + # If multiple dependent ops from the same tb keep the one that happens last + depends = {} + for dep_op in list(op.prev): + if dep_op.inst != Instruction.start: + tb = dep_op.tb + if tb not in depends or dep_op.step > depends[tb].step: + depends[tb] = dep_op + op.depends = list(depends.values()) + frontier = frontier[1:] + op.next + + # Convert local scratch buffers to index into one global scratch buffer + def _lower_chunk(self, chunk): + if chunk is not None and chunk.buffer is not Buffer.input and chunk.buffer is not Buffer.output: + buffer = self.buffers[chunk.rank][chunk.buffer].get_buffer() + index = self.buffers[chunk.rank][chunk.buffer].get_global_index(chunk.index) + return ChunkRef(chunk.rank, buffer, index, chunk.size) + return chunk + + # Assigns each scratch buffer an offset into the global scratch buffer + def _lower_buffers(self, instances): + for rank_buffers in self.buffers: + offset = 0 + for key, buf in rank_buffers.items(): + if key is not Buffer.input and key is not Buffer.output: + buf.set_offset(offset) + offset += buf.instance_size() * instances + + # Preprocess the threadblocks for lowering into xml + def _lower_tbs(self): + gpus = [] + for rank, rank_tbs in enumerate(self.instanced_tbs): + lowered_tbs = {} + for tbid, tb in rank_tbs.items(): + for op in tb.ops: + op.src = self._lower_chunk(op.src) + op.dst = self._lower_chunk(op.dst) + srcs = sorted(op.srcs, key=lambda x: x[1]) + dsts = sorted(op.dsts, key=lambda x: x[1]) + op.srcs = [self._lower_chunk(src[0]) for src in srcs] + op.dsts = [self._lower_chunk(dst[0]) for dst in dsts] + lowered_tbs[tbid] = tb + gpus.append(Gpu(rank, list(lowered_tbs.values()))) + return gpus + + # InstructionDAG - builds the roots of the DAG + def add_start(self, rank, buffer, index, ref): + slot = (rank, buffer, index) + op = Op(Instruction.start, rank, ref, ref, next=set(), prev=set(), chunk_step=-1) + self.operations[slot] = op + self.last_writer[slot] = op + + def convert_set_list(self): + ops = [] + visited = set() + for slot, op in self.operations.items(): + if op.inst == Instruction.start: + op.next = list(op.next) + for o in op.next: + ops.append(o) + elif op.inst != MscclInstruction.copy: + ops.append(op) + + while len(ops) > 0: + op = ops[0] + if op not in visited: + visited.add(op) + op.next = list(op.next) + ops = ops[1:] + op.next + else: + ops = ops[1:] + return visited + + def lower_pt1(self, instances: int): + self._infer_dependencies() + self._lower_buffers(instances) + + def lower_pt2(self, instances: int, instance_pollicy: InstancePolicy): + self.replicate(instances, instance_pollicy) + return self._lower_tbs() + + @abstractmethod + def optimize(self): + pass + + @abstractmethod + def replicate(self, instances: int, instance_policy: InstancePolicy): + pass + + +class MscclInstructionDAG(InstructionDAG): + + def __init__(self, num_ranks, buffers): + super().__init__(num_ranks, buffers) + + # InstructionDAG - adds a copy node + def add_copy(self, rank, send_ref, recv_ref, tb, ch): + op = Op(MscclInstruction.copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) + dstbuffer = recv_ref.buffer + dstindex = recv_ref.index + srcbuffer = send_ref.buffer + srcindex = send_ref.index + size = recv_ref.size + # Sending part of copy [Read] + self._read(rank, srcbuffer, srcindex, size, op) + # Receiving part of copy [Write] + self._write(rank, dstbuffer, dstindex, size, op) + return op + + # InstructionDAG - adds a redduce node + def add_reduce(self, rank, send_ref, recv_ref, tb, ch): + op = Op(MscclInstruction.reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) + dstbuffer = recv_ref.buffer + dstindex = recv_ref.index + srcbuffer = send_ref.buffer + srcindex = send_ref.index + size = recv_ref.size + # Sending part of reduce + self._read(rank, srcbuffer, srcindex, size, op) + # Reduce part of copy + self._write(rank, dstbuffer, dstindex, size, op, read=True) + return op + + # InstructionDAG - adds a send node + def add_send(self, rank, send_ref, recv_ref, tb, ch): + op = Op(MscclInstruction.send, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) + buffer = send_ref.buffer + index = send_ref.index + size = send_ref.size + self._read(rank, buffer, index, size, op) + return op + + # InstructionDAG - adds a recv node + def add_recv(self, rank, send_ref, recv_ref, tb, ch, send_op): + op = Op(MscclInstruction.recv, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) + buffer = recv_ref.buffer + index = recv_ref.index + size = recv_ref.size + self._write(rank, buffer, index, size, op) + op.send_match = send_op + return op + + # InstructionDAG - adds a rrc node + def add_recv_reduce_copy(self, rank, send_ref, recv_ref, tb, ch, send_op): + op = Op(MscclInstruction.recv_reduce_copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) + buffer = recv_ref.buffer + index = recv_ref.index + size = recv_ref.size + self._write(rank, buffer, index, size, op, read=True) + op.send_match = send_op + return op + + def optimize(self): + self._optimize_rrcs_rrs() + self._optimize_rcs() + + # Completes metadata for chunk_steps (number of steps from a start op) and priority (number of steps to the last op) + def _complete_metadata(self): + def dfs(op, cs): + op.chunk_step = max(op.chunk_step, cs + 1) + + if len(op.next) == 0 and op.recv_match is None: + op.priority = 0 + else: + for o in op.next: + dfs(o, op.chunk_step) + # Priority = +1 of the highest priority child + if len(op.next) > 0: + highest_next_priority = max([x.priority + 1 for x in op.next]) + op.priority = max(highest_next_priority, op.priority) + if op.is_send(): + dfs(op.recv_match, op.chunk_step) + op.priority = max(op.priority, op.recv_match.priority + 1) + + for chunk, op in self.operations.items(): + if op.inst == Instruction.start: + dfs(op, -2) # Start instructions should start at -1 + + # Given the set of operations that operate over a particular slot (rank, buffer, idx) fixed + # Try and replace operations with pipelined ops like receive copy send (rcs) + # or receive reduce send (rrs) and receive reduce copy send (rrcs) + # Rules: + # recv-copy-send + # recv(src, sbuf, si, _, _, _ ) send(_, _, _, dst, dbuf, di) -> recv_copy_send(src, sbuf, si, dst, dbuf, di) + def _optimize_rcs(self): + for slot, ops in self.operations.items(): + frontier = [ops] + while len(frontier) > 0: + op = frontier[0] + for next_op in op.next: + if ( + op.inst == MscclInstruction.recv + and next_op.inst == MscclInstruction.send + and same_tb(op, next_op) + and same_count(op, next_op) + and same_buf_dst(op, next_op) + ): + # recv -> rcs, remove send + op.inst = MscclInstruction.recv_copy_send + op.dst = next_op.dst + next_op.recv_match.send_match = op + op.recv_match = next_op.recv_match + remove_op(next_op) + break + frontier = frontier[1:] + op.next + + # recv-reduce-send - A rrc followed by a send that gets overwritten + # rrc(src, sbuf, si, ...) send(_, _, _, dst, dbuf, di) recv(_, _, _, dst, dbuf, di) + # recv-reduce-copy-send - A rrc followed by a send that does not get overwritten + # rrc(src, sbuf, si, ...) send(_, _, _, dst, dbuf, di) + def _optimize_rrcs_rrs(self): + # RRC/S -> RRS + for slot, ops in self.operations.items(): + frontier = [ops] + while len(frontier) > 0: + op = frontier[0] + if len(op.next) == 1: + next_op = op.next[0] + if len(next_op.next) == 1: + nnext_op = next_op.next[0] + if ( + op.inst == MscclInstruction.recv_reduce_copy + and next_op.inst == MscclInstruction.send + and nnext_op.inst is MscclInstruction.recv + and same_tb(op, next_op) + and same_count(op, next_op) + and same_buf_dst(op, next_op) + ): + op.inst = MscclInstruction.recv_reduce_send + op.dst = next_op.dst + next_op.recv_match.send_match = op + op.recv_match = next_op.recv_match + remove_op(next_op) + + if ( + op.inst == MscclInstruction.recv_reduce_copy + and next_op.inst == MscclInstruction.send + and same_tb(op, next_op) + and same_count(op, next_op) + and same_buf_dst(op, next_op) + ): + op.inst = MscclInstruction.recv_reduce_copy_send + op.dst = next_op.dst + next_op.recv_match.send_match = op + op.recv_match = next_op.recv_match + remove_op(next_op) + frontier = frontier[1:] + op.next + + # Automatically replicates the algorithm instance number of times + # interleaved sets the replication policy + # if True chunks are split as: ChunkA ChunkB -> ChunkA0 ChunkA1 .. ChunkB0 ChunkB1 ... + # if false chunks are divided as ChunkA0 ChunkB0 ChunkA1 ChunkB1 ... + # For collectives were chunks are designated for a particular GPU (e.g. AllToAll) + # only interleaved replication will be correct + # Interleaved policy only supports single count sends/receives from the input/output buffer + # (multicount ops are fine between scratch) + def replicate(self, instances, interleaved): + if instances == 1: + self.instanced_tbs = self.tbs + return + + self.instanced_tbs = [] + for _ in range(self.num_ranks): + self.instanced_tbs.append({}) + + def is_scratch(buffer): + return buffer != Buffer.input and buffer != Buffer.output + + def get_new_index(rank, buffer, index, size, i): + # Scratch buffers always use batched + if is_scratch(buffer): + buf_instance_len = self.buffers[rank][buffer].instance_size() + return buf_instance_len * i + index + # If this is operating on the input/output buffer then replication strategy can be either interleaved or batched + # This is to fit with the semantics of certain collectives + elif interleaved: + return index * instances + i * size + else: + return len(self.buffers[rank][buffer]) * i + index + + def get_instance_ref(ref): + iindex = get_new_index(ref.rank, ref.buffer, ref.index, ref.size, i) + iref = ChunkRef(ref.rank, ref.buffer, iindex, ref.size) + return iref + + max_channels = max(self.num_channels) + for i in range(instances): + # Generate all the threadblocks and ops + for rank, rank_tbs in enumerate(self.tbs): + # rank_channels = self.num_channels[rank] + for tbid, tb in rank_tbs.items(): + instance_channel = max_channels * i + tb.channel + itb = Threadblock(instance_channel, tb.send, tb.recv) + itbid = tbid * instances + i + itb.ops = [None] * len(tb.ops) + for s, op in enumerate(tb.ops): + isrc = get_instance_ref(op.src) + idst = get_instance_ref(op.dst) + idepends = [] + # Note: We don't need the fill out the rest of the metadata since replication is the last optimization + iop = Op(op.inst, op.rank, isrc, idst, idepends, op.step, itbid) + itb.ops[s] = iop + self.instanced_tbs[op.rank][itbid] = itb + + # Redo dependency analysis + for rank, rank_tbs in enumerate(self.tbs): + for tbid, tb in rank_tbs.items(): + for i in range(instances): + itbid = tbid * instances + i + itb = self.instanced_tbs[rank][itbid] + for op, iop in zip(tb.ops, itb.ops): + iop.depends = [None] * len(op.depends) + for s, dep in enumerate(op.depends): + dep_tbid = dep.tb + dep_itbid = dep_tbid * instances + i + dep_step = dep.step + iop.depends[s] = self.instanced_tbs[op.rank][dep_itbid].ops[dep_step] diff --git a/msccl/language/ir.py b/msccl/language/ir.py index 13556d4..3ae2ea3 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -2,224 +2,25 @@ # Licensed under the MIT License. from lxml import etree as ET -from dataclasses import dataclass, field -from enum import Enum from collections import defaultdict from msccl.language.buffer import Buffer -from msccl.language.channel import ChannelType - - -@dataclass -class Program: - name: str - collective: str - inplace: bool - protocol: str - gpus: list = field(default_factory=list) - num_chunk_groups: int = 1 - - -@dataclass -class Gpu: - rank: int - threadblocks: list = field(default_factory=list) - - # From ncclize - precopies: list = field(default_factory=list) - postcopies: list = field(default_factory=list) - inputs: dict = field(default_factory=dict) - outputs: dict = field(default_factory=dict) - input_chunks: int = 0 - output_chunks: int = 0 - scratch_chunks: int = 0 - scratch: dict = field(default_factory=dict) - channels: dict = field(default_factory=dict) - - def scratch_size(self): - return max((idx for addr, idx in self.scratch.items()), default=-1) + 1 - -@dataclass -class Threadblock: - id: int = -1 - channel: int = -1 - send: int = -1 - recv: int = -1 - ops: list = field(default_factory=list) - rbid: int = -1 # threadblock id of the receiver - channels: list = field(default_factory=list) - - def __eq__(self, other): - return self is other - - def __hash__(self): - return id(self) - - -class ChunkInstruction(Enum): - start = 'start' - reduce = 'reduce' - send = 'send' - - def __str__(self): - return self.value - - -class ThreadblockPolicy(Enum): - auto = 'auto' - manual = 'manual' - - def __str__(self): - return self.value - -class InstancePolicy(Enum): - # this means pack multi instrances to deal with the same chunk and share the channels - packed = 'packed' - # this means each instance deal with the same chunk - dup = 'dup' - - def __str__(self): - return self.value - - -class Instruction(Enum): - nop = 'nop' - send = 's' - recv = 'r' - recv_copy_send = 'rcs' - recv_reduce_send = 'rrs' - recv_reduce_copy = 'rrc' - recv_reduce_copy_send = 'rrcs' - read_reduce_copy = "rrc" - read_reduce_copy_send = "rrcs" - reduce_send = 'rs' - copy = 'cpy' - reduce = 're' - delete = 'd' - start = 'st' - # used by mscclpp only - copy_packet = 'cpkt' - reduce_send_packet = 'rspkt' - reduce_packet = 'rpkt' - put = 'put' - put_packet = 'ppkt' - get = 'get' - wait = 'wait' - signal = 'signal' - flush = 'flush' - - def __str__(self): - return self.value - -@dataclass -class ChunkRef: - rank: int - buffer: Buffer - index: int - size: int - - def __hash__(self): - return hash((self.rank, self.buffer, self.index, self.size)) - -@dataclass -class Op: - inst: Instruction - rank: int - src: ChunkRef - dst: ChunkRef - depends: list = field(default_factory=list) - step: int = -1 # Step in the TB - tb: int = -1 # TB this op is assigned to - prev: list = field(default_factory=list) # List of instructions that happen before - next: list = field(default_factory=list) # List of instructions that happen after - num: int = -1 - chunk_step: int = -1 - priority: int = -1 - recv_match = None - send_match = None - channel: int = -1 - channel_type: ChannelType = ChannelType.none - srcs: list = field(default_factory=list) - dsts: list = field(default_factory=list) - - def cnt(self): - if self.src: - if self.dst: - assert self.src.size == self.dst.size - return self.src.size - elif self.dst: - return self.dst.size - else: - return 0 - - def is_send(self): - return self.inst == Instruction.send or \ - self.inst == Instruction.recv_reduce_copy_send or \ - self.inst == Instruction.recv_copy_send or \ - self.inst == Instruction.recv_reduce_send - - def is_recv(self): - return self.inst == Instruction.recv or \ - self.inst == Instruction.recv_reduce_copy or \ - self.inst == Instruction.recv_reduce_copy_send or \ - self.inst == Instruction.recv_copy_send or \ - self.inst == Instruction.recv_reduce_send - - def is_fused(self): - return self.inst == Instruction.recv_reduce_copy_send or \ - self.inst == Instruction.recv_copy_send or \ - self.inst == Instruction.recv_reduce_send - - def is_local(self): - return self.inst == Instruction.copy or \ - self.inst == Instruction.reduce - - def peer(self): - if self.inst == Instruction.send: - return self.dst.rank - elif self.inst == Instruction.recv: - return self.src.rank - else: - return None - - def send_peer(self): - if self.is_send(): - return self.dst.rank - return -1 - - def recv_peer(self): - if self.is_recv(): - return self.src.rank - return -1 - - def __eq__(self, other): - return self is other - - def __lt__(self, other): - # Ordering of operations - # 1. Lower chunk step 2. Higher priority 3. Lower src index - if self.chunk_step == other.chunk_step: - if self.priority == other.priority: - return self.src.index < other.src.index - return self.priority > other.priority - return self.chunk_step < other.chunk_step - - def __gt__(self, other): - return not self < other - - def __hash__(self): - return id(self) - - def __repr__(self): - return f'Op({self.inst}, {self.rank}, {self.src}, {self.dst}, step:{self.step}, tb:{self.tb})' +from msccl.language.types import MscclInstruction as Instruction, Op, Program # Instructions where src is on local GPU _local_src_insts = {Instruction.send, Instruction.copy, Instruction.reduce} # Instructions where dst is on local GPU -_local_dst_insts = {Instruction.recv, Instruction.recv_copy_send, Instruction.recv_reduce_send, - Instruction.recv_reduce_copy, Instruction.copy, Instruction.reduce, - Instruction.recv_reduce_copy_send} +_local_dst_insts = { + Instruction.recv, + Instruction.recv_copy_send, + Instruction.recv_reduce_send, + Instruction.recv_reduce_copy, + Instruction.copy, + Instruction.reduce, + Instruction.recv_reduce_copy_send, +} + def ir_to_xml(program: Program, old_format=True, use_scratch=True, pretty_print=True, dependence_nop=False): # Figure out sizes of buffers based on usage @@ -229,19 +30,16 @@ def ir_to_xml(program: Program, old_format=True, use_scratch=True, pretty_print= for op in tb.ops: if op.inst in _local_src_insts: key = (gpu.rank, op.src.buffer) - buffer_sizes[key] = max( - buffer_sizes[key], op.src.index + op.src.size) + buffer_sizes[key] = max(buffer_sizes[key], op.src.index + op.src.size) if op.inst in _local_dst_insts: key = (gpu.rank, op.dst.buffer) - buffer_sizes[key] = max( - buffer_sizes[key], op.dst.index + op.dst.size) + buffer_sizes[key] = max(buffer_sizes[key], op.dst.index + op.dst.size) tb_id = {} # Sort threadblocks in each GPU by peers and then the channel # This is important as in NCCL threadblocks using the same NVLink concurrently should be close together for gpu in program.gpus: - gpu.threadblocks = sorted(gpu.threadblocks, - key=lambda tb: (tb.send, tb.recv, tb.channel)) + gpu.threadblocks = sorted(gpu.threadblocks, key=lambda tb: (tb.send, tb.recv, tb.channel)) for i, tb in enumerate(gpu.threadblocks): tb_id[tb] = i @@ -254,8 +52,7 @@ def ir_to_xml(program: Program, old_format=True, use_scratch=True, pretty_print= for gpu in program.gpus: for tb in gpu.threadblocks: for op in tb.ops: - op.depends = list( - filter(lambda dep: op_tb_id[dep] != tb_id[tb], op.depends)) + op.depends = list(filter(lambda dep: op_tb_id[dep] != tb_id[tb], op.depends)) # Filter out redundant dependencies # e.g. if op1 and op2 depend on op, and op1 happends before op2 # then op2 does not need to explicitly depend on op @@ -263,8 +60,7 @@ def ir_to_xml(program: Program, old_format=True, use_scratch=True, pretty_print= for tb in gpu.threadblocks: running_depends = [] for op in tb.ops: - op.depends = list( - filter(lambda dep: dep not in running_depends, op.depends)) + op.depends = list(filter(lambda dep: dep not in running_depends, op.depends)) running_depends = running_depends + op.depends # Mark all ops that have a dependence on them @@ -318,7 +114,7 @@ def ir_to_xml(program: Program, old_format=True, use_scratch=True, pretty_print= for i, dep in enumerate(extra_deps): new_ops.append(Op(Instruction.nop, -1, None, None, [dep])) op_idx[new_ops[-1]] = len(new_ops) - 1 - #op_tb_id[new_ops[-1]] = op_tb_id[op] + # op_tb_id[new_ops[-1]] = op_tb_id[op] new_ops.append(op) op_idx[new_ops[-1]] = len(new_ops) - 1 tb.ops = new_ops @@ -327,36 +123,42 @@ def ir_to_xml(program: Program, old_format=True, use_scratch=True, pretty_print= for gpu in program.gpus: max_tb_channels = 0 if len(gpu.threadblocks) > 0: - max_tb_channels = max(tb.channel+1 for tb in gpu.threadblocks) + max_tb_channels = max(tb.channel + 1 for tb in gpu.threadblocks) nchannels = max(nchannels, max_tb_channels) # Generate the XML structure - algo_elem = ET.Element('algo') - algo_elem.set('name', program.name) - algo_elem.set('proto', program.protocol) - algo_elem.set('nchannels', str(nchannels)) + algo_elem = ET.Element("algo") + algo_elem.set("name", program.name) + algo_elem.set("proto", program.protocol) + algo_elem.set("nchannels", str(nchannels)) if old_format: - algo_elem.set('nchunksperloop', str( - max(max(buffer_sizes[(gpu.rank, Buffer.input)], buffer_sizes[(gpu.rank, Buffer.output)]) for gpu in program.gpus))) - algo_elem.set('ngpus', str(len(program.gpus))) - algo_elem.set('coll', program.collective) - algo_elem.set('inplace', str(1 if program.inplace else 0)) + algo_elem.set( + "nchunksperloop", + str( + max( + max(buffer_sizes[(gpu.rank, Buffer.input)], buffer_sizes[(gpu.rank, Buffer.output)]) + for gpu in program.gpus + ) + ), + ) + algo_elem.set("ngpus", str(len(program.gpus))) + algo_elem.set("coll", program.collective) + algo_elem.set("inplace", str(1 if program.inplace else 0)) for gpu in program.gpus: - gpu_elem = ET.SubElement(algo_elem, 'gpu') - gpu_elem.set('id', str(gpu.rank)) - gpu_elem.set('i_chunks', str(max(buffer_sizes[(gpu.rank, Buffer.input)], gpu.input_chunks))) - gpu_elem.set('o_chunks', str(max(buffer_sizes[(gpu.rank, Buffer.output)], gpu.output_chunks))) - gpu_elem.set('s_chunks', str(max(buffer_sizes[(gpu.rank, Buffer.scratch)], gpu.scratch_size()))) + gpu_elem = ET.SubElement(algo_elem, "gpu") + gpu_elem.set("id", str(gpu.rank)) + gpu_elem.set("i_chunks", str(max(buffer_sizes[(gpu.rank, Buffer.input)], gpu.input_chunks))) + gpu_elem.set("o_chunks", str(max(buffer_sizes[(gpu.rank, Buffer.output)], gpu.output_chunks))) + gpu_elem.set("s_chunks", str(max(buffer_sizes[(gpu.rank, Buffer.scratch)], gpu.scratch_size()))) for tb in gpu.threadblocks: - tb_elem = ET.SubElement(gpu_elem, 'tb') - tb_elem.set('id', str(tb_id[tb])) - tb_elem.set('send', str(tb.send)) - tb_elem.set('recv', str(tb.recv)) - tb_elem.set('chan', str(tb.channel)) + tb_elem = ET.SubElement(gpu_elem, "tb") + tb_elem.set("id", str(tb_id[tb])) + tb_elem.set("send", str(tb.send)) + tb_elem.set("recv", str(tb.recv)) + tb_elem.set("chan", str(tb.channel)) for op in tb.ops: - op_elem = ET.SubElement( - tb_elem, 'op' if not old_format else 'step') - op_elem.set('step' if not old_format else 's', str(op_idx[op])) - op_elem.set('type', str(op.inst)) + op_elem = ET.SubElement(tb_elem, "op" if not old_format else "step") + op_elem.set("step" if not old_format else "s", str(op_idx[op])) + op_elem.set("type", str(op.inst)) # The NCCL backend currently wants scratch at the end of output if not use_scratch: @@ -369,40 +171,40 @@ def ir_to_xml(program: Program, old_format=True, use_scratch=True, pretty_print= if old_format: if op.src is not None: - op_elem.set('srcbuf', str(op.src.buffer)) - op_elem.set('srcoff', str(op.src.index)) + op_elem.set("srcbuf", str(op.src.buffer)) + op_elem.set("srcoff", str(op.src.index)) else: - op_elem.set('srcbuf', 'i') - op_elem.set('srcoff', '-1') + op_elem.set("srcbuf", "i") + op_elem.set("srcoff", "-1") if op.dst is not None: - op_elem.set('dstbuf', str(op.dst.buffer)) - op_elem.set('dstoff', str(op.dst.index)) + op_elem.set("dstbuf", str(op.dst.buffer)) + op_elem.set("dstoff", str(op.dst.index)) else: - op_elem.set('dstbuf', 'o') - op_elem.set('dstoff', '-1') + op_elem.set("dstbuf", "o") + op_elem.set("dstoff", "-1") else: if op.is_send(): if op.src is not None: - op_elem.set('buf', str(op.src.buffer)) - op_elem.set('off', str(op.src.index)) + op_elem.set("buf", str(op.src.buffer)) + op_elem.set("off", str(op.src.index)) else: if op.dst is not None: - op_elem.set('buf', str(op.dst.buffer)) - op_elem.set('off', str(op.dst.index)) + op_elem.set("buf", str(op.dst.buffer)) + op_elem.set("off", str(op.dst.index)) if op.cnt() > 1 or old_format: - op_elem.set('cnt', str(op.cnt())) + op_elem.set("cnt", str(op.cnt())) assert len(op.depends) <= 1 if len(op.depends) == 1: - op_elem.set('depid', str(op_tb_id[op.depends[0]])) - op_elem.set('deps', str(op_idx[op.depends[0]])) + op_elem.set("depid", str(op_tb_id[op.depends[0]])) + op_elem.set("deps", str(op_idx[op.depends[0]])) elif old_format: - op_elem.set('depid', '-1') - op_elem.set('deps', '-1') + op_elem.set("depid", "-1") + op_elem.set("deps", "-1") if op in has_dependence: - op_elem.set('hasdep', '1') + op_elem.set("hasdep", "1") elif old_format: - op_elem.set('hasdep', '0') + op_elem.set("hasdep", "0") if pretty_print: - ET.indent(algo_elem, space=' ') - return ET.tostring(algo_elem, encoding='unicode') + ET.indent(algo_elem, space=" ") + return ET.tostring(algo_elem, encoding="unicode") diff --git a/msccl/language/msccl.py b/msccl/language/msccl.py index 683bf8f..e54756f 100644 --- a/msccl/language/msccl.py +++ b/msccl/language/msccl.py @@ -2,9 +2,10 @@ # Licensed under the MIT License. from msccl.language.buffer import * -from msccl.language.ir_mscclpp import * -from msccl.language.rank_dag import * +from msccl.language.instruction_dag import * +from msccl.language.passes import * from msccl.language.tb_assignment import * +from msccl.language.types import ThreadblockPolicy _current_program = None @@ -48,7 +49,7 @@ def __init__( # Initialize the input buffers # self.chunk_dag = ChunkDAG() self.buffers = collective.init_buffers() - self.instr_dag = InstructionDAG(self.num_ranks, self.buffers) + self.instr_dag = MscclInstructionDAG(self.num_ranks, self.buffers) for r in range(self.num_ranks): for index, chunk in enumerate(self.buffers[r][Buffer.input]): buffer, index = self.collective.get_buffer_index(r, Buffer.input, index) diff --git a/msccl/language/mscclpp.py b/msccl/language/mscclpp/__init__.py similarity index 95% rename from msccl/language/mscclpp.py rename to msccl/language/mscclpp/__init__.py index 1661925..5e8e29f 100644 --- a/msccl/language/mscclpp.py +++ b/msccl/language/mscclpp/__init__.py @@ -3,8 +3,8 @@ from msccl.collectives import Collective from msccl.language.buffer import * -from msccl.language.ir_mscclpp import * -from msccl.language.rank_dag import * +from msccl.language.mscclpp.ir import * +from msccl.language.mscclpp.instruction_dag import MscclppInstructionDAG from msccl.language.tb_assignment import * from msccl.topologies.topology import Topology @@ -31,7 +31,7 @@ def __init__( instances: int, protocol: str = "Simple", instr_fusion: bool = True, - instance_policy: InstancePolicy = InstancePolicy.dup, + instance_policy: InstancePolicy = InstancePolicy.duplicated, ): self.name = name self.topo = topo @@ -45,7 +45,7 @@ def __init__( self.run_opt = True # Runs optimization passes # Initialize the input buffers self.buffers = collective.init_buffers() - self.instr_dag = InstructionDAG(self.num_ranks, self.buffers) + self.instr_dag = MscclppInstructionDAG(self.num_ranks, self.buffers) for r in range(self.num_ranks): for index, chunk in enumerate(self.buffers[r][Buffer.input]): buffer, index = self.collective.get_buffer_index(r, Buffer.input, index) @@ -112,9 +112,9 @@ def lower(self): convert_to_exectuion_plan(self.instr_dag) self.instr_dag.complete_channels() if self.instr_fusion: - self.instr_dag.optimize_mscclpp() + self.instr_dag.optimize() self.instr_dag.lower_pt1(self.instances) - gpu_prgms = self.instr_dag.lower_pt2_mscclpp(self.instances, self.instance_policy) + gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.instance_policy) return Program( self.name, self.collective.name, @@ -246,7 +246,7 @@ def _copy(self, dst, buffer=None, index=-1, sendtb=-1, use_packet=False): self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) assert self.rank == dst, "Chunk copy only supports intra-rank communication" - self.prog.instr_dag.add_copy_mscclpp(self.rank, self, dst_chunkref, sendtb, use_packet) + self.prog.instr_dag.add_copy(self.rank, self, dst_chunkref, sendtb, use_packet) return dst_chunkref @@ -270,7 +270,7 @@ def _reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.sm, use_pa if src != dst: self.prog.instr_dag.add_read_reduce(dst, other_chunkref, self, recvtb, channel_type) else: - self.prog.instr_dag.add_reduce_mscclpp(src, other_chunkref, self, recvtb, use_packet) + self.prog.instr_dag.add_reduce(src, other_chunkref, self, recvtb, use_packet) return self diff --git a/msccl/language/mscclpp/instruction_dag.py b/msccl/language/mscclpp/instruction_dag.py new file mode 100644 index 0000000..093d41d --- /dev/null +++ b/msccl/language/mscclpp/instruction_dag.py @@ -0,0 +1,704 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from msccl.language.buffer import Buffer +from msccl.language.channel import Channel, ChannelType +from msccl.language.instruction_dag import ( + buf_dst_src_match, + merge_op, + remove_op, + same_buf_dst, + same_buf_src, + same_chan_type, + same_count, + same_src_dst_buffer_type, +) +from msccl.language.instruction_dag import InstructionDAG +from msccl.language.types import ChunkRef, InstancePolicy, MscclppInstruction as Instruction, Op, Threadblock + + +class MscclppInstructionDAG(InstructionDAG): + def __init__(self, num_ranks, buffers): + super().__init__(num_ranks, buffers) + + # InstructionDAG - adds a copy node + def add_copy(self, rank, send_ref, recv_ref, tb, use_packet=False): + tb_step = self._get_tb_step(rank, tb) + if use_packet: + op = Op(Instruction.copy_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) + else: + op = Op(Instruction.copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) + dstbuffer = recv_ref.buffer + dstindex = recv_ref.index + srcbuffer = send_ref.buffer + srcindex = send_ref.index + size = recv_ref.size + # Sending part of copy [Read] + self._read(rank, srcbuffer, srcindex, size, op) + # Receiving part of copy [Write] + self._write(rank, dstbuffer, dstindex, size, op) + return op + + # InstructionDAG - adds a redduce node + def add_reduce(self, rank, send_ref, recv_ref, tb, use_packet=False): + tb_step = self._get_tb_step(rank, tb) + if use_packet: + op = Op(Instruction.reduce_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) + else: + op = Op(Instruction.reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) + dstbuffer = recv_ref.buffer + dstindex = recv_ref.index + srcbuffer = send_ref.buffer + srcindex = send_ref.index + size = recv_ref.size + op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) + # Sending part of reduce + self._read(rank, srcbuffer, srcindex, size, op) + # Reduce part of copy + self._write(rank, dstbuffer, dstindex, size, op, read=True) + return op + + # InstructionDAG - adds a put node + def add_put(self, rank, send_ref, recv_ref, tb, ch_type, use_packet=False): + tb_step = self._get_tb_step(rank, tb) + if use_packet: + op = Op( + Instruction.put_packet, + rank, + send_ref, + recv_ref, + next=set(), + prev=set(), + tb=tb, + channel_type=ch_type, + step=tb_step, + ) + else: + op = Op( + Instruction.put, + rank, + send_ref, + recv_ref, + next=set(), + prev=set(), + tb=tb, + channel_type=ch_type, + step=tb_step, + ) + buffer = send_ref.buffer + index = send_ref.index + size = send_ref.size + self._read(rank, buffer, index, size, op) + op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step)) + op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) + return op + + def add_get(self, rank, send_ref, recv_ref, tb, ch_type): + tb_step = self._get_tb_step(rank, tb) + op = Op( + Instruction.get, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step + ) + buffer = recv_ref.buffer + index = recv_ref.index + size = recv_ref.size + self._write(rank, buffer, index, size, op) + op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) + op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step)) + return op + + # InstructionDAG - adds a signal node. + def add_signal(self, rank, send_ref, recv_ref, tb, ch_type): + tb_step = self._get_tb_step(rank, tb) + op = Op( + Instruction.signal, + rank, + send_ref, + recv_ref, + next=set(), + prev=set(), + tb=tb, + channel_type=ch_type, + step=tb_step, + ) + buffer = send_ref.buffer + index = send_ref.index + size = send_ref.size + # treat signal as a write since it can not be executed parallelly with read operations + self._write(rank, buffer, index, size, op) + op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step)) + op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) + return op + + def add_wait(self, rank, dst_ref, src_ref, tb, ch_type): + tb_step = self._get_tb_step(rank, tb) + op = Op( + Instruction.wait, rank, src_ref, dst_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step + ) + buffer = dst_ref.buffer + index = dst_ref.index + size = dst_ref.size + self._write(rank, buffer, index, size, op) + op.srcs.append((ChunkRef(src_ref.rank, src_ref.buffer, src_ref.index, src_ref.size), tb_step)) + op.dsts.append((ChunkRef(dst_ref.rank, dst_ref.buffer, dst_ref.index, dst_ref.size), tb_step)) + return op + + def add_read_reduce(self, rank, send_ref, recv_ref, tb, ch_type): + tb_step = self._get_tb_step(rank, tb) + op = Op( + Instruction.read_reduce_copy, + rank, + send_ref, + recv_ref, + next=set(), + prev=set(), + tb=tb, + channel_type=ch_type, + step=tb_step, + ) + buffer = recv_ref.buffer + index = recv_ref.index + size = recv_ref.size + op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) + self._write(rank, buffer, index, size, op, read=True) + return op + + def complete_channels(self): + send_op = [Instruction.put, Instruction.signal, Instruction.put_packet] + recv_op = [Instruction.wait, Instruction.get, Instruction.read_reduce_copy] + for rank, rank_tbs in enumerate(self.tbs): + for tbid, tb in rank_tbs.items(): + chans = set() + for op in tb.ops: + src_buffer = ( + Buffer.scratch + if op.src.buffer is not Buffer.input and op.src.buffer is not Buffer.output + else op.src.buffer + ) + dst_buffer = ( + Buffer.scratch + if op.dst.buffer is not Buffer.input and op.dst.buffer is not Buffer.output + else op.dst.buffer + ) + if op.inst in send_op: + chan = Channel(src_buffer, dst_buffer, op.channel_type, op.dst.rank) + chans.add(chan) + elif op.inst in recv_op: + chan = Channel(src_buffer, dst_buffer, op.channel_type, op.src.rank) + chans.add(chan) + tb.channels = list(chans) + + def _optimize_redundant_signal_wait(self): + # For packet ops, we can remove signal/wait + for rank, rank_tbs in enumerate(self.tbs): + for tbid, tb in rank_tbs.items(): + queue = list(tb.ops) + while len(queue) > 0: + op = queue[0] + if op.inst == Instruction.put_packet: + fused = False + for next_op in op.next: + if next_op.inst == Instruction.signal: + remove_op(next_op) + fused = True + break + if fused: + continue + elif op.inst == Instruction.reduce_packet or op.inst == Instruction.copy_packet: + fused = False + for prev_op in op.prev: + if prev_op.inst == Instruction.wait: + remove_op(prev_op) + fused = True + break + if fused: + continue + queue = queue[1:] + + # rrc(_,_,_,dst,dbuf,di) rrc(_,_,_,dst,dbuf,di) -> rrc(list[src,sbuf,si], dst, dbuf, di) + # signal(_,_,_,dst,dbuf,di) signal(_,_,_,dst,dbuf,di) -> signal(_,_,_,list[dst,dbuf,di]) + # wait(src,sbuf,si,_,_,_) wait(src,sbuf,si,_,_,_) -> wait(list[src,sbuf,si],_,_,_,_]) + # reduce(_,_,_,dst,dbuf,di) reduce(_,_,_,dst,dbuf,di) -> reduce(list[src,sbuf,si], dst, dbuf, di) + # reduce_packet(_,_,_,dst,dbuf,di) reduce_packet(_,_,_,dst,dbuf,di) -> reduce_packet(list[src,sbuf,si], dst, dbuf, di) + def _optimize_rrc_r_signal_wait(self): + for rank, rank_tbs in enumerate(self.tbs): + for tbid, tb in rank_tbs.items(): + queue = list(tb.ops) + while len(queue) > 0: + op = queue[0] + if op.inst == Instruction.read_reduce_copy: + fused = False + for next_op in op.next: + if ( + next_op.inst == Instruction.read_reduce_copy + and same_count(op, next_op) + and same_buf_dst(op, next_op) + and same_chan_type(op, next_op) + ): + op.srcs.append( + ( + ChunkRef( + next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size + ), + next_op.step, + ) + ) + remove_op(next_op) + tb.ops.remove(next_op) + queue.remove(next_op) + fused = True + break + if fused: + continue + elif op.inst == Instruction.reduce: + fused = False + for next_op in op.next: + if ( + next_op.inst == Instruction.reduce + and same_buf_dst(op, next_op) + and same_chan_type(op, next_op) + ): + op.srcs.append( + ( + ChunkRef( + next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size + ), + next_op.step, + ) + ) + remove_op(next_op) + tb.ops.remove(next_op) + queue.remove(next_op) + fused = True + break + if fused: + continue + elif op.inst == Instruction.reduce_packet: + fused = False + for next_op in op.next: + if ( + next_op.inst == Instruction.reduce_packet + and same_buf_dst(op, next_op) + and same_chan_type(op, next_op) + ): + op.srcs.append( + ( + ChunkRef( + next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size + ), + next_op.step, + ) + ) + remove_op(next_op) + tb.ops.remove(next_op) + queue.remove(next_op) + fused = True + break + if fused: + continue + elif op.inst == Instruction.signal: + fused = False + for next_op in op.next: + if ( + next_op.inst == Instruction.signal + and same_buf_src(op, next_op) + and same_chan_type(op, next_op) + ): + op.dsts.append( + ( + ChunkRef( + next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size + ), + next_op.step, + ) + ) + op.srcs.append( + ( + ChunkRef( + next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size + ), + next_op.step, + ) + ) + remove_op(next_op) + tb.ops.remove(next_op) + queue.remove(next_op) + fused = True + break + if fused: + continue + elif op.inst == Instruction.wait: + fused = False + for next_op in op.next: + if ( + next_op.inst == Instruction.wait + and same_buf_dst(op, next_op) + and same_chan_type(op, next_op) + ): + op.srcs.append( + ( + ChunkRef( + next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size + ), + next_op.step, + ) + ) + op.dsts.append( + ( + ChunkRef( + next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size + ), + next_op.step, + ) + ) + remove_op(next_op) + tb.ops.remove(next_op) + queue.remove(next_op) + fused = True + break + if fused: + continue + queue = queue[1:] + + # rrc(_,_,_,dst,dbuf,di) put(dst,dbuf,di,_,_,_) -> rrcs(_,_,_,_,_,_) + # reduce(_,_,_,dst,dbuf,di) put(dst,dbuf,di,_,_,_) -> rs(_,_,_,_,_,_) + def _optimize_rrcs_rs(self): + for rank, rank_tbs in enumerate(self.tbs): + for tbid, tb in rank_tbs.items(): + queue = list(tb.ops) + while len(queue) > 0: + op = queue[0] + if op.inst == Instruction.read_reduce_copy or op.inst == Instruction.read_reduce_copy_send: + fused = False + for next_op in op.next: + if ( + next_op.inst == Instruction.put + and same_count(op, next_op) + and buf_dst_src_match(op, next_op) + and same_chan_type(op, next_op) + ): + if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer: + continue + if op.inst == Instruction.read_reduce_copy: + op.inst = Instruction.read_reduce_copy_send + op.dsts.append( + ( + ChunkRef( + next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size + ), + next_op.step, + ) + ) + remove_op(next_op) + tb.ops.remove(next_op) + queue.remove(next_op) + fused = True + break + if fused: + continue + if op.inst == Instruction.reduce or op.inst == Instruction.reduce_send: + fused = False + for next_op in op.next: + if ( + next_op.inst == Instruction.put + and same_count(op, next_op) + and buf_dst_src_match(op, next_op) + and next_op.channel_type == ChannelType.sm + ): + if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer: + continue + if op.inst == Instruction.reduce: + op.inst = Instruction.reduce_send + op.channel_type = ChannelType.sm + op.dsts.append( + ( + ChunkRef( + next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size + ), + next_op.step, + ) + ) + remove_op(next_op) + tb.ops.remove(next_op) + queue.remove(next_op) + fused = True + break + if fused: + continue + if op.inst == Instruction.reduce_packet or op.inst == Instruction.reduce_send_packet: + fused = False + for next_op in op.next: + if ( + next_op.inst == Instruction.put_packet + and same_count(op, next_op) + and buf_dst_src_match(op, next_op) + and next_op.channel_type == ChannelType.sm + ): + if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer: + continue + if op.inst == Instruction.reduce_packet: + op.inst = Instruction.reduce_send_packet + op.channel_type = ChannelType.sm + op.dsts.append( + ( + ChunkRef( + next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size + ), + next_op.step, + ) + ) + remove_op(next_op) + tb.ops.remove(next_op) + queue.remove(next_op) + fused = True + break + if fused: + continue + queue = queue[1:] + + # get(src, sbuf. si, dst, dbuf, di) get(src, sbuf, si, dst, dbuf, di) -> get(list[src,sbuf,si], list[dst,dbuf,di]) + # put(src, sbuf, si, dst, dbuf, di) put(src, sbuf, si, dst, dbuf, di) -> put(list[src,sbuf,si], list[dst,dbuf,di]) + def _optimize_get_put(self): + for rank, rank_tbs in enumerate(self.tbs): + for tbid, tb in rank_tbs.items(): + queue = list(tb.ops) + while len(queue) > 0: + op = queue[0] + if op.inst == Instruction.get: + fused = False + if len(queue) > 1: + seq_op = queue[1] + if ( + seq_op.inst == Instruction.get + and same_src_dst_buffer_type(op, seq_op) + and same_chan_type(op, seq_op) + and same_count(op, seq_op) + ): + op.dsts.append( + ( + ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), + seq_op.step, + ) + ) + op.srcs.append( + ( + ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), + seq_op.step, + ) + ) + merge_op(op, seq_op) + tb.ops.remove(seq_op) + queue.remove(seq_op) + fused = True + if fused: + continue + elif op.inst == Instruction.put: + fused = False + if len(queue) > 1: + seq_op = queue[1] + if ( + seq_op.inst == Instruction.put + and same_src_dst_buffer_type(op, seq_op) + and same_chan_type(op, seq_op) + and same_count(op, seq_op) + ): + op.dsts.append( + ( + ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), + seq_op.step, + ) + ) + op.srcs.append( + ( + ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), + seq_op.step, + ) + ) + merge_op(op, seq_op) + tb.ops.remove(seq_op) + queue.remove(seq_op) + fused = True + if fused: + continue + elif op.inst == Instruction.put_packet: + fused = False + if len(queue) > 1: + seq_op = queue[1] + if ( + seq_op.inst == Instruction.put_packet + and same_src_dst_buffer_type(op, seq_op) + and same_chan_type(op, seq_op) + and same_count(op, seq_op) + ): + op.dsts.append( + ( + ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), + seq_op.step, + ) + ) + op.srcs.append( + ( + ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), + seq_op.step, + ) + ) + merge_op(op, seq_op) + tb.ops.remove(seq_op) + queue.remove(seq_op) + fused = True + queue = queue[1:] + + # For signal/wait ops, if they are independent of other operations and no other operations in between, + # then merge them into a single signal/wait op + # wait(src,sbuf,si,_,_,_) wait(src,sbuf,si,_,_,_) -> wait(list[src,sbuf,si],_,_,_,_]) + def _parallel_signal_wait(self): + for rank, rank_tbs in enumerate(self.tbs): + for tbid, tb in rank_tbs.items(): + if tbid == -1: + continue + queue = list(tb.ops) + while len(queue) > 0: + op = queue[0] + if op.inst == Instruction.signal: + fused = False + if len(queue) > 1: + seq_op = queue[1] + if ( + seq_op.inst == Instruction.signal + and same_src_dst_buffer_type(op, seq_op) + and same_chan_type(op, seq_op) + ): + op.dsts.append( + ( + ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), + seq_op.step, + ) + ) + op.srcs.append( + ( + ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), + seq_op.step, + ) + ) + merge_op(op, seq_op) + tb.ops.remove(seq_op) + queue.remove(seq_op) + fused = True + if fused: + continue + elif op.inst == Instruction.wait: + fused = False + if len(queue) > 1: + seq_op = queue[1] + if ( + seq_op.inst == Instruction.wait + and same_src_dst_buffer_type(op, seq_op) + and same_chan_type(op, seq_op) + ): + op.dsts.append( + ( + ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), + seq_op.step, + ) + ) + op.srcs.append( + ( + ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), + seq_op.step, + ) + ) + merge_op(op, seq_op) + tb.ops.remove(seq_op) + queue.remove(seq_op) + fused = True + if fused: + continue + queue = queue[1:] + + def _get_tb_step(self, rank: int, tb: int): + if tb in self.tb_steps[rank]: + self.tb_steps[rank][tb] += 1 + return self.tb_steps[rank][tb] + else: + self.tb_steps[rank][tb] = 0 + return 0 + + def optimize(self): + self._optimize_redundant_signal_wait() + self._optimize_rrc_r_signal_wait() + self._optimize_rrcs_rs() + self._optimize_get_put() + + self._parallel_signal_wait() + + def replicate(self, instances: int, instance_policy: InstancePolicy): + # update op step + for rank, rank_tbs in enumerate(self.tbs): + for _, tb in rank_tbs.items(): + for id, op in enumerate(tb.ops): + op.step = id + + if instances == 1: + self.instanced_tbs = self.tbs + return + + self.instanced_tbs = [] + for _ in range(self.num_ranks): + self.instanced_tbs.append({}) + + def is_scratch(buffer): + return buffer != Buffer.input and buffer != Buffer.output + + def get_new_index(rank, buffer, index, size, i): + # Scratch buffers always use batched + if is_scratch(buffer): + buf_instance_len = self.buffers[rank][buffer].instance_size() + return buf_instance_len * i + index + return len(self.buffers[rank][buffer]) * i + index + + def get_instance_ref(ref): + iindex = get_new_index(ref.rank, ref.buffer, ref.index, ref.size, i) + iref = ChunkRef(ref.rank, ref.buffer, iindex, ref.size) + return iref + + if instance_policy == InstancePolicy.duplicated: + for i in range(instances): + # Generate all the threadblocks and ops + for rank, rank_tbs in enumerate(self.tbs): + # rank_channels = self.num_channels[rank] + for tbid, tb in rank_tbs.items(): + itbid = tbid * instances + i + itb = Threadblock(id=itbid) + itb.ops = [None] * len(tb.ops) + for s, op in enumerate(tb.ops): + isrc = get_instance_ref(op.src) + idst = get_instance_ref(op.dst) + idepends = [] + # Note: We don't need the fill out the rest of the metadata since replication is the last optimization + iop = Op( + op.inst, op.rank, isrc, idst, idepends, op.step, itbid, channel_type=op.channel_type + ) + itb.ops[s] = iop + for src, step in op.srcs: + isrc = get_instance_ref(src) + iop.srcs.append((isrc, step)) + for dst, step in op.dsts: + idst = get_instance_ref(dst) + iop.dsts.append((idst, step)) + for chan in tb.channels: + itb.channels.append(chan) + self.instanced_tbs[op.rank][itbid] = itb + + # Redo dependency analysis + for rank, rank_tbs in enumerate(self.tbs): + for tbid, tb in rank_tbs.items(): + for i in range(instances): + itbid = tbid * instances + i + itb = self.instanced_tbs[rank][itbid] + for op, iop in zip(tb.ops, itb.ops): + iop.depends = [None] * len(op.depends) + for s, dep in enumerate(op.depends): + dep_tbid = dep.tb + dep_itbid = dep_tbid * instances + i + dep_step = dep.step + iop.depends[s] = self.instanced_tbs[op.rank][dep_itbid].ops[dep_step] diff --git a/msccl/language/ir_mscclpp.py b/msccl/language/mscclpp/ir.py similarity index 98% rename from msccl/language/ir_mscclpp.py rename to msccl/language/mscclpp/ir.py index 95cc9ec..28df6f5 100644 --- a/msccl/language/ir_mscclpp.py +++ b/msccl/language/mscclpp/ir.py @@ -1,8 +1,11 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from collections import defaultdict from dataclasses import dataclass import json -from msccl.language.ir import Buffer, ChannelType, Instruction, Op, Program +from msccl.language.types import Buffer, ChannelType, Op, Program, MscclppInstruction as Instruction _local_src_insts_mscclpp = { Instruction.put, diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py deleted file mode 100755 index 119c85a..0000000 --- a/msccl/language/rank_dag.py +++ /dev/null @@ -1,872 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from msccl.language.channel import Channel -from msccl.language.ir import * -from msccl.language.passes import * - -def remove_op(op): - for p in op.prev: - p.next.remove(op) - p.next += op.next - - for n in op.next: - n.prev.remove(op) - n.prev = op.prev.union(n.prev) - -def merge_op(op, other_op): - for p in other_op.prev: - p.next.remove(other_op) - p.next.append(op) - - for n in other_op.next: - n.prev.remove(other_op) - n.prev.add(op) - - op.prev = op.prev.union(other_op.prev) - op.next += other_op.next - -def same_tb(op1, op2): - return op1.tb == op2.tb and op1.channel == op2.channel - -def same_count(op1, op2): - return op1.cnt() == op2.cnt() - -def same_buf_dst(op1, op2): - return op1.dst.buffer == op2.dst.buffer and op1.dst.index == op2.dst.index - -def same_src_dst_buffer_type(op1, op2): - return op1.src.buffer == op2.src.buffer and op1.dst.buffer == op2.dst.buffer - -def buf_dst_src_match(op1, op2): - return op1.dst.buffer == op2.src.buffer and op1.dst.index == op2.src.index - -def same_buf_src(op1, op2): - return op1.src.buffer == op2.src.buffer and op1.src.index == op2.src.index - -def same_chan_type(op1, op2): - return op1.channel_type == op2.channel_type - -# TODO:(binyli): Need to treat it as base class. For MSCCLPP/MSCCL implement different methods -class InstructionDAG: - def __init__(self, num_ranks, buffers): - self.num_ranks = num_ranks - self.buffers = buffers - # State for the actual instruction DAG - self.operations = {} # slot -> operations - self.last_writer = {} # slot -> last writing op - self.last_readers = defaultdict(list) # slot -> list of last reading ops - # State for the MSCCL-IR - self.tbs = [] - for _ in range(num_ranks): - self.tbs.append({}) - self.tb_mapping = {} - self.num_channels = [1] * num_ranks - self.tb_steps = [{} for _ in range(num_ranks)] - - # InstructionDAG helper - identifies the dependencies for a write-type operation (recv, copy, rrc, reduce) - def _write(self, rank, buffer, index, size, op, read=False): - prev_ops = set() - for i in range(index, index+size): - slot = (rank, buffer, i) - if read: - assert slot in self.last_writer, f"Destination slot has never been written before a reduce {op}" - - # First write to this slot - if slot not in self.operations: - self.operations[slot] = op - - # If there are active readers - these are the previous operations - # Else the previous operation is the last write (if there is one) - readers = self.last_readers[slot] - if len(readers) > 0: - prev_ops.update(readers) - elif slot in self.last_writer: - prev_ops.add(self.last_writer[slot]) - - # Set the last_writer to this op, and clear all readers - self.last_writer[slot] = op - self.last_readers[slot] = [] - - # Update the next pointer of the previous ops - for prev_op in prev_ops: - prev_op.next.add(op) - op.prev.add(prev_op) - - # InstructionDAG helper - identifies the dependencies for read-type operations (send, copy, reduce) - def _read(self, rank, buffer, index, size, op): - prev_ops = set() - for i in range(index, index+size): - slot = (rank, buffer, i) - assert slot in self.last_writer, f"Slot has never been written before a read-type {op}" - # The previous operation for a reader is the last write to the slot - writer = self.last_writer[slot] - prev_ops.add(writer) - self.last_readers[slot].append(op) - - # Update the next pointer of the previous ops - for prev_op in prev_ops: - prev_op.next.add(op) - op.prev.add(prev_op) - - # InstructionDAG - builds the roots of the DAG - def add_start(self, rank, buffer, index, ref): - slot = (rank, buffer, index) - op = Op(Instruction.start, rank, ref, ref, next=set(), prev=set(), chunk_step=-1) - self.operations[slot] = op - self.last_writer[slot] = op - - # InstructionDAG - adds a copy node - def add_copy(self, rank, send_ref, recv_ref, tb, ch): - op = Op(Instruction.copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) - dstbuffer = recv_ref.buffer - dstindex = recv_ref.index - srcbuffer = send_ref.buffer - srcindex = send_ref.index - size = recv_ref.size - # Sending part of copy [Read] - self._read(rank, srcbuffer, srcindex, size, op) - # Receiving part of copy [Write] - self._write(rank, dstbuffer, dstindex, size, op) - return op - - def add_copy_mscclpp(self, rank, send_ref, recv_ref, tb, use_packet = False): - tb_step = self._get_tb_step(rank, tb) - if use_packet: - op = Op(Instruction.copy_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) - else: - op = Op(Instruction.copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) - dstbuffer = recv_ref.buffer - dstindex = recv_ref.index - srcbuffer = send_ref.buffer - srcindex = send_ref.index - size = recv_ref.size - # Sending part of copy [Read] - self._read(rank, srcbuffer, srcindex, size, op) - # Receiving part of copy [Write] - self._write(rank, dstbuffer, dstindex, size, op) - return op - - # InstructionDAG - adds a redduce node - def add_reduce(self, rank, send_ref, recv_ref, tb, ch): - op = Op(Instruction.reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) - dstbuffer = recv_ref.buffer - dstindex = recv_ref.index - srcbuffer = send_ref.buffer - srcindex = send_ref.index - size = recv_ref.size - prev_ops = [] - # Sending part of reduce - self._read(rank, srcbuffer, srcindex, size, op) - # Reduce part of copy - self._write(rank, dstbuffer, dstindex, size, op, read=True) - return op - - # InstructionDAG - adds a redduce node - def add_reduce_mscclpp(self, rank, send_ref, recv_ref, tb, use_packet = False): - tb_step = self._get_tb_step(rank, tb) - if use_packet: - op = Op(Instruction.reduce_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) - else: - op = Op(Instruction.reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) - dstbuffer = recv_ref.buffer - dstindex = recv_ref.index - srcbuffer = send_ref.buffer - srcindex = send_ref.index - size = recv_ref.size - op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) - # Sending part of reduce - self._read(rank, srcbuffer, srcindex, size, op) - # Reduce part of copy - self._write(rank, dstbuffer, dstindex, size, op, read=True) - return op - - # InstructionDAG - adds a send node - def add_send(self, rank, send_ref, recv_ref, tb, ch): - op = Op(Instruction.send, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) - buffer = send_ref.buffer - index = send_ref.index - size = send_ref.size - self._read(rank, buffer, index, size, op) - return op - - # InstructionDAG - adds a put node - def add_put(self, rank, send_ref, recv_ref, tb, ch_type, use_packet = False): - tb_step = self._get_tb_step(rank, tb) - if use_packet: - op = Op( - Instruction.put_packet, - rank, - send_ref, - recv_ref, - next=set(), - prev=set(), - tb=tb, - channel_type=ch_type, - step=tb_step, - ) - else: - op = Op( - Instruction.put, - rank, - send_ref, - recv_ref, - next=set(), - prev=set(), - tb=tb, - channel_type=ch_type, - step=tb_step, - ) - buffer = send_ref.buffer - index = send_ref.index - size = send_ref.size - self._read(rank, buffer, index, size, op) - op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step)) - op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) - return op - - def add_get(self, rank, send_ref, recv_ref, tb, ch_type): - tb_step = self._get_tb_step(rank, tb) - op = Op(Instruction.get, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step) - buffer = recv_ref.buffer - index = recv_ref.index - size = recv_ref.size - self._write(rank, buffer, index, size, op) - op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) - op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step)) - return op - - # InstructionDAG - adds a signal node. - def add_signal(self, rank, send_ref, recv_ref, tb, ch_type): - tb_step = self._get_tb_step(rank, tb) - op = Op(Instruction.signal, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step) - buffer = send_ref.buffer - index = send_ref.index - size = send_ref.size - # treat signal as a write since it can not be executed parallelly with read operations - self._write(rank, buffer, index, size, op) - op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step)) - op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) - return op - - def add_wait(self, rank, dst_ref, src_ref, tb, ch_type): - tb_step = self._get_tb_step(rank, tb) - op = Op(Instruction.wait, rank, src_ref, dst_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step) - buffer = dst_ref.buffer - index = dst_ref.index - size = dst_ref.size - self._write(rank, buffer, index, size, op) - op.srcs.append((ChunkRef(src_ref.rank, src_ref.buffer, src_ref.index, src_ref.size), tb_step)) - op.dsts.append((ChunkRef(dst_ref.rank, dst_ref.buffer, dst_ref.index, dst_ref.size), tb_step)) - return op - - # InstructionDAG - adds a recv node - def add_recv(self, rank, send_ref, recv_ref, tb, ch, send_op): - op = Op(Instruction.recv, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) - buffer = recv_ref.buffer - index = recv_ref.index - size = recv_ref.size - self._write(rank, buffer, index, size, op) - op.send_match = send_op - return op - - # InstructionDAG - adds a rrc node - def add_recv_reduce_copy(self, rank, send_ref, recv_ref, tb, ch, send_op): - op = Op(Instruction.recv_reduce_copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) - buffer = recv_ref.buffer - index = recv_ref.index - size = recv_ref.size - self._write(rank, buffer, index, size, op, read=True) - op.send_match = send_op - return op - - def add_read_reduce(self, rank, send_ref, recv_ref, tb, ch_type): - tb_step = self._get_tb_step(rank, tb) - op = Op(Instruction.read_reduce_copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step) - buffer = recv_ref.buffer - index = recv_ref.index - size = recv_ref.size - op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) - self._write(rank, buffer, index, size, op, read=True) - return op - - def convert_set_list(self): - ops = [] - visited = set() - for slot, op in self.operations.items(): - if op.inst == Instruction.start: - op.next = list(op.next) - for o in op.next: - ops.append(o) - elif op.inst != Instruction.copy: - ops.append(op) - - while len(ops) > 0: - op = ops[0] - if op not in visited: - visited.add(op) - op.next = list(op.next) - ops = ops[1:] + op.next - else: - ops = ops[1:] - return visited - - def optimize(self): - self._optimize_rrcs_rrs() - self._optimize_rcs() - - def complete_channels(self): - send_op = [Instruction.put, Instruction.signal, Instruction.put_packet] - recv_op = [Instruction.wait, Instruction.get, Instruction.read_reduce_copy] - for rank, rank_tbs in enumerate(self.tbs): - for tbid, tb in rank_tbs.items(): - chans = set() - for op in tb.ops: - src_buffer = Buffer.scratch if op.src.buffer is not Buffer.input and op.src.buffer is not Buffer.output else op.src.buffer - dst_buffer = Buffer.scratch if op.dst.buffer is not Buffer.input and op.dst.buffer is not Buffer.output else op.dst.buffer - if op.inst in send_op: - chan = Channel(src_buffer, dst_buffer, op.channel_type, op.dst.rank) - chans.add(chan) - elif op.inst in recv_op: - chan = Channel(src_buffer, dst_buffer, op.channel_type, op.src.rank) - chans.add(chan) - tb.channels = list(chans) - - def _optimize_redundant_signal_wait(self): - # For packet ops, we can remove signal/wait - for rank, rank_tbs in enumerate(self.tbs): - for tbid, tb in rank_tbs.items(): - queue = list(tb.ops) - while len(queue) > 0: - op = queue[0] - if op.inst == Instruction.put_packet: - fused = False - for next_op in op.next: - if next_op.inst == Instruction.signal: - remove_op(next_op) - fused = True - break - if fused: - continue - elif op.inst == Instruction.reduce_packet or op.inst == Instruction.copy_packet: - fused = False - for prev_op in op.prev: - if prev_op.inst == Instruction.wait: - remove_op(prev_op) - fused = True - break - if fused: - continue - queue = queue[1:] - - # rrc(_,_,_,dst,dbuf,di) rrc(_,_,_,dst,dbuf,di) -> rrc(list[src,sbuf,si], dst, dbuf, di) - # signal(_,_,_,dst,dbuf,di) signal(_,_,_,dst,dbuf,di) -> signal(_,_,_,list[dst,dbuf,di]) - # wait(src,sbuf,si,_,_,_) wait(src,sbuf,si,_,_,_) -> wait(list[src,sbuf,si],_,_,_,_]) - # reduce(_,_,_,dst,dbuf,di) reduce(_,_,_,dst,dbuf,di) -> reduce(list[src,sbuf,si], dst, dbuf, di) - # reduce_packet(_,_,_,dst,dbuf,di) reduce_packet(_,_,_,dst,dbuf,di) -> reduce_packet(list[src,sbuf,si], dst, dbuf, di) - def _optimize_rrc_r_signal_wait(self): - for rank, rank_tbs in enumerate(self.tbs): - for tbid, tb in rank_tbs.items(): - queue = list(tb.ops) - while len(queue) > 0: - op = queue[0] - if op.inst == Instruction.read_reduce_copy: - fused = False - for next_op in op.next: - if next_op.inst == Instruction.read_reduce_copy and same_count(op, next_op) and same_buf_dst(op, next_op) and same_chan_type(op, next_op): - op.srcs.append((ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size), next_op.step)) - remove_op(next_op) - tb.ops.remove(next_op) - queue.remove(next_op) - fused = True - break - if fused: - continue - elif op.inst == Instruction.reduce: - fused = False - for next_op in op.next: - if next_op.inst == Instruction.reduce and same_buf_dst(op, next_op) and same_chan_type(op, next_op): - op.srcs.append((ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size), next_op.step)) - remove_op(next_op) - tb.ops.remove(next_op) - queue.remove(next_op) - fused = True - break - if fused: - continue - elif op.inst == Instruction.reduce_packet: - fused = False - for next_op in op.next: - if next_op.inst == Instruction.reduce_packet and same_buf_dst(op, next_op) and same_chan_type(op, next_op): - op.srcs.append((ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size), next_op.step)) - remove_op(next_op) - tb.ops.remove(next_op) - queue.remove(next_op) - fused = True - break - if fused: - continue - elif op.inst == Instruction.signal: - fused = False - for next_op in op.next: - if next_op.inst == Instruction.signal and same_buf_src(op, next_op) and same_chan_type(op, next_op): - op.dsts.append((ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size), next_op.step)) - op.srcs.append((ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size), next_op.step)) - remove_op(next_op) - tb.ops.remove(next_op) - queue.remove(next_op) - fused = True - break - if fused: - continue - elif op.inst == Instruction.wait: - fused = False - for next_op in op.next: - if next_op.inst == Instruction.wait and same_buf_dst(op, next_op) and same_chan_type(op, next_op): - op.srcs.append((ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size), next_op.step)) - op.dsts.append((ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size),next_op.step)) - remove_op(next_op) - tb.ops.remove(next_op) - queue.remove(next_op) - fused = True - break - if fused: - continue - queue = queue[1:] - - # rrc(_,_,_,dst,dbuf,di) put(dst,dbuf,di,_,_,_) -> rrcs(_,_,_,_,_,_) - # reduce(_,_,_,dst,dbuf,di) put(dst,dbuf,di,_,_,_) -> rs(_,_,_,_,_,_) - def _optimize_rrcs_rs(self): - for rank, rank_tbs in enumerate(self.tbs): - for tbid, tb in rank_tbs.items(): - queue = list(tb.ops) - while len(queue) > 0: - op = queue[0] - if op.inst == Instruction.read_reduce_copy or op.inst == Instruction.read_reduce_copy_send: - fused = False - for next_op in op.next: - if next_op.inst == Instruction.put and same_count(op, next_op) and buf_dst_src_match(op, next_op) and same_chan_type(op, next_op): - if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer: - continue - if op.inst == Instruction.read_reduce_copy: - op.inst = Instruction.read_reduce_copy_send - op.dsts.append((ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size), next_op.step)) - remove_op(next_op) - tb.ops.remove(next_op) - queue.remove(next_op) - fused = True - break - if fused: - continue - if op.inst == Instruction.reduce or op.inst == Instruction.reduce_send: - fused = False - for next_op in op.next: - if next_op.inst == Instruction.put and same_count(op, next_op) and buf_dst_src_match(op, next_op) and next_op.channel_type == ChannelType.sm: - if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer: - continue - if op.inst == Instruction.reduce: - op.inst = Instruction.reduce_send - op.channel_type = ChannelType.sm - op.dsts.append((ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size), next_op.step)) - remove_op(next_op) - tb.ops.remove(next_op) - queue.remove(next_op) - fused = True - break - if fused: - continue - if op.inst == Instruction.reduce_packet or op.inst == Instruction.reduce_send_packet: - fused = False - for next_op in op.next: - if next_op.inst == Instruction.put_packet and same_count(op, next_op) and buf_dst_src_match(op, next_op) and next_op.channel_type == ChannelType.sm: - if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer: - continue - if op.inst == Instruction.reduce_packet: - op.inst = Instruction.reduce_send_packet - op.channel_type = ChannelType.sm - op.dsts.append((ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size), next_op.step)) - remove_op(next_op) - tb.ops.remove(next_op) - queue.remove(next_op) - fused = True - break - if fused: - continue - queue = queue[1:] - - # get(src, sbuf. si, dst, dbuf, di) get(src, sbuf, si, dst, dbuf, di) -> get(list[src,sbuf,si], list[dst,dbuf,di]) - # put(src, sbuf, si, dst, dbuf, di) put(src, sbuf, si, dst, dbuf, di) -> put(list[src,sbuf,si], list[dst,dbuf,di]) - def _optimize_get_put(self): - for rank, rank_tbs in enumerate(self.tbs): - for tbid, tb in rank_tbs.items(): - queue = list(tb.ops) - while len(queue) > 0: - op = queue[0] - if op.inst == Instruction.get: - fused = False - if len(queue) > 1: - seq_op = queue[1] - if seq_op.inst == Instruction.get and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op) and same_count(op, seq_op): - op.dsts.append((ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), seq_op.step)) - op.srcs.append((ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), seq_op.step)) - merge_op(op, seq_op) - tb.ops.remove(seq_op) - queue.remove(seq_op) - fused = True - if fused: - continue - elif op.inst == Instruction.put: - fused = False - if len(queue) > 1: - seq_op = queue[1] - if seq_op.inst == Instruction.put and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op) and same_count(op, seq_op): - op.dsts.append((ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), seq_op.step)) - op.srcs.append((ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), seq_op.step)) - merge_op(op, seq_op) - tb.ops.remove(seq_op) - queue.remove(seq_op) - fused = True - if fused: - continue - elif op.inst == Instruction.put_packet: - fused = False - if len(queue) > 1: - seq_op = queue[1] - if seq_op.inst == Instruction.put_packet and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op) and same_count(op, seq_op): - op.dsts.append((ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), seq_op.step)) - op.srcs.append((ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), seq_op.step)) - merge_op(op, seq_op) - tb.ops.remove(seq_op) - queue.remove(seq_op) - fused = True - queue = queue[1:] - - # For signal/wait ops, if they are independent of other operations and no other operations in between, - # then merge them into a single signal/wait op - # wait(src,sbuf,si,_,_,_) wait(src,sbuf,si,_,_,_) -> wait(list[src,sbuf,si],_,_,_,_]) - def _parallel_signal_wait(self): - for rank, rank_tbs in enumerate(self.tbs): - for tbid, tb in rank_tbs.items(): - if tbid == -1: - continue - queue = list(tb.ops) - while len(queue) > 0: - op = queue[0] - if op.inst == Instruction.signal: - fused = False - if len(queue) > 1: - seq_op = queue[1] - if seq_op.inst == Instruction.signal and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op): - op.dsts.append((ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), seq_op.step)) - op.srcs.append((ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), seq_op.step)) - merge_op(op, seq_op) - tb.ops.remove(seq_op) - queue.remove(seq_op) - fused = True - if fused: - continue - elif op.inst == Instruction.wait: - fused = False - if len(queue) > 1: - seq_op = queue[1] - if seq_op.inst == Instruction.wait and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op): - op.dsts.append((ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), seq_op.step)) - op.srcs.append((ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), seq_op.step)) - merge_op(op, seq_op) - tb.ops.remove(seq_op) - queue.remove(seq_op) - fused = True - if fused: - continue - queue = queue[1:] - - def optimize_mscclpp(self): - self._optimize_redundant_signal_wait() - self._optimize_rrc_r_signal_wait() - self._optimize_rrcs_rs() - self._optimize_get_put() - - self._parallel_signal_wait() - - # Completes metadata for chunk_steps (number of steps from a start op) and priority (number of steps to the last op) - def _complete_metadata(self): - def dfs(op, cs): - op.chunk_step = max(op.chunk_step, cs+1) - - if len(op.next) == 0 and op.recv_match is None: - op.priority = 0 - else: - for o in op.next: - dfs(o, op.chunk_step) - # Priority = +1 of the highest priority child - if len(op.next) > 0: - highest_next_priority = max([x.priority+1 for x in op.next]) - op.priority = max(highest_next_priority, op.priority) - if op.is_send(): - dfs(op.recv_match, op.chunk_step) - op.priority = max(op.priority, op.recv_match.priority+1) - - for chunk, op in self.operations.items(): - if op.inst == Instruction.start: - dfs(op,-2) # Start instructions should start at -1 - - # Given the set of operations that operate over a particular slot (rank, buffer, idx) fixed - # Try and replace operations with pipelined ops like receive copy send (rcs) - # or receive reduce send (rrs) and receive reduce copy send (rrcs) - # Rules: - # recv-copy-send - # recv(src, sbuf, si, _, _, _ ) send(_, _, _, dst, dbuf, di) -> recv_copy_send(src, sbuf, si, dst, dbuf, di) - def _optimize_rcs(self): - for slot, ops in self.operations.items(): - frontier = [ops] - while len(frontier) > 0: - op = frontier[0] - for next_op in op.next: - if op.inst == Instruction.recv and next_op.inst == Instruction.send and same_tb(op, next_op) and same_count(op, next_op) and same_buf_dst(op, next_op): - # recv -> rcs, remove send - op.inst = Instruction.recv_copy_send - op.dst = next_op.dst - next_op.recv_match.send_match = op - op.recv_match = next_op.recv_match - remove_op(next_op) - break - frontier = frontier[1:] + op.next - # recv-reduce-send - A rrc followed by a send that gets overwritten - # rrc(src, sbuf, si, ...) send(_, _, _, dst, dbuf, di) recv(_, _, _, dst, dbuf, di) - # recv-reduce-copy-send - A rrc followed by a send that does not get overwritten - # rrc(src, sbuf, si, ...) send(_, _, _, dst, dbuf, di) - def _optimize_rrcs_rrs(self): - # RRC/S -> RRS - for slot, ops in self.operations.items(): - frontier = [ops] - while len(frontier) > 0: - op = frontier[0] - if len(op.next) == 1: - next_op = op.next[0] - if len(next_op.next) == 1: - nnext_op = next_op.next[0] - if op.inst == Instruction.recv_reduce_copy and next_op.inst == Instruction.send and nnext_op.inst is Instruction.recv and same_tb(op, next_op) and same_count(op, next_op) and same_buf_dst(op, next_op): - op.inst = Instruction.recv_reduce_send - op.dst = next_op.dst - next_op.recv_match.send_match = op - op.recv_match = next_op.recv_match - remove_op(next_op) - - if op.inst == Instruction.recv_reduce_copy and next_op.inst == Instruction.send and same_tb(op, next_op) and same_count(op, next_op) and same_buf_dst(op, next_op): - op.inst = Instruction.recv_reduce_copy_send - op.dst = next_op.dst - next_op.recv_match.send_match = op - op.recv_match = next_op.recv_match - remove_op(next_op) - frontier = frontier[1:] + op.next - - def _get_tb_step(self, rank, tb): - if tb in self.tb_steps[rank]: - self.tb_steps[rank][tb] += 1 - return self.tb_steps[rank][tb] - else: - self.tb_steps[rank][tb] = 0 - return 0 - - def lower_pt1(self, instances): - self.infer_dependencies() - self.lower_buffers(instances) - - def lower_pt2(self, instances, interleaved): - self.replicate(instances, interleaved) - return self.lower_tbs() - - def lower_pt2_mscclpp(self, instances, instance_pollicy): - self.replicate_mscclpp(instances, instance_pollicy) - return self.lower_tbs() - - def infer_dependencies(self): - for slot, ops in self.operations.items(): - frontier = [ops] - while len(frontier) > 0: - op = frontier[0] - # Dependencies for every op is the same as the ops that are stored in prev - # Filter out dependencies that are satisified by tbs executing ops sequentially - # If multiple dependent ops from the same tb keep the one that happens last - depends = {} - for dep_op in list(op.prev): - if dep_op.inst != Instruction.start: - tb = dep_op.tb - if tb not in depends or dep_op.step > depends[tb].step: - depends[tb] = dep_op - op.depends = list(depends.values()) - frontier = frontier[1:] + op.next - - # Convert local scratch buffers to index into one global scratch buffer - def lower_chunk(self, chunk): - if chunk is not None and chunk.buffer is not Buffer.input and chunk.buffer is not Buffer.output: - buffer = self.buffers[chunk.rank][chunk.buffer].get_buffer() - index = self.buffers[chunk.rank][chunk.buffer].get_global_index(chunk.index) - return ChunkRef(chunk.rank, buffer, index, chunk.size) - return chunk - - # Assigns each scratch buffer an offset into the global scratch buffer - def lower_buffers(self, instances): - for rank_buffers in self.buffers: - offset = 0 - for key, buf in rank_buffers.items(): - if key is not Buffer.input and key is not Buffer.output: - buf.set_offset(offset) - offset += buf.instance_size() * instances - - # Preprocess the threadblocks for lowering into xml - def lower_tbs(self): - gpus = [] - for rank, rank_tbs in enumerate(self.instanced_tbs): - lowered_tbs = {} - for tbid, tb in rank_tbs.items(): - for op in tb.ops: - op.src = self.lower_chunk(op.src) - op.dst = self.lower_chunk(op.dst) - srcs = sorted(op.srcs, key=lambda x: x[1]) - dsts = sorted(op.dsts, key=lambda x: x[1]) - op.srcs = [self.lower_chunk(src[0]) for src in srcs] - op.dsts = [self.lower_chunk(dst[0]) for dst in dsts] - lowered_tbs[tbid] = tb - gpus.append(Gpu(rank, list(lowered_tbs.values()))) - return gpus - - # Automatically replicates the algorithm instance number of times - # interleaved sets the replication policy - # if True chunks are split as: ChunkA ChunkB -> ChunkA0 ChunkA1 .. ChunkB0 ChunkB1 ... - # if false chunks are divided as ChunkA0 ChunkB0 ChunkA1 ChunkB1 ... - # For collectives were chunks are designated for a particular GPU (e.g. AllToAll) - # only interleaved replication will be correct - # Interleaved policy only supports single count sends/receives from the input/output buffer - # (multicount ops are fine between scratch) - def replicate(self, instances, interleaved): - if instances == 1: - self.instanced_tbs = self.tbs - return - - self.instanced_tbs = [] - for _ in range(self.num_ranks): - self.instanced_tbs.append({}) - - def is_scratch(buffer): - return buffer != Buffer.input and buffer != Buffer.output - - def get_new_index(rank, buffer, index, size, i): - # Scratch buffers always use batched - if is_scratch(buffer): - buf_instance_len = self.buffers[rank][buffer].instance_size() - return buf_instance_len * i + index - # If this is operating on the input/output buffer then replication strategy can be either interleaved or batched - # This is to fit with the semantics of certain collectives - elif interleaved: - return index * instances + i * size - else: - return len(self.buffers[rank][buffer]) * i + index - - def get_instance_ref(ref): - iindex = get_new_index(ref.rank, ref.buffer, ref.index, ref.size, i) - iref = ChunkRef(ref.rank, ref.buffer, iindex, ref.size) - return iref - - max_channels = max(self.num_channels) - for i in range(instances): - # Generate all the threadblocks and ops - for rank, rank_tbs in enumerate(self.tbs): - # rank_channels = self.num_channels[rank] - for tbid, tb in rank_tbs.items(): - instance_channel = max_channels * i + tb.channel - itb = Threadblock(instance_channel, tb.send, tb.recv) - itbid = tbid * instances + i - itb.ops = [None] * len(tb.ops) - for s, op in enumerate(tb.ops): - isrc = get_instance_ref(op.src) - idst = get_instance_ref(op.dst) - idepends = [] - # Note: We don't need the fill out the rest of the metadata since replication is the last optimization - iop = Op(op.inst, op.rank, isrc, idst, idepends, op.step, itbid) - itb.ops[s] = iop - self.instanced_tbs[op.rank][itbid] = itb - - # Redo dependency analysis - for rank, rank_tbs in enumerate(self.tbs): - for tbid, tb in rank_tbs.items(): - for i in range(instances): - itbid = tbid * instances + i - itb = self.instanced_tbs[rank][itbid] - for op, iop in zip(tb.ops, itb.ops): - iop.depends = [None] * len(op.depends) - for s, dep in enumerate(op.depends): - dep_tbid = dep.tb - dep_itbid = dep_tbid * instances + i - dep_step = dep.step - iop.depends[s] = self.instanced_tbs[op.rank][dep_itbid].ops[dep_step] - - def replicate_mscclpp(self, instances, instance_policy): - # update op step - for rank, rank_tbs in enumerate(self.tbs): - for _, tb in rank_tbs.items(): - for id, op in enumerate(tb.ops): - op.step = id - - if instances == 1: - self.instanced_tbs = self.tbs - return - - self.instanced_tbs = [] - for _ in range(self.num_ranks): - self.instanced_tbs.append({}) - - def is_scratch(buffer): - return buffer != Buffer.input and buffer != Buffer.output - - def get_new_index(rank, buffer, index, size, i): - # Scratch buffers always use batched - if is_scratch(buffer): - buf_instance_len = self.buffers[rank][buffer].instance_size() - return buf_instance_len * i + index - return len(self.buffers[rank][buffer]) * i + index - - def get_instance_ref(ref): - iindex = get_new_index(ref.rank, ref.buffer, ref.index, ref.size, i) - iref = ChunkRef(ref.rank, ref.buffer, iindex, ref.size) - return iref - - if instance_policy == InstancePolicy.dup: - for i in range(instances): - # Generate all the threadblocks and ops - for rank, rank_tbs in enumerate(self.tbs): - # rank_channels = self.num_channels[rank] - for tbid, tb in rank_tbs.items(): - itbid = tbid * instances + i - itb = Threadblock(id=itbid) - itb.ops = [None] * len(tb.ops) - for s, op in enumerate(tb.ops): - isrc = get_instance_ref(op.src) - idst = get_instance_ref(op.dst) - idepends = [] - # Note: We don't need the fill out the rest of the metadata since replication is the last optimization - iop = Op(op.inst, op.rank, isrc, idst, idepends, op.step, itbid, channel_type=op.channel_type) - itb.ops[s] = iop - for src, step in op.srcs: - isrc = get_instance_ref(src) - iop.srcs.append((isrc, step)) - for dst, step in op.dsts: - idst = get_instance_ref(dst) - iop.dsts.append((idst, step)) - for chan in tb.channels: - itb.channels.append(chan) - self.instanced_tbs[op.rank][itbid] = itb - - # Redo dependency analysis - for rank, rank_tbs in enumerate(self.tbs): - for tbid, tb in rank_tbs.items(): - for i in range(instances): - itbid = tbid * instances + i - itb = self.instanced_tbs[rank][itbid] - for op, iop in zip(tb.ops, itb.ops): - iop.depends = [None] * len(op.depends) - for s, dep in enumerate(op.depends): - dep_tbid = dep.tb - dep_itbid = dep_tbid * instances + i - dep_step = dep.step - iop.depends[s] = self.instanced_tbs[op.rank][dep_itbid].ops[dep_step] diff --git a/msccl/language/tb_assignment.py b/msccl/language/tb_assignment.py index 83d0171..f4fb429 100755 --- a/msccl/language/tb_assignment.py +++ b/msccl/language/tb_assignment.py @@ -6,7 +6,7 @@ import heapq from msccl.language.ir import * -from msccl.language.rank_dag import * +from msccl.language.instruction_dag import * def _verify_tb_op_compatible(tb, op): @@ -119,7 +119,7 @@ def priority(op): if op.inst == Instruction.start: visited.add(op) for o in op.next: - if o.inst == Instruction.send or o.inst == Instruction.copy: + if o.inst == MscclInstruction.send or o.inst == MscclInstruction.copy: heapq.heappush(ops, (priority(o), o)) while len(ops) > 0: @@ -206,7 +206,7 @@ def dfs(op, channels, f): # Assign channels to flows for op in instrs: - if op.inst == Instruction.send and op.recv_match.is_fused(): + if op.inst == MscclInstruction.send and op.recv_match.is_fused(): dfs(op, all_channels(), []) # Iterate through and make certain the sends and receives between a pair of GPUs is consistent @@ -233,5 +233,3 @@ def dfs(op, channels, f): op.channel += 1 op.send_match.channel += 1 pr.remove(op) - - diff --git a/msccl/language/types.py b/msccl/language/types.py new file mode 100644 index 0000000..3132bd3 --- /dev/null +++ b/msccl/language/types.py @@ -0,0 +1,240 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from dataclasses import dataclass, field +from enum import Enum +from typing import Union + +from msccl.language.buffer import Buffer +from msccl.language.channel import ChannelType + + +@dataclass +class Program: + name: str + collective: str + inplace: bool + protocol: str + gpus: list = field(default_factory=list) + num_chunk_groups: int = 1 + + +@dataclass +class Gpu: + rank: int + threadblocks: list = field(default_factory=list) + + # From ncclize + precopies: list = field(default_factory=list) + postcopies: list = field(default_factory=list) + inputs: dict = field(default_factory=dict) + outputs: dict = field(default_factory=dict) + input_chunks: int = 0 + output_chunks: int = 0 + scratch_chunks: int = 0 + scratch: dict = field(default_factory=dict) + channels: dict = field(default_factory=dict) + + def scratch_size(self): + return max((idx for addr, idx in self.scratch.items()), default=-1) + 1 + + +@dataclass +class Threadblock: + id: int = -1 + channel: int = -1 + send: int = -1 + recv: int = -1 + ops: list = field(default_factory=list) + rbid: int = -1 # threadblock id of the receiver + channels: list = field(default_factory=list) + + def __eq__(self, other): + return self is other + + def __hash__(self): + return id(self) + + +class ChunkInstruction(Enum): + start = "start" + reduce = "reduce" + send = "send" + + def __str__(self): + return self.value + + +class ThreadblockPolicy(Enum): + auto = "auto" + manual = "manual" + + def __str__(self): + return self.value + + +class InstancePolicy(Enum): + # this means pack multi instrances to deal with the same chunk and share the channels + packed = "packed" + # this means each instance deal with the different chunk + # Chunk A, Chunk B -> Chunk A0, Chunk B0, Chunk A1, Chunk B1 + duplicated = "duplicated" + # this means each instance deal with the different chunk in interleaved way + # Chunk A, Chunk B -> Chunk A0, Chunk A1, Chunk B0, Chunk B1 + interleaved = "interleaved" + + def __str__(self): + return self.value + + +class Instruction(Enum): + delete = "d" + start = "st" + + def __str__(self): + return self.value + + +class MscclInstruction(Enum): + nop = "nop" + send = "s" + recv = "r" + recv_copy_send = "rcs" + recv_reduce_send = "rrs" + recv_reduce_copy = "rrc" + recv_reduce_copy_send = "rrcs" + copy = "cpy" + reduce = "re" + + def __str__(self): + return self.value + + +class MscclppInstruction(Enum): + nop = "nop" + read_reduce_copy = "rrc" + read_reduce_copy_send = "rrcs" + reduce_send = "rs" + copy = "copy" + reduce = "reduce" + copy_packet = "cpkt" + reduce_send_packet = "rspkt" + reduce_packet = "rpkt" + put = "put" + put_packet = "ppkt" + get = "get" + wait = "wait" + signal = "signal" + flush = "flush" + + def __str__(self): + return self.value + + +@dataclass +class ChunkRef: + rank: int + buffer: Buffer + index: int + size: int + + def __hash__(self): + return hash((self.rank, self.buffer, self.index, self.size)) + + +@dataclass +class Op: + inst: Union[Instruction, MscclInstruction, MscclppInstruction] + rank: int + src: ChunkRef + dst: ChunkRef + depends: list = field(default_factory=list) + step: int = -1 # Step in the TB + tb: int = -1 # TB this op is assigned to + prev: list = field(default_factory=list) # List of instructions that happen before + next: list = field(default_factory=list) # List of instructions that happen after + num: int = -1 + chunk_step: int = -1 + priority: int = -1 + recv_match = None + send_match = None + channel: int = -1 + channel_type: ChannelType = ChannelType.none + srcs: list = field(default_factory=list) + dsts: list = field(default_factory=list) + + def cnt(self): + if self.src: + if self.dst: + assert self.src.size == self.dst.size + return self.src.size + elif self.dst: + return self.dst.size + else: + return 0 + + def is_send(self): + return ( + self.inst == MscclInstruction.send + or self.inst == MscclInstruction.recv_reduce_copy_send + or self.inst == MscclInstruction.recv_copy_send + or self.inst == MscclInstruction.recv_reduce_send + ) + + def is_recv(self): + return ( + self.inst == MscclInstruction.recv + or self.inst == MscclInstruction.recv_reduce_copy + or self.inst == MscclInstruction.recv_reduce_copy_send + or self.inst == MscclInstruction.recv_copy_send + or self.inst == MscclInstruction.recv_reduce_send + ) + + def is_fused(self): + return ( + self.inst == MscclInstruction.recv_reduce_copy_send + or self.inst == MscclInstruction.recv_copy_send + or self.inst == MscclInstruction.recv_reduce_send + ) + + def is_local(self): + return self.inst == MscclInstruction.copy or self.inst == MscclInstruction.reduce + + def peer(self): + if self.inst == MscclInstruction.send: + return self.dst.rank + elif self.inst == MscclInstruction.recv: + return self.src.rank + else: + return None + + def send_peer(self): + if self.is_send(): + return self.dst.rank + return -1 + + def recv_peer(self): + if self.is_recv(): + return self.src.rank + return -1 + + def __eq__(self, other): + return self is other + + def __lt__(self, other): + # Ordering of operations + # 1. Lower chunk step 2. Higher priority 3. Lower src index + if self.chunk_step == other.chunk_step: + if self.priority == other.priority: + return self.src.index < other.src.index + return self.priority > other.priority + return self.chunk_step < other.chunk_step + + def __gt__(self, other): + return not self < other + + def __hash__(self): + return id(self) + + def __repr__(self): + return f"Op({self.inst}, {self.rank}, {self.src}, {self.dst}, step:{self.step}, tb:{self.tb})" diff --git a/msccl/language/visualize.py b/msccl/language/visualize.py index e24710a..d385c79 100755 --- a/msccl/language/visualize.py +++ b/msccl/language/visualize.py @@ -3,7 +3,7 @@ import igraph as ig from msccl.language.ir import * -from msccl.language.rank_dag import * +from msccl.language.instruction_dag import * def visualize_chunk_dag(chunk_paths): # pragma: no cover frontier = [] @@ -29,7 +29,7 @@ def add_node(op, nnodes, vertex_label, vertex_colors): return nnodes for chunk, op in chunk_paths.items(): - if len(op.prev) == 0: + if len(op.prev) == 0: frontier.append(op) while len(frontier) > 0: @@ -73,7 +73,7 @@ def add_node(op, nnodes, vertex_label, vertex_colors): else: vertex_label.append(f'{op.inst}') - # Add colors + # Add colors if op.inst == Instruction.start: vertex_colors.append('gray') else: @@ -81,7 +81,7 @@ def add_node(op, nnodes, vertex_label, vertex_colors): return nnodes for slot, op in operations.items(): - if len(op.prev) == 0: + if len(op.prev) == 0: frontier.append(op) while len(frontier) > 0: @@ -100,4 +100,4 @@ def add_node(op, nnodes, vertex_label, vertex_colors): g = ig.Graph(nnodes, edges, directed=True) layout = g.layout(layout=ig.Graph.layout_grid) - ig.plot(g, vertex_label=vertex_label, vertex_color=vertex_colors, layout='rt') \ No newline at end of file + ig.plot(g, vertex_label=vertex_label, vertex_color=vertex_colors, layout='rt') From eb4f612d9b58946488629a37caa8706ed00689b8 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 24 Apr 2024 11:33:19 +0000 Subject: [PATCH 44/76] fix --- .github/workflows/tests.yaml | 2 +- examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 3137151..d854490 100755 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -11,7 +11,7 @@ jobs: strategy: matrix: - python-version: [3.8, 3.9, 3.10] + python-version: ['3.8', '3.9', '3.10'] name: Test with Python ${{ matrix.python-version }} diff --git a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py index b052ec8..1fad562 100644 --- a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py +++ b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py @@ -37,7 +37,7 @@ def allreduce_allpairs(gpus, instances): c = chunk(r, Buffer.input, r * size + index) for peer in range(size): if peer != r: - c.reduce_packet(chunk(r, "scratch", peer * size + index), sendtb=index) + c.reduce_packet(chunk(r, "scratch", peer * size + index), recvtb=index) for peer in range(size): if peer != r: c.put_packet(peer, "scratch", (size * size) + r * size + index, sendtb=index) From 26172809c0eb46a5f30b1036b8e1c0ca856c1da3 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 24 Apr 2024 11:40:15 +0000 Subject: [PATCH 45/76] fix --- msccl/language/instruction_dag.py | 36 ++++++++++++++--------------- msccl/language/ir.py | 2 +- msccl/language/tb_assignment.py | 4 ++-- msccl/language/types.py | 38 +++++++++++++------------------ 4 files changed, 37 insertions(+), 43 deletions(-) diff --git a/msccl/language/instruction_dag.py b/msccl/language/instruction_dag.py index f3e9748..fb4fe81 100755 --- a/msccl/language/instruction_dag.py +++ b/msccl/language/instruction_dag.py @@ -5,7 +5,7 @@ from collections import defaultdict from msccl.language.buffer import Buffer -from msccl.language.types import ChunkRef, Gpu, InstancePolicy, Instruction, MscclInstruction, Op, Threadblock +from msccl.language.types import ChunkRef, Gpu, InstancePolicy, Instruction, Op, Threadblock def remove_op(op: Op): @@ -186,7 +186,7 @@ def convert_set_list(self): op.next = list(op.next) for o in op.next: ops.append(o) - elif op.inst != MscclInstruction.copy: + elif op.inst != Instruction.copy: ops.append(op) while len(ops) > 0: @@ -216,14 +216,14 @@ def replicate(self, instances: int, instance_policy: InstancePolicy): pass -class MscclInstructionDAG(InstructionDAG): +class InstructionDAG(InstructionDAG): def __init__(self, num_ranks, buffers): super().__init__(num_ranks, buffers) # InstructionDAG - adds a copy node def add_copy(self, rank, send_ref, recv_ref, tb, ch): - op = Op(MscclInstruction.copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) + op = Op(Instruction.copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) dstbuffer = recv_ref.buffer dstindex = recv_ref.index srcbuffer = send_ref.buffer @@ -237,7 +237,7 @@ def add_copy(self, rank, send_ref, recv_ref, tb, ch): # InstructionDAG - adds a redduce node def add_reduce(self, rank, send_ref, recv_ref, tb, ch): - op = Op(MscclInstruction.reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) + op = Op(Instruction.reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) dstbuffer = recv_ref.buffer dstindex = recv_ref.index srcbuffer = send_ref.buffer @@ -251,7 +251,7 @@ def add_reduce(self, rank, send_ref, recv_ref, tb, ch): # InstructionDAG - adds a send node def add_send(self, rank, send_ref, recv_ref, tb, ch): - op = Op(MscclInstruction.send, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) + op = Op(Instruction.send, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) buffer = send_ref.buffer index = send_ref.index size = send_ref.size @@ -260,7 +260,7 @@ def add_send(self, rank, send_ref, recv_ref, tb, ch): # InstructionDAG - adds a recv node def add_recv(self, rank, send_ref, recv_ref, tb, ch, send_op): - op = Op(MscclInstruction.recv, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) + op = Op(Instruction.recv, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) buffer = recv_ref.buffer index = recv_ref.index size = recv_ref.size @@ -270,7 +270,7 @@ def add_recv(self, rank, send_ref, recv_ref, tb, ch, send_op): # InstructionDAG - adds a rrc node def add_recv_reduce_copy(self, rank, send_ref, recv_ref, tb, ch, send_op): - op = Op(MscclInstruction.recv_reduce_copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) + op = Op(Instruction.recv_reduce_copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) buffer = recv_ref.buffer index = recv_ref.index size = recv_ref.size @@ -317,14 +317,14 @@ def _optimize_rcs(self): op = frontier[0] for next_op in op.next: if ( - op.inst == MscclInstruction.recv - and next_op.inst == MscclInstruction.send + op.inst == Instruction.recv + and next_op.inst == Instruction.send and same_tb(op, next_op) and same_count(op, next_op) and same_buf_dst(op, next_op) ): # recv -> rcs, remove send - op.inst = MscclInstruction.recv_copy_send + op.inst = Instruction.recv_copy_send op.dst = next_op.dst next_op.recv_match.send_match = op op.recv_match = next_op.recv_match @@ -347,27 +347,27 @@ def _optimize_rrcs_rrs(self): if len(next_op.next) == 1: nnext_op = next_op.next[0] if ( - op.inst == MscclInstruction.recv_reduce_copy - and next_op.inst == MscclInstruction.send - and nnext_op.inst is MscclInstruction.recv + op.inst == Instruction.recv_reduce_copy + and next_op.inst == Instruction.send + and nnext_op.inst is Instruction.recv and same_tb(op, next_op) and same_count(op, next_op) and same_buf_dst(op, next_op) ): - op.inst = MscclInstruction.recv_reduce_send + op.inst = Instruction.recv_reduce_send op.dst = next_op.dst next_op.recv_match.send_match = op op.recv_match = next_op.recv_match remove_op(next_op) if ( - op.inst == MscclInstruction.recv_reduce_copy - and next_op.inst == MscclInstruction.send + op.inst == Instruction.recv_reduce_copy + and next_op.inst == Instruction.send and same_tb(op, next_op) and same_count(op, next_op) and same_buf_dst(op, next_op) ): - op.inst = MscclInstruction.recv_reduce_copy_send + op.inst = Instruction.recv_reduce_copy_send op.dst = next_op.dst next_op.recv_match.send_match = op op.recv_match = next_op.recv_match diff --git a/msccl/language/ir.py b/msccl/language/ir.py index 3ae2ea3..aaee409 100755 --- a/msccl/language/ir.py +++ b/msccl/language/ir.py @@ -5,7 +5,7 @@ from collections import defaultdict from msccl.language.buffer import Buffer -from msccl.language.types import MscclInstruction as Instruction, Op, Program +from msccl.language.types import Instruction, Op, Program # Instructions where src is on local GPU diff --git a/msccl/language/tb_assignment.py b/msccl/language/tb_assignment.py index f4fb429..4d3c6ad 100755 --- a/msccl/language/tb_assignment.py +++ b/msccl/language/tb_assignment.py @@ -119,7 +119,7 @@ def priority(op): if op.inst == Instruction.start: visited.add(op) for o in op.next: - if o.inst == MscclInstruction.send or o.inst == MscclInstruction.copy: + if o.inst == Instruction.send or o.inst == Instruction.copy: heapq.heappush(ops, (priority(o), o)) while len(ops) > 0: @@ -206,7 +206,7 @@ def dfs(op, channels, f): # Assign channels to flows for op in instrs: - if op.inst == MscclInstruction.send and op.recv_match.is_fused(): + if op.inst == Instruction.send and op.recv_match.is_fused(): dfs(op, all_channels(), []) # Iterate through and make certain the sends and receives between a pair of GPUs is consistent diff --git a/msccl/language/types.py b/msccl/language/types.py index 3132bd3..55f2b1d 100644 --- a/msccl/language/types.py +++ b/msccl/language/types.py @@ -90,12 +90,6 @@ def __str__(self): class Instruction(Enum): delete = "d" start = "st" - - def __str__(self): - return self.value - - -class MscclInstruction(Enum): nop = "nop" send = "s" recv = "r" @@ -144,7 +138,7 @@ def __hash__(self): @dataclass class Op: - inst: Union[Instruction, MscclInstruction, MscclppInstruction] + inst: Union[Instruction, MscclppInstruction] rank: int src: ChunkRef dst: ChunkRef @@ -175,35 +169,35 @@ def cnt(self): def is_send(self): return ( - self.inst == MscclInstruction.send - or self.inst == MscclInstruction.recv_reduce_copy_send - or self.inst == MscclInstruction.recv_copy_send - or self.inst == MscclInstruction.recv_reduce_send + self.inst == Instruction.send + or self.inst == Instruction.recv_reduce_copy_send + or self.inst == Instruction.recv_copy_send + or self.inst == Instruction.recv_reduce_send ) def is_recv(self): return ( - self.inst == MscclInstruction.recv - or self.inst == MscclInstruction.recv_reduce_copy - or self.inst == MscclInstruction.recv_reduce_copy_send - or self.inst == MscclInstruction.recv_copy_send - or self.inst == MscclInstruction.recv_reduce_send + self.inst == Instruction.recv + or self.inst == Instruction.recv_reduce_copy + or self.inst == Instruction.recv_reduce_copy_send + or self.inst == Instruction.recv_copy_send + or self.inst == Instruction.recv_reduce_send ) def is_fused(self): return ( - self.inst == MscclInstruction.recv_reduce_copy_send - or self.inst == MscclInstruction.recv_copy_send - or self.inst == MscclInstruction.recv_reduce_send + self.inst == Instruction.recv_reduce_copy_send + or self.inst == Instruction.recv_copy_send + or self.inst == Instruction.recv_reduce_send ) def is_local(self): - return self.inst == MscclInstruction.copy or self.inst == MscclInstruction.reduce + return self.inst == Instruction.copy or self.inst == Instruction.reduce def peer(self): - if self.inst == MscclInstruction.send: + if self.inst == Instruction.send: return self.dst.rank - elif self.inst == MscclInstruction.recv: + elif self.inst == Instruction.recv: return self.src.rank else: return None From 66d04950f7a5c0e57bf48fe5c435ea323e9053bd Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 24 Apr 2024 11:47:41 +0000 Subject: [PATCH 46/76] fix --- msccl/language/msccl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/msccl/language/msccl.py b/msccl/language/msccl.py index e54756f..17f5d59 100644 --- a/msccl/language/msccl.py +++ b/msccl/language/msccl.py @@ -3,6 +3,7 @@ from msccl.language.buffer import * from msccl.language.instruction_dag import * +from msccl.language.mscclpp.instruction_dag import MscclppInstructionDAG from msccl.language.passes import * from msccl.language.tb_assignment import * from msccl.language.types import ThreadblockPolicy @@ -49,7 +50,7 @@ def __init__( # Initialize the input buffers # self.chunk_dag = ChunkDAG() self.buffers = collective.init_buffers() - self.instr_dag = MscclInstructionDAG(self.num_ranks, self.buffers) + self.instr_dag = MscclppInstructionDAG(self.num_ranks, self.buffers) for r in range(self.num_ranks): for index, chunk in enumerate(self.buffers[r][Buffer.input]): buffer, index = self.collective.get_buffer_index(r, Buffer.input, index) From 8cfcc30e62632a87f1d50129ab62d0e3cce74fd1 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Thu, 25 Apr 2024 06:08:49 +0000 Subject: [PATCH 47/76] Fix UT --- msccl/language/__init__.py | 266 ++++++++++++++++++++++++++++- msccl/language/instruction_dag.py | 2 +- msccl/language/msccl.py | 268 ------------------------------ 3 files changed, 261 insertions(+), 275 deletions(-) delete mode 100644 msccl/language/msccl.py diff --git a/msccl/language/__init__.py b/msccl/language/__init__.py index 7f0822a..9ba80a0 100755 --- a/msccl/language/__init__.py +++ b/msccl/language/__init__.py @@ -7,28 +7,282 @@ from msccl.language.chunk import * from msccl.language.buffer import * from msccl.language.instruction_dag import * -import msccl.language.msccl as msccl_lang import msccl.language.mscclpp as mscclpp from msccl.language.mscclpp import * -from msccl.language.msccl import * from typing import Union +from msccl.language.types import ThreadblockPolicy + # from msccl.language.visualize import * +_current_program = None + + def _curr(): - if msccl_lang._current_program == None and mscclpp._current_program == None: + global _current_program + if _current_program == None and mscclpp._current_program == None: raise RuntimeError("No Program in context") - if msccl_lang._current_program == None: + if _current_program == None: return mscclpp._current_program - return msccl_lang._current_program + return _current_program + + +class MSCCLProgram: + def __init__( + self, + name, + topo, + collective, + instances, + protocol="Simple", + threadblock_policy=ThreadblockPolicy.auto, + interleaved_replication=True, + instr_fusion=True, + check_xml=True, + dependence_nop=False, + ): + self.name = name + self.topo = topo + self.collective = collective + self.num_ranks = topo.num_nodes() + self.instances = instances + self.protocol = protocol + self.threadblock_policy = threadblock_policy + self.interleaved_replication = interleaved_replication + self.instr_fusion = instr_fusion + self.check_xml = check_xml + self.dependence_nop = dependence_nop + assert ( + protocol == "Simple" or protocol == "LL" or protocol == "LL128" + ), f"Given protocol: {protocol}. Must be either Simple, LL, LL128" + self.run_opt = True # Runs optimization passes + # Initialize the input buffers + # self.chunk_dag = ChunkDAG() + self.buffers = collective.init_buffers() + self.instr_dag = MscclInstructionDAG(self.num_ranks, self.buffers) + for r in range(self.num_ranks): + for index, chunk in enumerate(self.buffers[r][Buffer.input]): + buffer, index = self.collective.get_buffer_index(r, Buffer.input, index) + ref = self.get_ref(r, buffer, index, 1) + # self.chunk_dag.init_chunk(chunk, ref) + self.instr_dag.add_start(r, buffer, index, ref) + + def __enter__(self): + global _current_program + if _current_program != None: + raise RuntimeError("There is already a MSCCL Program in context") + _current_program = self + + def __exit__(self, exc_type, exc_value, exc_traceback): + global _current_program + if _current_program != self: + raise RuntimeError("This program is not currently in context") + _current_program = None + + # Tracks a send operation on the buffers + def apply_send(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size): + src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index) + dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index) + sb = self.buffers[src][src_buffer] + db = self.buffers[dst][dst_buffer] + for i in range(size): + db[dst_index + i] = sb[src_index + i] + + # Tracks a reduce operation on the buffers + def apply_reduce(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size): + src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index) + dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index) + sb = self.buffers[src][src_buffer] + db = self.buffers[dst][dst_buffer] + for i in range(size): + reduce_chunk = db[dst_index + i] + sent_chunk = sb[src_index + i] + db[dst_index + i] = reduce_chunk.reduce(dst, sent_chunk) + + def get_ref(self, rank, buffer, index, size): + buffer, index = self.collective.get_buffer_index(rank, buffer, index) + return Ref(rank, buffer, index, size, self) + + def get_chunks(self, rank, buffer, index, size=1): + chunks = [None] * size + for i in range(0, size): + if self.buffers[rank][buffer] and index + i < len(self.buffers[rank][buffer]): + chunks[i] = self.buffers[rank][buffer][index + i] + else: + chunks[i] = None + return chunks + + def check_buffer_exists(self, rank, name): + if name not in self.buffers[rank]: + self.buffers[rank][name] = BufferSlice(Buffer.scratch, name) + + # Checks that all chunks that should be on each rank + # are present in the output buffer. + def check(self): + return self.collective.check(self) + + # Lower program to XML + def lower(self): + # self.chunk_dag._complete_metadata() + # self.chunk_dag.channel_assignment() + # self.chunk_dag.lower_instr_dag(self.instr_dag) + self.instr_dag.convert_set_list() # Pre-emptively convert sets to lists + if self.instr_fusion: + self.instr_dag.optimize() + self.instr_dag._complete_metadata() + if self.threadblock_policy == ThreadblockPolicy.manual: + manual_assign_tbs(self.instr_dag) + else: + auto_assign_tbs(self.instr_dag) + self.instr_dag.lower_pt1(self.instances) + gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.interleaved_replication) + if self.check_xml: + # Check generated MSCCL-IR for correctness - no circular dependencies, sends and receives are ordered + # For very large programs, turn off check_xml when shipping + check_dependency_cycles(self.instr_dag.tbs) + check_threadblock_ordering(self.instr_dag) + return Program(self.name, self.collective.name, self.collective.inplace, self.protocol, gpu_prgms) + + def generate_xml(self): + return ir_to_xml(self.lower(), dependence_nop=self.dependence_nop) + + def print_chunk_dag(self): + visualize_chunk_dag(self.chunk_dag.chunk_paths) + + def print_instr_dags(self, rank): + if rank == 0: + for r in range(len(self.ranks)): + visualize_instr_dag(self.instr_dags[r].operations) + else: + visualize_instr_dag(self.instr_dags[rank].operations) + + +def XML(): + print(_curr().generate_xml()) + + +@dataclass +class Ref(ChunkRef): + prog: MSCCLProgram + + def __repr__(self): + return f"Ref(Buffer:{self.buffer}, Index:{self.index}, Size:{self.size}, Rank:{self.rank})" + + def _end(self): + return self.index + self.size + + def _get_chunk(self, index): + return self.prog.buffers[self.rank][self.buffer][index] + + def split(self, num): + assert self.size % num == 0, f"Trying to split a chunk of {self.size} elements into {num} parts" + chunks = [None] * num + size = self.size // num + for i in range(num): + index = self.index + i * size + chunks[i] = self.prog.get_ref(self.rank, self.buffer, index, size) + return chunks + + def group(self, other): + assert self.rank == other.rank, f"Trying to concatenate chunks on ranks {self.rank} and {other.rank}" + assert self.buffer == other.buffer, f"Trying to concatenate chunks in {self.buffer} and {other.buffer}" + if self.index < other.index: + first = self + second = other + else: + first = other + second = self + + end = max(first._end(), second._end()) + return Ref(self.rank, self.buffer, first.index, end - first.index, self.prog) + + # Copies the chunk(s) referenced by this chunkref onto Rank dst at location (buffer, index) + def copy(self, dst, buffer=None, index=-1, sendtb=-1, recvtb=-1, ch=-1): + self.prog.check_buffer_exists(dst, buffer) + + # If index is not specified assume it is going to the same place in the next gpu + if index == -1 and buffer == None: + index = self.index + buffer = self.buffer + elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output: + index = self.prog.buffers[dst][buffer].instance_size() + + # Some inplace collectives have custom logic for buffers and index (ReduceScatter, AllGather) + buffer, index = self.prog.collective.get_buffer_index(self.rank, buffer, index) + + # Direct send + assert self.prog.topo.link(self.rank, dst) or dst == self.rank, f"No link from {self.rank} to {dst}" + dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) + + # Check if we are copying the chunk to the same index (easy mistake when we are using inplace) + if dst_chunkref == self: + return + + # chunks = self.prog.get_chunks(self.rank, self.buffer, self.index, self.size) + # overwritten_chunks = self.prog.get_chunks(dst, buffer, index, self.size) + + self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) + + # self.prog.chunk_dag.add_send(chunks, overwritten_chunks, self, dst_chunkref, sendtb, recvtb, ch) + sender = self.rank + receiver = dst + if sender != receiver: + sop = self.prog.instr_dag.add_send(sender, self, dst_chunkref, sendtb, ch) + rop = self.prog.instr_dag.add_recv(receiver, self, dst_chunkref, recvtb, ch, sop) + sop.recv_match = rop + else: + self.prog.instr_dag.add_copy(sender, self, dst_chunkref, sendtb, ch) + + return dst_chunkref + + # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref + def reduce(self, other_chunkref, sendtb=-1, recvtb=-1, ch=-1): + # Receive reduce copy + dst = self.rank + src = other_chunkref.rank + assert self.prog.topo.link(src, dst) or src == dst, f"No link from {src} to {dst}" + # dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) + + # chunks1 = self.prog.get_chunks(self.rank, self.buffer, self.index, self.size) + # chunks2 = self.prog.get_chunks(other_chunkref.rank, other_chunkref.buffer, other_chunkref.index self.size) + + self.prog.apply_reduce( + src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size + ) + + # reduce_chunks = self.prog.get_chunks(dst, buffer, index, self.size) + # self.prog.chunk_dag.add_reduce(chunks1, chunks2, reduce_chunks, self, dst_chunkref, sendtb, recvtb, ch) + if src != dst: + sop = self.prog.instr_dag.add_send(src, other_chunkref, self, sendtb, ch) + rop = self.prog.instr_dag.add_recv_reduce_copy(dst, other_chunkref, self, recvtb, ch, sop) + sop.recv_match = rop + else: + self.prog.instr_dag.add_reduce(src, other_chunkref, self, sendtb, ch) + + return self + + def get_origin_index(self, index=0): + return self._get_chunk(index + self.index).origin_index + + def get_origin_rank(self, index=0): + return self._get_chunk(index + self.index).origin_rank + + def get_dst_index(self, index=0): + return self._get_chunk(index + self.index).dst_index + + def get_dst_rank(self, index=0): + return self._get_chunk(index + self.index).dst_rank + + def print_chunk_info(self, index=0): + print(self._get_chunk(index + self.index)) def Print(): _curr().print_chunk_dag() -def chunk(rank, buffer, index, size=1) -> Union[mscclpp.Ref, msccl_lang.Ref]: +def chunk(rank, buffer, index, size=1) -> Union[mscclpp.Ref, Ref]: if _curr().buffers[rank][buffer][index] is None: return None return _curr().get_ref(rank, buffer, index, size) diff --git a/msccl/language/instruction_dag.py b/msccl/language/instruction_dag.py index fb4fe81..2e74ef0 100755 --- a/msccl/language/instruction_dag.py +++ b/msccl/language/instruction_dag.py @@ -216,7 +216,7 @@ def replicate(self, instances: int, instance_policy: InstancePolicy): pass -class InstructionDAG(InstructionDAG): +class MscclInstructionDAG(InstructionDAG): def __init__(self, num_ranks, buffers): super().__init__(num_ranks, buffers) diff --git a/msccl/language/msccl.py b/msccl/language/msccl.py deleted file mode 100644 index 17f5d59..0000000 --- a/msccl/language/msccl.py +++ /dev/null @@ -1,268 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from msccl.language.buffer import * -from msccl.language.instruction_dag import * -from msccl.language.mscclpp.instruction_dag import MscclppInstructionDAG -from msccl.language.passes import * -from msccl.language.tb_assignment import * -from msccl.language.types import ThreadblockPolicy - -_current_program = None - - -def _curr(): - global _current_program - if _current_program == None: - raise RuntimeError("No Program in context") - return _current_program - - -class MSCCLProgram: - def __init__( - self, - name, - topo, - collective, - instances, - protocol="Simple", - threadblock_policy=ThreadblockPolicy.auto, - interleaved_replication=True, - instr_fusion=True, - check_xml=True, - dependence_nop=False, - ): - self.name = name - self.topo = topo - self.collective = collective - self.num_ranks = topo.num_nodes() - self.instances = instances - self.protocol = protocol - self.threadblock_policy = threadblock_policy - self.interleaved_replication = interleaved_replication - self.instr_fusion = instr_fusion - self.check_xml = check_xml - self.dependence_nop = dependence_nop - assert ( - protocol == "Simple" or protocol == "LL" or protocol == "LL128" - ), f"Given protocol: {protocol}. Must be either Simple, LL, LL128" - self.run_opt = True # Runs optimization passes - # Initialize the input buffers - # self.chunk_dag = ChunkDAG() - self.buffers = collective.init_buffers() - self.instr_dag = MscclppInstructionDAG(self.num_ranks, self.buffers) - for r in range(self.num_ranks): - for index, chunk in enumerate(self.buffers[r][Buffer.input]): - buffer, index = self.collective.get_buffer_index(r, Buffer.input, index) - ref = self.get_ref(r, buffer, index, 1) - # self.chunk_dag.init_chunk(chunk, ref) - self.instr_dag.add_start(r, buffer, index, ref) - - def __enter__(self): - global _current_program - if _current_program != None: - raise RuntimeError("There is already a MSCCL Program in context") - _current_program = self - - def __exit__(self, exc_type, exc_value, exc_traceback): - global _current_program - if _current_program != self: - raise RuntimeError("This program is not currently in context") - _current_program = None - - # Tracks a send operation on the buffers - def apply_send(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size): - src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index) - dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index) - sb = self.buffers[src][src_buffer] - db = self.buffers[dst][dst_buffer] - for i in range(size): - db[dst_index + i] = sb[src_index + i] - - # Tracks a reduce operation on the buffers - def apply_reduce(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size): - src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index) - dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index) - sb = self.buffers[src][src_buffer] - db = self.buffers[dst][dst_buffer] - for i in range(size): - reduce_chunk = db[dst_index + i] - sent_chunk = sb[src_index + i] - db[dst_index + i] = reduce_chunk.reduce(dst, sent_chunk) - - def get_ref(self, rank, buffer, index, size): - buffer, index = self.collective.get_buffer_index(rank, buffer, index) - return Ref(rank, buffer, index, size, self) - - def get_chunks(self, rank, buffer, index, size=1): - chunks = [None] * size - for i in range(0, size): - if self.buffers[rank][buffer] and index + i < len(self.buffers[rank][buffer]): - chunks[i] = self.buffers[rank][buffer][index + i] - else: - chunks[i] = None - return chunks - - def check_buffer_exists(self, rank, name): - if name not in self.buffers[rank]: - self.buffers[rank][name] = BufferSlice(Buffer.scratch, name) - - # Checks that all chunks that should be on each rank - # are present in the output buffer. - def check(self): - return self.collective.check(self) - - # Lower program to XML - def lower(self): - # self.chunk_dag._complete_metadata() - # self.chunk_dag.channel_assignment() - # self.chunk_dag.lower_instr_dag(self.instr_dag) - self.instr_dag.convert_set_list() # Pre-emptively convert sets to lists - if self.instr_fusion: - self.instr_dag.optimize() - self.instr_dag._complete_metadata() - if self.threadblock_policy == ThreadblockPolicy.manual: - manual_assign_tbs(self.instr_dag) - else: - auto_assign_tbs(self.instr_dag) - self.instr_dag.lower_pt1(self.instances) - gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.interleaved_replication) - if self.check_xml: - # Check generated MSCCL-IR for correctness - no circular dependencies, sends and receives are ordered - # For very large programs, turn off check_xml when shipping - check_dependency_cycles(self.instr_dag.tbs) - check_threadblock_ordering(self.instr_dag) - return Program(self.name, self.collective.name, self.collective.inplace, self.protocol, gpu_prgms) - - def generate_xml(self): - return ir_to_xml(self.lower(), dependence_nop=self.dependence_nop) - - def print_chunk_dag(self): - visualize_chunk_dag(self.chunk_dag.chunk_paths) - - def print_instr_dags(self, rank): - if rank == 0: - for r in range(len(self.ranks)): - visualize_instr_dag(self.instr_dags[r].operations) - else: - visualize_instr_dag(self.instr_dags[rank].operations) - - -def XML(): - print(_curr().generate_xml()) - - -@dataclass -class Ref(ChunkRef): - prog: MSCCLProgram - - def __repr__(self): - return f"Ref(Buffer:{self.buffer}, Index:{self.index}, Size:{self.size}, Rank:{self.rank})" - - def _end(self): - return self.index + self.size - - def _get_chunk(self, index): - return self.prog.buffers[self.rank][self.buffer][index] - - def split(self, num): - assert self.size % num == 0, f"Trying to split a chunk of {self.size} elements into {num} parts" - chunks = [None] * num - size = self.size // num - for i in range(num): - index = self.index + i * size - chunks[i] = self.prog.get_ref(self.rank, self.buffer, index, size) - return chunks - - def group(self, other): - assert self.rank == other.rank, f"Trying to concatenate chunks on ranks {self.rank} and {other.rank}" - assert self.buffer == other.buffer, f"Trying to concatenate chunks in {self.buffer} and {other.buffer}" - if self.index < other.index: - first = self - second = other - else: - first = other - second = self - - end = max(first._end(), second._end()) - return Ref(self.rank, self.buffer, first.index, end - first.index, self.prog) - - # Copies the chunk(s) referenced by this chunkref onto Rank dst at location (buffer, index) - def copy(self, dst, buffer=None, index=-1, sendtb=-1, recvtb=-1, ch=-1): - self.prog.check_buffer_exists(dst, buffer) - - # If index is not specified assume it is going to the same place in the next gpu - if index == -1 and buffer == None: - index = self.index - buffer = self.buffer - elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output: - index = self.prog.buffers[dst][buffer].instance_size() - - # Some inplace collectives have custom logic for buffers and index (ReduceScatter, AllGather) - buffer, index = self.prog.collective.get_buffer_index(self.rank, buffer, index) - - # Direct send - assert self.prog.topo.link(self.rank, dst) or dst == self.rank, f"No link from {self.rank} to {dst}" - dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) - - # Check if we are copying the chunk to the same index (easy mistake when we are using inplace) - if dst_chunkref == self: - return - - # chunks = self.prog.get_chunks(self.rank, self.buffer, self.index, self.size) - # overwritten_chunks = self.prog.get_chunks(dst, buffer, index, self.size) - - self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) - - # self.prog.chunk_dag.add_send(chunks, overwritten_chunks, self, dst_chunkref, sendtb, recvtb, ch) - sender = self.rank - receiver = dst - if sender != receiver: - sop = self.prog.instr_dag.add_send(sender, self, dst_chunkref, sendtb, ch) - rop = self.prog.instr_dag.add_recv(receiver, self, dst_chunkref, recvtb, ch, sop) - sop.recv_match = rop - else: - self.prog.instr_dag.add_copy(sender, self, dst_chunkref, sendtb, ch) - - return dst_chunkref - - # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref - def reduce(self, other_chunkref, sendtb=-1, recvtb=-1, ch=-1): - # Receive reduce copy - dst = self.rank - src = other_chunkref.rank - assert self.prog.topo.link(src, dst) or src == dst, f"No link from {src} to {dst}" - # dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) - - # chunks1 = self.prog.get_chunks(self.rank, self.buffer, self.index, self.size) - # chunks2 = self.prog.get_chunks(other_chunkref.rank, other_chunkref.buffer, other_chunkref.index self.size) - - self.prog.apply_reduce( - src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size - ) - - # reduce_chunks = self.prog.get_chunks(dst, buffer, index, self.size) - # self.prog.chunk_dag.add_reduce(chunks1, chunks2, reduce_chunks, self, dst_chunkref, sendtb, recvtb, ch) - if src != dst: - sop = self.prog.instr_dag.add_send(src, other_chunkref, self, sendtb, ch) - rop = self.prog.instr_dag.add_recv_reduce_copy(dst, other_chunkref, self, recvtb, ch, sop) - sop.recv_match = rop - else: - self.prog.instr_dag.add_reduce(src, other_chunkref, self, sendtb, ch) - - return self - - def get_origin_index(self, index=0): - return self._get_chunk(index + self.index).origin_index - - def get_origin_rank(self, index=0): - return self._get_chunk(index + self.index).origin_rank - - def get_dst_index(self, index=0): - return self._get_chunk(index + self.index).dst_index - - def get_dst_rank(self, index=0): - return self._get_chunk(index + self.index).dst_rank - - def print_chunk_info(self, index=0): - print(self._get_chunk(index + self.index)) From 2ba1760011961c7dd9fb8dc0d95a04448821dc75 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Thu, 25 Apr 2024 08:28:23 +0000 Subject: [PATCH 48/76] update --- .../allreduce_a100_allpairs_packet_mscclpp.py | 2 +- .../allreduce_a100_allpairs_sm_mscclpp.py | 1 + .../allreduce_a100_allpairs_sm_mscclpp_get.py | 1 + msccl/language/__init__.py | 8 +- msccl/language/instruction_dag.py | 12 +-- msccl/language/mscclpp/__init__.py | 9 ++- msccl/language/mscclpp/instruction_dag.py | 6 +- msccl/language/mscclpp/ir.py | 7 ++ msccl/language/types.py | 2 +- tests/test_language.py | 75 +++++++++++++++++-- 10 files changed, 98 insertions(+), 25 deletions(-) diff --git a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py index 1fad562..d9c6275 100644 --- a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py +++ b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py @@ -50,7 +50,7 @@ def allreduce_allpairs(gpus, instances): c.copy_packet(r, Buffer.input, peer * size, sendtb=peer) Json() - # Check() + Check() parser = argparse.ArgumentParser() diff --git a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py index 74ae223..08717c4 100644 --- a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py +++ b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py @@ -45,6 +45,7 @@ def allreduce_allpairs(gpus, instances, protocol): c_peer.wait(nghr, Buffer.input, peer_index + tb, recvtb=tb) Json() + Check() parser = argparse.ArgumentParser() diff --git a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py index 49f3606..5407676 100644 --- a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py +++ b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py @@ -59,6 +59,7 @@ def allreduce_allpairs(gpus, instances, protocol): c.get(nghr, Buffer.input, index + tb, recvtb=tb) Json() + Check() parser = argparse.ArgumentParser() diff --git a/msccl/language/__init__.py b/msccl/language/__init__.py index 9ba80a0..51cd330 100755 --- a/msccl/language/__init__.py +++ b/msccl/language/__init__.py @@ -11,7 +11,7 @@ from msccl.language.mscclpp import * from typing import Union -from msccl.language.types import ThreadblockPolicy +from msccl.language.types import ReplicationPolicy, ThreadblockPolicy # from msccl.language.visualize import * @@ -49,7 +49,9 @@ def __init__( self.instances = instances self.protocol = protocol self.threadblock_policy = threadblock_policy - self.interleaved_replication = interleaved_replication + self.replication_policy = ( + ReplicationPolicy.interleaved if interleaved_replication else ReplicationPolicy.duplicated + ) self.instr_fusion = instr_fusion self.check_xml = check_xml self.dependence_nop = dependence_nop @@ -136,7 +138,7 @@ def lower(self): else: auto_assign_tbs(self.instr_dag) self.instr_dag.lower_pt1(self.instances) - gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.interleaved_replication) + gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.replication_policy) if self.check_xml: # Check generated MSCCL-IR for correctness - no circular dependencies, sends and receives are ordered # For very large programs, turn off check_xml when shipping diff --git a/msccl/language/instruction_dag.py b/msccl/language/instruction_dag.py index 2e74ef0..7bd70bf 100755 --- a/msccl/language/instruction_dag.py +++ b/msccl/language/instruction_dag.py @@ -5,7 +5,7 @@ from collections import defaultdict from msccl.language.buffer import Buffer -from msccl.language.types import ChunkRef, Gpu, InstancePolicy, Instruction, Op, Threadblock +from msccl.language.types import ChunkRef, Gpu, Instruction, Op, ReplicationPolicy, Threadblock def remove_op(op: Op): @@ -203,8 +203,8 @@ def lower_pt1(self, instances: int): self._infer_dependencies() self._lower_buffers(instances) - def lower_pt2(self, instances: int, instance_pollicy: InstancePolicy): - self.replicate(instances, instance_pollicy) + def lower_pt2(self, instances: int, replication_policy: ReplicationPolicy): + self.replicate(instances, replication_policy) return self._lower_tbs() @abstractmethod @@ -212,7 +212,7 @@ def optimize(self): pass @abstractmethod - def replicate(self, instances: int, instance_policy: InstancePolicy): + def replicate(self, instances: int, replication_policy: ReplicationPolicy): pass @@ -382,7 +382,7 @@ def _optimize_rrcs_rrs(self): # only interleaved replication will be correct # Interleaved policy only supports single count sends/receives from the input/output buffer # (multicount ops are fine between scratch) - def replicate(self, instances, interleaved): + def replicate(self, instances, replication_policy: ReplicationPolicy): if instances == 1: self.instanced_tbs = self.tbs return @@ -401,7 +401,7 @@ def get_new_index(rank, buffer, index, size, i): return buf_instance_len * i + index # If this is operating on the input/output buffer then replication strategy can be either interleaved or batched # This is to fit with the semantics of certain collectives - elif interleaved: + elif replication_policy == ReplicationPolicy.interleaved: return index * instances + i * size else: return len(self.buffers[rank][buffer]) * i + index diff --git a/msccl/language/mscclpp/__init__.py b/msccl/language/mscclpp/__init__.py index 5e8e29f..9bc27af 100644 --- a/msccl/language/mscclpp/__init__.py +++ b/msccl/language/mscclpp/__init__.py @@ -31,7 +31,7 @@ def __init__( instances: int, protocol: str = "Simple", instr_fusion: bool = True, - instance_policy: InstancePolicy = InstancePolicy.duplicated, + replication_policy: ReplicationPolicy = ReplicationPolicy.duplicated, ): self.name = name self.topo = topo @@ -40,7 +40,7 @@ def __init__( self.instances = instances self.protocol = protocol self.instr_fusion = instr_fusion - self.instance_policy = instance_policy + self.replication_policy = replication_policy assert protocol == "Simple" or protocol == "LL", f"Given protocol: {protocol}. Must be either Simple, LL" self.run_opt = True # Runs optimization passes # Initialize the input buffers @@ -114,7 +114,7 @@ def lower(self): if self.instr_fusion: self.instr_dag.optimize() self.instr_dag.lower_pt1(self.instances) - gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.instance_policy) + gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.replication_policy) return Program( self.name, self.collective.name, @@ -189,9 +189,10 @@ def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, self.prog.instr_dag.add_wait(dst, dst_chunkref, self, -1, ChannelType.none) else: self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type) + return dst_chunkref def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): - self._put(dst, buffer, index, sendtb, chan_type) + return self._put(dst, buffer, index, sendtb, chan_type) def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, channel_type=ChannelType.sm): return self._put(dst, buffer, index, sendtb, channel_type, use_packet=True) diff --git a/msccl/language/mscclpp/instruction_dag.py b/msccl/language/mscclpp/instruction_dag.py index 093d41d..e0bf47a 100644 --- a/msccl/language/mscclpp/instruction_dag.py +++ b/msccl/language/mscclpp/instruction_dag.py @@ -15,7 +15,7 @@ same_src_dst_buffer_type, ) from msccl.language.instruction_dag import InstructionDAG -from msccl.language.types import ChunkRef, InstancePolicy, MscclppInstruction as Instruction, Op, Threadblock +from msccl.language.types import ChunkRef, MscclppInstruction as Instruction, Op, ReplicationPolicy, Threadblock class MscclppInstructionDAG(InstructionDAG): @@ -631,7 +631,7 @@ def optimize(self): self._parallel_signal_wait() - def replicate(self, instances: int, instance_policy: InstancePolicy): + def replicate(self, instances: int, replication_policy: ReplicationPolicy): # update op step for rank, rank_tbs in enumerate(self.tbs): for _, tb in rank_tbs.items(): @@ -661,7 +661,7 @@ def get_instance_ref(ref): iref = ChunkRef(ref.rank, ref.buffer, iindex, ref.size) return iref - if instance_policy == InstancePolicy.duplicated: + if replication_policy == ReplicationPolicy.duplicated: for i in range(instances): # Generate all the threadblocks and ops for rank, rank_tbs in enumerate(self.tbs): diff --git a/msccl/language/mscclpp/ir.py b/msccl/language/mscclpp/ir.py index 28df6f5..2c8f37e 100644 --- a/msccl/language/mscclpp/ir.py +++ b/msccl/language/mscclpp/ir.py @@ -61,6 +61,13 @@ def ir_to_json(program: Program): gpu.output_chunks = max(buffer_sizes[(gpu.rank, Buffer.output)], gpu.output_chunks) gpu.scratch_chunks = max(buffer_sizes[(gpu.rank, Buffer.scratch)], gpu.scratch_chunks) + # Since LL protocol will double the scratch size. We need to make sure all GPUs have the same scratch size. + # Otherwise the offset calculation will be wrong. + if program.protocol == "LL": + max_scratch = max(gpu.scratch_chunks for gpu in program.gpus) + for gpu in program.gpus: + gpu.scratch_chunks = max_scratch + # get channel info for each GPU and threadblock for gpu in program.gpus: gpu.threadblocks = sorted(gpu.threadblocks, key=lambda tb: tb.id) diff --git a/msccl/language/types.py b/msccl/language/types.py index 55f2b1d..97a4277 100644 --- a/msccl/language/types.py +++ b/msccl/language/types.py @@ -73,7 +73,7 @@ def __str__(self): return self.value -class InstancePolicy(Enum): +class ReplicationPolicy(Enum): # this means pack multi instrances to deal with the same chunk and share the channels packed = "packed" # this means each instance deal with the different chunk diff --git a/tests/test_language.py b/tests/test_language.py index 9c2a38d..6fc9b98 100755 --- a/tests/test_language.py +++ b/tests/test_language.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import msccl +from msccl.language.types import MscclppInstruction from msccl.topologies import line, fully_connected from msccl.language import * from msccl.language.routines import * @@ -16,15 +17,15 @@ def init_buffers(self): rank_buffers = [] for r in range(self.num_ranks): input_buffer = [None] * chunks_per_node - output_buffer = [None] * chunks_per_node + output_buffer = [None] * chunks_per_node if r == 0: for c in range(chunks_per_node): input_buffer[c] = Chunk(r, c, 2, c) - buffers = {Buffer.input : input_buffer, + buffers = {Buffer.input : input_buffer, Buffer.output : output_buffer} rank_buffers.append(buffers) return rank_buffers - + # Final state chunk0 from rank0 is in the output buffer of rank2 def check(self, prog): @@ -46,14 +47,14 @@ def init_buffers(self): rank_buffers = [] for r in range(self.num_ranks): input_buffer = [None] * chunks_per_node - output_buffer = [None] * chunks_per_node + output_buffer = [None] * chunks_per_node for c in range(chunks_per_node): input_buffer[c] = Chunk(r, c, -1, c) - buffers = {Buffer.input : input_buffer, + buffers = {Buffer.input : input_buffer, Buffer.output : output_buffer} rank_buffers.append(buffers) return rank_buffers - + # Final state rank2 has a fully reduced chunk from gpus 0, 1, and 2 def check(self, prog): @@ -201,6 +202,32 @@ def test_instruction_fusion(): assert lowered_prgm.gpus[1].threadblocks[0].ops[1].inst == Instruction.recv assert lowered_prgm.gpus[2].threadblocks[0].ops[0].inst == Instruction.recv_reduce_copy_send + +def test_instruction_fusion_mscclpp(): + topology = fully_connected(3) + collective = AllReduce(3, 3, True) + prgm = MSCCLPPProgram("allreduce", topology, collective, 1) + with prgm: + c01 = chunk(1, Buffer.input, 0, 3).reduce(chunk(0, Buffer.input, 0, 3), recvtb=0) + c01.signal(2, Buffer.input, 0, sendtb=0) + c012 = chunk(2, Buffer.input, 0, 3) + c012.wait(1, Buffer.input, 0, recvtb=0) + c012.reduce(c01, recvtb=0).put(0, Buffer.input, 0, sendtb=0) + c012.signal(0, Buffer.input, 0, sendtb=0) + c0 = chunk(0, Buffer.input, 0, 3) + c0.wait(2, Buffer.input, 0, recvtb=0) + c0.put(1, Buffer.input, 0, sendtb=0) + assert Check() + lowered_prgm = prgm.lower() + assert lowered_prgm.gpus[0].threadblocks[0].ops[0].inst == MscclppInstruction.wait + assert lowered_prgm.gpus[0].threadblocks[0].ops[1].inst == MscclppInstruction.put + assert lowered_prgm.gpus[1].threadblocks[0].ops[0].inst == MscclppInstruction.read_reduce_copy + assert lowered_prgm.gpus[1].threadblocks[0].ops[1].inst == MscclppInstruction.signal + assert lowered_prgm.gpus[2].threadblocks[0].ops[0].inst == MscclppInstruction.wait + assert lowered_prgm.gpus[2].threadblocks[0].ops[1].inst == MscclppInstruction.read_reduce_copy_send + assert lowered_prgm.gpus[2].threadblocks[0].ops[2].inst == MscclppInstruction.signal + + def test_replication(): topology = fully_connected(2) collective = AllToAll(2, 1, False) @@ -287,4 +314,38 @@ def test_routines_allreduce_nodes(): c.reduce(chunk(r, Buffer.output, exchange_index, 4)) c = c.copy(r, Buffer.output, exchange_index) XML() - assert Check() \ No newline at end of file + assert Check() + +def test_routines_allreduce_packet_inplace_mscclpp(): + size = 8 + topology = fully_connected(size) + collective = AllReduce(size, size * size, True) + with MSCCLPPProgram("allreduce_packet", topology, collective, 1, protocol="LL"): + # Each rank sends the nth chunk to the nth rank into scratch space + for r1 in range(size): + for tb in range(size): + if tb == r1: + continue + remote_rank = tb + index = remote_rank * size + c = chunk(r1, Buffer.input, index, size) + c.put_packet(remote_rank, "scratch", index=r1 * size, sendtb=tb) + # Each rank performs a local reduction on the nth chunk + # Utilize 8 threadblocks for this reduction for better parallelism + for r in range(size): + for index in range(size): + c = chunk(r, Buffer.input, r * size + index) + for peer in range(size): + if peer != r: + c.reduce_packet(chunk(r, "scratch", peer * size + index), recvtb=index) + for peer in range(size): + if peer != r: + c.put_packet(peer, "scratch", (size * size) + r * size + index, sendtb=index) + # Each rank get final result from scratch space + for r in range(size): + for peer in range(size): + if peer != r: + c = chunk(r, "scratch", size * size + peer * size, size) + c.copy_packet(r, Buffer.input, peer * size, sendtb=peer) + Json() + assert Check() From 52783e93621a53c860925980de6f365ace35218c Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Thu, 25 Apr 2024 08:32:20 +0000 Subject: [PATCH 49/76] WIP --- tests/test_language.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_language.py b/tests/test_language.py index 6fc9b98..bcd22ab 100755 --- a/tests/test_language.py +++ b/tests/test_language.py @@ -320,7 +320,7 @@ def test_routines_allreduce_packet_inplace_mscclpp(): size = 8 topology = fully_connected(size) collective = AllReduce(size, size * size, True) - with MSCCLPPProgram("allreduce_packet", topology, collective, 1, protocol="LL"): + with MSCCLPPProgram("allreduce_packet", topology, collective, 2, protocol="LL"): # Each rank sends the nth chunk to the nth rank into scratch space for r1 in range(size): for tb in range(size): From 01a0745a28502ba1169e9120136ec1a2dbf5d3a6 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Thu, 25 Apr 2024 08:35:56 +0000 Subject: [PATCH 50/76] WIP --- .github/workflows/tests.yaml | 2 +- pytest.ini | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index d854490..f3b6c8c 100755 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -25,6 +25,6 @@ jobs: run: | pip install --upgrade pip pip install -r requirements.txt - - name: Run tests and check at least 90% coverage + - name: Run tests and check at least 85% coverage run: | pytest diff --git a/pytest.ini b/pytest.ini index d68bf05..4621e92 100755 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +1,2 @@ [pytest] -addopts = --cov=msccl --cov-report term-missing:skip-covered --cov-fail-under 90 -n auto +addopts = --cov=msccl --cov-report term-missing:skip-covered --cov-fail-under 85 -n auto From 13e902ab37958667352abd7f45e8dd08cf05f5cb Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Thu, 25 Apr 2024 08:42:07 +0000 Subject: [PATCH 51/76] update --- examples/mscclang/allreduce_a100_ring_mscclpp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/mscclang/allreduce_a100_ring_mscclpp.py b/examples/mscclang/allreduce_a100_ring_mscclpp.py index 8801ac5..712f830 100644 --- a/examples/mscclang/allreduce_a100_ring_mscclpp.py +++ b/examples/mscclang/allreduce_a100_ring_mscclpp.py @@ -43,6 +43,7 @@ def allreduce_ring(size, instances): c.wait(prev_rank, Buffer.input, (index + size - 1) % size, 0) Json() + Check() parser = argparse.ArgumentParser() From e52cabf71e3a079ea9b41aca2bf4855e83cb1304 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 29 Apr 2024 03:35:51 +0000 Subject: [PATCH 52/76] revert --- .../allreduce_a100_allpairs_packet_mscclpp.py | 62 -- .../allreduce_a100_allpairs_sm_mscclpp.py | 58 -- .../allreduce_a100_allpairs_sm_mscclpp_get.py | 72 -- .../mscclang/allreduce_a100_ring_mscclpp.py | 54 -- .../mscclang/allreduce_mi300_sm_mscclpp.py | 69 ++ .../mscclang/allreduce_mi300_sm_mscclpp2.py | 69 ++ msccl/language/__init__.py | 9 +- msccl/language/channel.py | 25 - msccl/language/collectives.py | 3 +- msccl/language/mscclpp/__init__.py | 299 -------- msccl/language/mscclpp/instruction_dag.py | 704 ------------------ msccl/language/mscclpp/ir.py | 301 -------- msccl/language/types.py | 24 +- tests/test_language.py | 61 -- 14 files changed, 142 insertions(+), 1668 deletions(-) delete mode 100644 examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py delete mode 100644 examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py delete mode 100644 examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py delete mode 100644 examples/mscclang/allreduce_a100_ring_mscclpp.py create mode 100644 examples/mscclang/allreduce_mi300_sm_mscclpp.py create mode 100644 examples/mscclang/allreduce_mi300_sm_mscclpp2.py delete mode 100644 msccl/language/channel.py delete mode 100644 msccl/language/mscclpp/__init__.py delete mode 100644 msccl/language/mscclpp/instruction_dag.py delete mode 100644 msccl/language/mscclpp/ir.py diff --git a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py deleted file mode 100644 index d9c6275..0000000 --- a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import argparse -from msccl.language import * -from msccl.topologies import * -from msccl.language.collectives import AllReduce - - -def allreduce_allpairs(gpus, instances): - size = gpus - chunksperloop = gpus * gpus - topology = fully_connected(size) - collective = AllReduce(size, chunksperloop, True) - with MSCCLPPProgram( - "allreduce_pairs", - topology, - collective, - instances, - protocol="LL", - ): - - # Each rank sends the nth chunk to the nth rank into scratch space - for r1 in range(size): - for tb in range(size): - if tb == r1: - continue - remote_rank = tb - index = remote_rank * size - c = chunk(r1, Buffer.input, index, size) - c.put_packet(remote_rank, "scratch", index=r1*size, sendtb=tb) - - # Each rank performs a local reduction on the nth chunk - # Utilize 8 threadblocks for this reduction for better parallelism - for r in range(size): - for index in range(size): - c = chunk(r, Buffer.input, r * size + index) - for peer in range(size): - if peer != r: - c.reduce_packet(chunk(r, "scratch", peer * size + index), recvtb=index) - for peer in range(size): - if peer != r: - c.put_packet(peer, "scratch", (size * size) + r * size + index, sendtb=index) - - # Each rank get final result from scratch space - for r in range(size): - for peer in range(size): - if peer != r: - c = chunk(r, "scratch", size * size + peer * size, size) - c.copy_packet(r, Buffer.input, peer * size, sendtb=peer) - - Json() - Check() - - -parser = argparse.ArgumentParser() -parser.add_argument("num_gpus", type=int, help="number of gpus") -parser.add_argument("instances", type=int, help="number of instances") - -args = parser.parse_args() - -allreduce_allpairs(args.num_gpus, args.instances) diff --git a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py deleted file mode 100644 index 08717c4..0000000 --- a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import argparse -from msccl.language import * -from msccl.topologies import * -from msccl.language.collectives import AllReduce - - -def allreduce_allpairs(gpus, instances, protocol): - size = gpus - chunksperloop = gpus * gpus - topology = fully_connected(size) - collective = AllReduce(size, chunksperloop, True) - with MSCCLPPProgram("allreduce_pairs", topology, collective, instances, protocol=protocol): - for rank in range(size): - for tb in range(size): - index = rank * size - c = chunk(rank, Buffer.input, index + tb) - # step1 make sure the data is ready - for nghr in range(size): - peer_index = nghr * size - if rank != nghr: - # signal peer the buffer is ready - c_peer = chunk(rank, Buffer.input, peer_index + tb) - c_peer.signal(nghr, Buffer.input, peer_index + tb, sendtb=tb) - for nghr in range(size): - if rank != nghr: - c.wait(nghr, Buffer.input, index + tb, recvtb=tb) - # step2 reduce the chunks and send to peers - for nghr in range(size): - if rank != nghr: - c.reduce(chunk(nghr, Buffer.input, index + tb), recvtb=tb) - for nghr in range(size): - if rank != nghr: - c.put(nghr, Buffer.input, index + tb, sendtb=tb) - # step3 signal the peers buffer is ready - for nghr in range(size): - if rank != nghr: - c.signal(nghr, Buffer.input, index + tb, sendtb=tb) - for nghr in range(size): - if rank != nghr: - peer_index = nghr * size - c_peer = chunk(rank, Buffer.input, peer_index + tb) - c_peer.wait(nghr, Buffer.input, peer_index + tb, recvtb=tb) - - Json() - Check() - - -parser = argparse.ArgumentParser() -parser.add_argument("num_gpus", type=int, help="number of gpus") -parser.add_argument("instances", type=int, help="number of instances") -parser.add_argument("--protocol", type=str, default="Simple", choices=["Simple"], help="Protocol") - -args = parser.parse_args() - -allreduce_allpairs(args.num_gpus, args.instances, args.protocol) diff --git a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py deleted file mode 100644 index 5407676..0000000 --- a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import argparse -from msccl.language import * -from msccl.topologies import * -from msccl.language.collectives import AllReduce - - -def allreduce_allpairs(gpus, instances, protocol): - size = gpus - chunksperloop = gpus * gpus - topology = fully_connected(size) - collective = AllReduce(size, chunksperloop, True) - with MSCCLPPProgram( - "allreduce_pairs", - topology, - collective, - instances, - protocol=protocol, - ): - - # Each rank sends the nth chunk to the nth rank into scratch space - for rank in range(size): - for tb in range(size): - index = rank * size - c = chunk(rank, Buffer.input, index + tb) - # make sure the data is ready - for nghr in range(size): - peer_index = nghr * size - if rank != nghr: - c_peer = chunk(rank, Buffer.input, peer_index + tb) - c_peer.signal(nghr, Buffer.input, peer_index + tb, sendtb=tb) - for nghr in range(size): - if rank != nghr: - c.wait(nghr, Buffer.input, index + tb, recvtb=tb) - # reduce the chunks - for i in range(size): - nghr = (rank + i) % size - if rank != nghr: - c.reduce(chunk(nghr, Buffer.input, index + tb), recvtb=tb) - for nghr in range(size): - if rank != nghr: - c.signal(nghr, Buffer.input, index + tb, sendtb=tb) - - # wait for all the chunks is ready, then get the chunks - for rank in range(size): - for tb in range(size): - for nghr in range(size): - if rank != nghr: - index = nghr * size - c = chunk(rank, Buffer.input, index + tb) - c.wait(nghr, Buffer.input, index + tb, recvtb=tb) - for i in range(size): - nghr = (rank + i) % size - index = nghr * size - if rank != nghr: - c = chunk(rank, Buffer.input, index + tb) - c.get(nghr, Buffer.input, index + tb, recvtb=tb) - - Json() - Check() - - -parser = argparse.ArgumentParser() -parser.add_argument("num_gpus", type=int, help="number of gpus") -parser.add_argument("instances", type=int, help="number of instances") -parser.add_argument("--protocol", type=str, default="Simple", choices=["Simple"], help="Protocol") - -args = parser.parse_args() - -allreduce_allpairs(args.num_gpus, args.instances, args.protocol) diff --git a/examples/mscclang/allreduce_a100_ring_mscclpp.py b/examples/mscclang/allreduce_a100_ring_mscclpp.py deleted file mode 100644 index 712f830..0000000 --- a/examples/mscclang/allreduce_a100_ring_mscclpp.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import argparse -from msccl.language import * -from msccl.topologies import * -from msccl.language.collectives import AllReduce - - -# Ring all reduce for A100s -def allreduce_ring(size, instances): - topology = fully_connected(size) - collective = AllReduce(size, size, True) - with MSCCLPPProgram( - f"allreduce_ring", - topology, - collective, - instances, - protocol="Simple", - ): - # Reduce ring - for step in range(0, size - 1): - for index in range(0, size): - rank = (index + step) % size - next_rank = (index + step + 1) % size - c = chunk(rank, Buffer.input, index) - c.signal(next_rank, Buffer.input, index, 0) - prev_rank = (index + step - 1) % size - c = chunk(rank, Buffer.input, (index + size - 1) % size) - c.wait(prev_rank, Buffer.input, (index + size - 1) % size, 0) - c.reduce(chunk(prev_rank, Buffer.input, (index + size - 1) % size), recvtb=0) - - # Propagate ring - for step in range(-1, size - 2): - for index in range(0, size): - rank = (index + step) % size - c = chunk(rank, Buffer.input, index) - next_rank = (index + step + 1) % size - c.put(next_rank, Buffer.input, index, sendtb=0) - c.signal(next_rank, Buffer.input, index, 0) - prev_rank = (index + step - 1) % size - c = chunk(rank, Buffer.input, (index + size - 1) % size) - c.wait(prev_rank, Buffer.input, (index + size - 1) % size, 0) - - Json() - Check() - - -parser = argparse.ArgumentParser() -parser.add_argument("num_gpus", type=int, help="number of gpus") -parser.add_argument("instances", type=int, help="number of instances") -args = parser.parse_args() - -allreduce_ring(args.num_gpus, args.instances) diff --git a/examples/mscclang/allreduce_mi300_sm_mscclpp.py b/examples/mscclang/allreduce_mi300_sm_mscclpp.py new file mode 100644 index 0000000..4e97661 --- /dev/null +++ b/examples/mscclang/allreduce_mi300_sm_mscclpp.py @@ -0,0 +1,69 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +from msccl.language import * +from msccl.topologies import * +from msccl.language.collectives import AllReduce + + +def allreduce(gpus, instances, protocol): + size = gpus + chunksperloop = gpus * (gpus - 1) + topology = fully_connected(size) + collective = AllReduce(size, chunksperloop, True) + with MSCCLPPProgram( + "allreduce_mi300", + topology, + collective, + instances, + protocol=protocol, + ): + for rank in range(size): + for tb in range(size - 1): + for i in range(size): + nghr = (rank + tb + 1 + i) % size + if rank == nghr: + continue + c = chunk(rank, Buffer.input, nghr * (size - 1) + tb) + c.put(nghr, "scratch", rank * (size - 1) + tb, sendtb=tb) + for rank in range(size): + for tb in range(size - 1): + for i in range(size): + nghr = (rank + tb + 1 + i) % size + if rank == nghr: + continue + c = chunk(rank, Buffer.input, nghr * (size - 1) + tb) + c.signal(nghr, "scratch", rank * (size - 1) + tb, sendtb=tb) + for i in range(size): + nghr = (rank + tb + 1 + i) % size + if rank == nghr: + continue + c = chunk(rank, "scratch", nghr * (size - 1) + tb) + c.wait(nghr, Buffer.input, rank * (size - 1) + tb, recvtb=tb) + + for rank in range(size): + for tb in range(size - 1): + c = chunk(rank, Buffer.input, rank * (size - 1) + tb) + for nghr in range(size): + if rank != nghr: + index = nghr * (size - 1) + c.reduce(chunk(rank, "scratch", index + tb), recvtb=tb) + for i in range(size): + nghr = (rank + i) % size + index = rank * (size-1) + if rank != nghr: + c = chunk(rank, Buffer.input, index + tb) + c.put(nghr, Buffer.input, index + tb, sendtb=tb) + + Json() + + +parser = argparse.ArgumentParser() +parser.add_argument("num_gpus", type=int, help="number of gpus") +parser.add_argument("instances", type=int, help="number of instances") +parser.add_argument("--protocol", type=str, default="Simple", choices=["Simple"], help="Protocol") + +args = parser.parse_args() + +allreduce(args.num_gpus, args.instances, args.protocol) diff --git a/examples/mscclang/allreduce_mi300_sm_mscclpp2.py b/examples/mscclang/allreduce_mi300_sm_mscclpp2.py new file mode 100644 index 0000000..80a9e6d --- /dev/null +++ b/examples/mscclang/allreduce_mi300_sm_mscclpp2.py @@ -0,0 +1,69 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +from msccl.language import * +from msccl.topologies import * +from msccl.language.collectives import AllReduce + + +def allreduce(gpus, instances, protocol): + size = gpus + chunksperloop = gpus * gpus + topology = fully_connected(size) + collective = AllReduce(size, chunksperloop, True) + with MSCCLPPProgram( + "allreduce_mi300", + topology, + collective, + instances, + protocol=protocol, + ): + for rank in range(size): + for tb in range(size): + for i in range(size): + nghr = (rank + tb + 1 + i) % size + if rank == nghr: + continue + c = chunk(rank, Buffer.input, nghr * size + tb) + c.put(nghr, "scratch", rank * size + tb, sendtb=tb) + for rank in range(size): + for tb in range(size): + for i in range(size): + nghr = (rank + tb + 1 + i) % size + if rank == nghr: + continue + c = chunk(rank, Buffer.input, nghr * size + tb) + c.signal(nghr, "scratch", rank * size + tb, sendtb=tb) + for i in range(size): + nghr = (rank + tb + 1 + i) % size + if rank == nghr: + continue + c = chunk(rank, "scratch", nghr * size + tb) + c.wait(nghr, Buffer.input, rank * size + tb, recvtb=tb) + + for rank in range(size): + for tb in range(size): + c = chunk(rank, Buffer.input, rank * size + tb) + for nghr in range(size): + if rank != nghr: + index = nghr * size + c.reduce(chunk(rank, "scratch", index + tb), recvtb=tb) + for i in range(size): + nghr = (rank + i) % size + index = rank * size + if rank != nghr: + c = chunk(rank, Buffer.input, index + tb) + c.put(nghr, Buffer.input, index + tb, sendtb=tb) + + Json() + + +parser = argparse.ArgumentParser() +parser.add_argument("num_gpus", type=int, help="number of gpus") +parser.add_argument("instances", type=int, help="number of instances") +parser.add_argument("--protocol", type=str, default="Simple", choices=["Simple"], help="Protocol") + +args = parser.parse_args() + +allreduce(args.num_gpus, args.instances, args.protocol) diff --git a/msccl/language/__init__.py b/msccl/language/__init__.py index 51cd330..d83849a 100755 --- a/msccl/language/__init__.py +++ b/msccl/language/__init__.py @@ -7,9 +7,6 @@ from msccl.language.chunk import * from msccl.language.buffer import * from msccl.language.instruction_dag import * -import msccl.language.mscclpp as mscclpp -from msccl.language.mscclpp import * -from typing import Union from msccl.language.types import ReplicationPolicy, ThreadblockPolicy @@ -21,10 +18,8 @@ def _curr(): global _current_program - if _current_program == None and mscclpp._current_program == None: - raise RuntimeError("No Program in context") if _current_program == None: - return mscclpp._current_program + raise RuntimeError("No Program in context") return _current_program @@ -284,7 +279,7 @@ def Print(): _curr().print_chunk_dag() -def chunk(rank, buffer, index, size=1) -> Union[mscclpp.Ref, Ref]: +def chunk(rank, buffer, index, size=1) -> Ref: if _curr().buffers[rank][buffer][index] is None: return None return _curr().get_ref(rank, buffer, index, size) diff --git a/msccl/language/channel.py b/msccl/language/channel.py deleted file mode 100644 index fb97b7e..0000000 --- a/msccl/language/channel.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - - -from dataclasses import dataclass -from enum import Enum - -from msccl.language.buffer import Buffer - - -class ChannelType(Enum): - proxy = "proxy" - sm = "sm" - none = "none" - - def __str__(self): - return self.value - - -@dataclass(frozen=True) -class Channel: - srcBuffer: Buffer - dstBuffer: Buffer - type: ChannelType - connected_to: int diff --git a/msccl/language/collectives.py b/msccl/language/collectives.py index dd243ec..74db17a 100755 --- a/msccl/language/collectives.py +++ b/msccl/language/collectives.py @@ -3,12 +3,11 @@ from msccl.language import * class Collective(): - def __init__(self, num_ranks, chunk_factor, inplace, num_chunk_groups = 1): + def __init__(self, num_ranks, chunk_factor, inplace): self.num_ranks = num_ranks self.chunk_factor = chunk_factor self.inplace = inplace self.name = "custom" - self.num_chunk_groups = num_chunk_groups def init_buffers(self): pass diff --git a/msccl/language/mscclpp/__init__.py b/msccl/language/mscclpp/__init__.py deleted file mode 100644 index 9bc27af..0000000 --- a/msccl/language/mscclpp/__init__.py +++ /dev/null @@ -1,299 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from msccl.collectives import Collective -from msccl.language.buffer import * -from msccl.language.mscclpp.ir import * -from msccl.language.mscclpp.instruction_dag import MscclppInstructionDAG -from msccl.language.tb_assignment import * -from msccl.topologies.topology import Topology - -_current_program = None - - -def _curr(): - global _current_program - if _current_program == None: - raise RuntimeError("No Program in context") - return _current_program - - -# For msccl++ program, we have one assumption that for channel can be identified by (send_buffer, recv_buffer, type, send_tb/recv_tb) -# which means the send_tb and recv_tb should be the same for a pair of signal and wait, also same for put/get operation. -# If one sender what to send data to peer want to use different tb in receiver side. We need to send to same tb in receiver side first, -# then performance a across tb sync. This is a limitation of current implementation. -class MSCCLPPProgram: - def __init__( - self, - name: str, - topo: Topology, - collective: Collective, - instances: int, - protocol: str = "Simple", - instr_fusion: bool = True, - replication_policy: ReplicationPolicy = ReplicationPolicy.duplicated, - ): - self.name = name - self.topo = topo - self.collective = collective - self.num_ranks = topo.num_nodes() - self.instances = instances - self.protocol = protocol - self.instr_fusion = instr_fusion - self.replication_policy = replication_policy - assert protocol == "Simple" or protocol == "LL", f"Given protocol: {protocol}. Must be either Simple, LL" - self.run_opt = True # Runs optimization passes - # Initialize the input buffers - self.buffers = collective.init_buffers() - self.instr_dag = MscclppInstructionDAG(self.num_ranks, self.buffers) - for r in range(self.num_ranks): - for index, chunk in enumerate(self.buffers[r][Buffer.input]): - buffer, index = self.collective.get_buffer_index(r, Buffer.input, index) - ref = self.get_ref(r, buffer, index, 1) - # self.chunk_dag.init_chunk(chunk, ref) - self.instr_dag.add_start(r, buffer, index, ref) - - def __enter__(self): - global _current_program - if _current_program != None: - raise RuntimeError("There is already a MSCCLPP Program in context") - _current_program = self - - def __exit__(self, exc_type, exc_value, exc_traceback): - global _current_program - if _current_program != self: - raise RuntimeError("This program is not currently in context") - _current_program = None - - # Tracks a send operation on the buffers - def apply_send(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size): - src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index) - dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index) - sb = self.buffers[src][src_buffer] - db = self.buffers[dst][dst_buffer] - for i in range(size): - db[dst_index + i] = sb[src_index + i] - - # Tracks a reduce operation on the buffers - def apply_reduce(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size): - src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index) - dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index) - sb = self.buffers[src][src_buffer] - db = self.buffers[dst][dst_buffer] - for i in range(size): - reduce_chunk = db[dst_index + i] - sent_chunk = sb[src_index + i] - db[dst_index + i] = reduce_chunk.reduce(dst, sent_chunk) - - def get_ref(self, rank, buffer, index, size): - buffer, index = self.collective.get_buffer_index(rank, buffer, index) - return Ref(rank, buffer, index, size, self) - - def get_chunks(self, rank, buffer, index, size=1): - chunks = [None] * size - for i in range(0, size): - if self.buffers[rank][buffer] and index + i < len(self.buffers[rank][buffer]): - chunks[i] = self.buffers[rank][buffer][index + i] - else: - chunks[i] = None - return chunks - - def check_buffer_exists(self, rank, name): - if name not in self.buffers[rank]: - self.buffers[rank][name] = BufferSlice(Buffer.scratch, name) - - # Checks that all chunks that should be on each rank - # are present in the output buffer. - def check(self): - return self.collective.check(self) - - # Lower program to MSCCLPP - def lower(self): - convert_to_exectuion_plan(self.instr_dag) - self.instr_dag.complete_channels() - if self.instr_fusion: - self.instr_dag.optimize() - self.instr_dag.lower_pt1(self.instances) - gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.replication_policy) - return Program( - self.name, - self.collective.name, - self.collective.inplace, - self.protocol, - gpu_prgms, - self.collective.num_chunk_groups * self.instances - ) - - def generate_json(self): - return ir_to_json(self.lower()) - - -def Json(): - print(_curr().generate_json()) - - -@dataclass -class Ref(ChunkRef): - prog: MSCCLPPProgram - - def __repr__(self): - return f"Ref(Buffer:{self.buffer}, Index:{self.index}, Size:{self.size}, Rank:{self.rank})" - - def _end(self): - return self.index + self.size - - def _get_chunk(self, index): - return self.prog.buffers[self.rank][self.buffer][index] - - def split(self, num): - assert self.size % num == 0, f"Trying to split a chunk of {self.size} elements into {num} parts" - chunks = [None] * num - size = self.size // num - for i in range(num): - index = self.index + i * size - chunks[i] = self.prog.get_ref(self.rank, self.buffer, index, size) - return chunks - - def group(self, other): - assert self.rank == other.rank, f"Trying to concatenate chunks on ranks {self.rank} and {other.rank}" - assert self.buffer == other.buffer, f"Trying to concatenate chunks in {self.buffer} and {other.buffer}" - if self.index < other.index: - first = self - second = other - else: - first = other - second = self - - end = max(first._end(), second._end()) - return Ref(self.rank, self.buffer, first.index, end - first.index, self.prog) - - def _get_buffer_index(self, remote_rank, buffer, index): - if index == -1 and buffer == None: - return self.buffer, self.index - elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output: - return buffer, self.prog.buffers[remote_rank][buffer].instance_size() - return buffer, index - - def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, use_packet=False): - self.prog.check_buffer_exists(dst, buffer) - assert self.rank != dst, "Cannot put to the same rank" - buffer, index = self._get_buffer_index(dst, buffer, index) - - # Direct put - assert self.prog.topo.link(self.rank, dst) or dst == self.rank, f"No link from {self.rank} to {dst}" - dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) - self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) - if use_packet: - self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, use_packet) - self.prog.instr_dag.add_signal(self.rank, self, dst_chunkref, -1, ChannelType.none) - self.prog.instr_dag.add_wait(dst, dst_chunkref, self, -1, ChannelType.none) - else: - self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type) - return dst_chunkref - - def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): - return self._put(dst, buffer, index, sendtb, chan_type) - - def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, channel_type=ChannelType.sm): - return self._put(dst, buffer, index, sendtb, channel_type, use_packet=True) - - def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm): - self.prog.check_buffer_exists(src, buffer) - sender = src - receiver = self.rank - assert sender != receiver, "Cannot get from the same rank" - buffer, index = self._get_buffer_index(src, buffer, index) - - # Direct get - assert self.prog.topo.link(self.rank, src) or src == self.rank, f"No link from {self.rank} to {src}" - src_chunkref = self.prog.get_ref(src, buffer, index, self.size) - - self.prog.apply_send(src, buffer, index, self.rank, self.buffer, self.index, self.size) - self.prog.instr_dag.add_get(receiver, src_chunkref, self, recvtb, chan_type) - - # for signal and wait, currently we assuem the pair will use the same tb index. In future we need - # to infer the tb index from the instruction DAG Add a channel is define as (send_tb, src_buffer, recv_tb, dst_buffer, type). - # Then we can use DAG info to reduce the number of channels. - def signal(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): - sender = self.rank - receiver = dst - assert sender != receiver, "Cannot signal to the same rank" - buffer, index = self._get_buffer_index(dst, buffer, index) - - # Direct signal - assert self.prog.topo.link(self.rank, dst) or dst == self.rank, f"No link from {self.rank} to {dst}" - dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) - self.prog.instr_dag.add_signal(sender, self, dst_chunkref, sendtb, chan_type) - - def wait(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm): - sender = src - receiver = self.rank - assert sender != receiver, "Cannot wait on the same rank" - buffer, index = self._get_buffer_index(src, buffer, index) - - # Direct wait - assert self.prog.topo.link(self.rank, src) or src == self.rank, f"No link from {self.rank} to {src}" - src_chunkref = self.prog.get_ref(src, buffer, index, self.size) - self.prog.instr_dag.add_wait(receiver, self, src_chunkref, recvtb, chan_type) - - def _copy(self, dst, buffer=None, index=-1, sendtb=-1, use_packet=False): - self.prog.check_buffer_exists(dst, buffer) - buffer, index = self._get_buffer_index(dst, buffer, index) - - dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) - # Check if we are copying the chunk to the same index (easy mistake when we are using inplace) - if dst_chunkref == self: - return - self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) - - assert self.rank == dst, "Chunk copy only supports intra-rank communication" - self.prog.instr_dag.add_copy(self.rank, self, dst_chunkref, sendtb, use_packet) - - return dst_chunkref - - # Copies the chunk(s) referenced by this chunkref onto Rank dst at location (buffer, index) - def copy(self, dst, buffer=None, index=-1, sendtb=-1): - return self._copy(dst, buffer, index, sendtb) - - def copy_packet(self, dst, buffer=None, index=-1, sendtb=-1): - return self._copy(dst, buffer, index, sendtb, use_packet=True) - - def _reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.sm, use_packet=False): - dst = self.rank - src = other_chunkref.rank - assert self.prog.topo.link(src, dst) or src == dst, f"No link from {src} to {dst}" - self.prog.apply_reduce( - src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size - ) - if use_packet: - assert src == dst, "Packet reduce only supports intra-rank communication" - - if src != dst: - self.prog.instr_dag.add_read_reduce(dst, other_chunkref, self, recvtb, channel_type) - else: - self.prog.instr_dag.add_reduce(src, other_chunkref, self, recvtb, use_packet) - - return self - - # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref - def reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.sm): - return self._reduce(other_chunkref, recvtb, channel_type) - - # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref - def reduce_packet(self, other_chunkref, recvtb=-1): - return self._reduce(other_chunkref, recvtb, use_packet=True) - - def get_origin_index(self, index=0): - return self._get_chunk(index + self.index).origin_index - - def get_origin_rank(self, index=0): - return self._get_chunk(index + self.index).origin_rank - - def get_dst_index(self, index=0): - return self._get_chunk(index + self.index).dst_index - - def get_dst_rank(self, index=0): - return self._get_chunk(index + self.index).dst_rank - - def print_chunk_info(self, index=0): - print(self._get_chunk(index + self.index)) diff --git a/msccl/language/mscclpp/instruction_dag.py b/msccl/language/mscclpp/instruction_dag.py deleted file mode 100644 index e0bf47a..0000000 --- a/msccl/language/mscclpp/instruction_dag.py +++ /dev/null @@ -1,704 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - - -from msccl.language.buffer import Buffer -from msccl.language.channel import Channel, ChannelType -from msccl.language.instruction_dag import ( - buf_dst_src_match, - merge_op, - remove_op, - same_buf_dst, - same_buf_src, - same_chan_type, - same_count, - same_src_dst_buffer_type, -) -from msccl.language.instruction_dag import InstructionDAG -from msccl.language.types import ChunkRef, MscclppInstruction as Instruction, Op, ReplicationPolicy, Threadblock - - -class MscclppInstructionDAG(InstructionDAG): - def __init__(self, num_ranks, buffers): - super().__init__(num_ranks, buffers) - - # InstructionDAG - adds a copy node - def add_copy(self, rank, send_ref, recv_ref, tb, use_packet=False): - tb_step = self._get_tb_step(rank, tb) - if use_packet: - op = Op(Instruction.copy_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) - else: - op = Op(Instruction.copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) - dstbuffer = recv_ref.buffer - dstindex = recv_ref.index - srcbuffer = send_ref.buffer - srcindex = send_ref.index - size = recv_ref.size - # Sending part of copy [Read] - self._read(rank, srcbuffer, srcindex, size, op) - # Receiving part of copy [Write] - self._write(rank, dstbuffer, dstindex, size, op) - return op - - # InstructionDAG - adds a redduce node - def add_reduce(self, rank, send_ref, recv_ref, tb, use_packet=False): - tb_step = self._get_tb_step(rank, tb) - if use_packet: - op = Op(Instruction.reduce_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) - else: - op = Op(Instruction.reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) - dstbuffer = recv_ref.buffer - dstindex = recv_ref.index - srcbuffer = send_ref.buffer - srcindex = send_ref.index - size = recv_ref.size - op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) - # Sending part of reduce - self._read(rank, srcbuffer, srcindex, size, op) - # Reduce part of copy - self._write(rank, dstbuffer, dstindex, size, op, read=True) - return op - - # InstructionDAG - adds a put node - def add_put(self, rank, send_ref, recv_ref, tb, ch_type, use_packet=False): - tb_step = self._get_tb_step(rank, tb) - if use_packet: - op = Op( - Instruction.put_packet, - rank, - send_ref, - recv_ref, - next=set(), - prev=set(), - tb=tb, - channel_type=ch_type, - step=tb_step, - ) - else: - op = Op( - Instruction.put, - rank, - send_ref, - recv_ref, - next=set(), - prev=set(), - tb=tb, - channel_type=ch_type, - step=tb_step, - ) - buffer = send_ref.buffer - index = send_ref.index - size = send_ref.size - self._read(rank, buffer, index, size, op) - op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step)) - op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) - return op - - def add_get(self, rank, send_ref, recv_ref, tb, ch_type): - tb_step = self._get_tb_step(rank, tb) - op = Op( - Instruction.get, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step - ) - buffer = recv_ref.buffer - index = recv_ref.index - size = recv_ref.size - self._write(rank, buffer, index, size, op) - op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) - op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step)) - return op - - # InstructionDAG - adds a signal node. - def add_signal(self, rank, send_ref, recv_ref, tb, ch_type): - tb_step = self._get_tb_step(rank, tb) - op = Op( - Instruction.signal, - rank, - send_ref, - recv_ref, - next=set(), - prev=set(), - tb=tb, - channel_type=ch_type, - step=tb_step, - ) - buffer = send_ref.buffer - index = send_ref.index - size = send_ref.size - # treat signal as a write since it can not be executed parallelly with read operations - self._write(rank, buffer, index, size, op) - op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step)) - op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) - return op - - def add_wait(self, rank, dst_ref, src_ref, tb, ch_type): - tb_step = self._get_tb_step(rank, tb) - op = Op( - Instruction.wait, rank, src_ref, dst_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step - ) - buffer = dst_ref.buffer - index = dst_ref.index - size = dst_ref.size - self._write(rank, buffer, index, size, op) - op.srcs.append((ChunkRef(src_ref.rank, src_ref.buffer, src_ref.index, src_ref.size), tb_step)) - op.dsts.append((ChunkRef(dst_ref.rank, dst_ref.buffer, dst_ref.index, dst_ref.size), tb_step)) - return op - - def add_read_reduce(self, rank, send_ref, recv_ref, tb, ch_type): - tb_step = self._get_tb_step(rank, tb) - op = Op( - Instruction.read_reduce_copy, - rank, - send_ref, - recv_ref, - next=set(), - prev=set(), - tb=tb, - channel_type=ch_type, - step=tb_step, - ) - buffer = recv_ref.buffer - index = recv_ref.index - size = recv_ref.size - op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) - self._write(rank, buffer, index, size, op, read=True) - return op - - def complete_channels(self): - send_op = [Instruction.put, Instruction.signal, Instruction.put_packet] - recv_op = [Instruction.wait, Instruction.get, Instruction.read_reduce_copy] - for rank, rank_tbs in enumerate(self.tbs): - for tbid, tb in rank_tbs.items(): - chans = set() - for op in tb.ops: - src_buffer = ( - Buffer.scratch - if op.src.buffer is not Buffer.input and op.src.buffer is not Buffer.output - else op.src.buffer - ) - dst_buffer = ( - Buffer.scratch - if op.dst.buffer is not Buffer.input and op.dst.buffer is not Buffer.output - else op.dst.buffer - ) - if op.inst in send_op: - chan = Channel(src_buffer, dst_buffer, op.channel_type, op.dst.rank) - chans.add(chan) - elif op.inst in recv_op: - chan = Channel(src_buffer, dst_buffer, op.channel_type, op.src.rank) - chans.add(chan) - tb.channels = list(chans) - - def _optimize_redundant_signal_wait(self): - # For packet ops, we can remove signal/wait - for rank, rank_tbs in enumerate(self.tbs): - for tbid, tb in rank_tbs.items(): - queue = list(tb.ops) - while len(queue) > 0: - op = queue[0] - if op.inst == Instruction.put_packet: - fused = False - for next_op in op.next: - if next_op.inst == Instruction.signal: - remove_op(next_op) - fused = True - break - if fused: - continue - elif op.inst == Instruction.reduce_packet or op.inst == Instruction.copy_packet: - fused = False - for prev_op in op.prev: - if prev_op.inst == Instruction.wait: - remove_op(prev_op) - fused = True - break - if fused: - continue - queue = queue[1:] - - # rrc(_,_,_,dst,dbuf,di) rrc(_,_,_,dst,dbuf,di) -> rrc(list[src,sbuf,si], dst, dbuf, di) - # signal(_,_,_,dst,dbuf,di) signal(_,_,_,dst,dbuf,di) -> signal(_,_,_,list[dst,dbuf,di]) - # wait(src,sbuf,si,_,_,_) wait(src,sbuf,si,_,_,_) -> wait(list[src,sbuf,si],_,_,_,_]) - # reduce(_,_,_,dst,dbuf,di) reduce(_,_,_,dst,dbuf,di) -> reduce(list[src,sbuf,si], dst, dbuf, di) - # reduce_packet(_,_,_,dst,dbuf,di) reduce_packet(_,_,_,dst,dbuf,di) -> reduce_packet(list[src,sbuf,si], dst, dbuf, di) - def _optimize_rrc_r_signal_wait(self): - for rank, rank_tbs in enumerate(self.tbs): - for tbid, tb in rank_tbs.items(): - queue = list(tb.ops) - while len(queue) > 0: - op = queue[0] - if op.inst == Instruction.read_reduce_copy: - fused = False - for next_op in op.next: - if ( - next_op.inst == Instruction.read_reduce_copy - and same_count(op, next_op) - and same_buf_dst(op, next_op) - and same_chan_type(op, next_op) - ): - op.srcs.append( - ( - ChunkRef( - next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size - ), - next_op.step, - ) - ) - remove_op(next_op) - tb.ops.remove(next_op) - queue.remove(next_op) - fused = True - break - if fused: - continue - elif op.inst == Instruction.reduce: - fused = False - for next_op in op.next: - if ( - next_op.inst == Instruction.reduce - and same_buf_dst(op, next_op) - and same_chan_type(op, next_op) - ): - op.srcs.append( - ( - ChunkRef( - next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size - ), - next_op.step, - ) - ) - remove_op(next_op) - tb.ops.remove(next_op) - queue.remove(next_op) - fused = True - break - if fused: - continue - elif op.inst == Instruction.reduce_packet: - fused = False - for next_op in op.next: - if ( - next_op.inst == Instruction.reduce_packet - and same_buf_dst(op, next_op) - and same_chan_type(op, next_op) - ): - op.srcs.append( - ( - ChunkRef( - next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size - ), - next_op.step, - ) - ) - remove_op(next_op) - tb.ops.remove(next_op) - queue.remove(next_op) - fused = True - break - if fused: - continue - elif op.inst == Instruction.signal: - fused = False - for next_op in op.next: - if ( - next_op.inst == Instruction.signal - and same_buf_src(op, next_op) - and same_chan_type(op, next_op) - ): - op.dsts.append( - ( - ChunkRef( - next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size - ), - next_op.step, - ) - ) - op.srcs.append( - ( - ChunkRef( - next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size - ), - next_op.step, - ) - ) - remove_op(next_op) - tb.ops.remove(next_op) - queue.remove(next_op) - fused = True - break - if fused: - continue - elif op.inst == Instruction.wait: - fused = False - for next_op in op.next: - if ( - next_op.inst == Instruction.wait - and same_buf_dst(op, next_op) - and same_chan_type(op, next_op) - ): - op.srcs.append( - ( - ChunkRef( - next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size - ), - next_op.step, - ) - ) - op.dsts.append( - ( - ChunkRef( - next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size - ), - next_op.step, - ) - ) - remove_op(next_op) - tb.ops.remove(next_op) - queue.remove(next_op) - fused = True - break - if fused: - continue - queue = queue[1:] - - # rrc(_,_,_,dst,dbuf,di) put(dst,dbuf,di,_,_,_) -> rrcs(_,_,_,_,_,_) - # reduce(_,_,_,dst,dbuf,di) put(dst,dbuf,di,_,_,_) -> rs(_,_,_,_,_,_) - def _optimize_rrcs_rs(self): - for rank, rank_tbs in enumerate(self.tbs): - for tbid, tb in rank_tbs.items(): - queue = list(tb.ops) - while len(queue) > 0: - op = queue[0] - if op.inst == Instruction.read_reduce_copy or op.inst == Instruction.read_reduce_copy_send: - fused = False - for next_op in op.next: - if ( - next_op.inst == Instruction.put - and same_count(op, next_op) - and buf_dst_src_match(op, next_op) - and same_chan_type(op, next_op) - ): - if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer: - continue - if op.inst == Instruction.read_reduce_copy: - op.inst = Instruction.read_reduce_copy_send - op.dsts.append( - ( - ChunkRef( - next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size - ), - next_op.step, - ) - ) - remove_op(next_op) - tb.ops.remove(next_op) - queue.remove(next_op) - fused = True - break - if fused: - continue - if op.inst == Instruction.reduce or op.inst == Instruction.reduce_send: - fused = False - for next_op in op.next: - if ( - next_op.inst == Instruction.put - and same_count(op, next_op) - and buf_dst_src_match(op, next_op) - and next_op.channel_type == ChannelType.sm - ): - if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer: - continue - if op.inst == Instruction.reduce: - op.inst = Instruction.reduce_send - op.channel_type = ChannelType.sm - op.dsts.append( - ( - ChunkRef( - next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size - ), - next_op.step, - ) - ) - remove_op(next_op) - tb.ops.remove(next_op) - queue.remove(next_op) - fused = True - break - if fused: - continue - if op.inst == Instruction.reduce_packet or op.inst == Instruction.reduce_send_packet: - fused = False - for next_op in op.next: - if ( - next_op.inst == Instruction.put_packet - and same_count(op, next_op) - and buf_dst_src_match(op, next_op) - and next_op.channel_type == ChannelType.sm - ): - if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer: - continue - if op.inst == Instruction.reduce_packet: - op.inst = Instruction.reduce_send_packet - op.channel_type = ChannelType.sm - op.dsts.append( - ( - ChunkRef( - next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size - ), - next_op.step, - ) - ) - remove_op(next_op) - tb.ops.remove(next_op) - queue.remove(next_op) - fused = True - break - if fused: - continue - queue = queue[1:] - - # get(src, sbuf. si, dst, dbuf, di) get(src, sbuf, si, dst, dbuf, di) -> get(list[src,sbuf,si], list[dst,dbuf,di]) - # put(src, sbuf, si, dst, dbuf, di) put(src, sbuf, si, dst, dbuf, di) -> put(list[src,sbuf,si], list[dst,dbuf,di]) - def _optimize_get_put(self): - for rank, rank_tbs in enumerate(self.tbs): - for tbid, tb in rank_tbs.items(): - queue = list(tb.ops) - while len(queue) > 0: - op = queue[0] - if op.inst == Instruction.get: - fused = False - if len(queue) > 1: - seq_op = queue[1] - if ( - seq_op.inst == Instruction.get - and same_src_dst_buffer_type(op, seq_op) - and same_chan_type(op, seq_op) - and same_count(op, seq_op) - ): - op.dsts.append( - ( - ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), - seq_op.step, - ) - ) - op.srcs.append( - ( - ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), - seq_op.step, - ) - ) - merge_op(op, seq_op) - tb.ops.remove(seq_op) - queue.remove(seq_op) - fused = True - if fused: - continue - elif op.inst == Instruction.put: - fused = False - if len(queue) > 1: - seq_op = queue[1] - if ( - seq_op.inst == Instruction.put - and same_src_dst_buffer_type(op, seq_op) - and same_chan_type(op, seq_op) - and same_count(op, seq_op) - ): - op.dsts.append( - ( - ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), - seq_op.step, - ) - ) - op.srcs.append( - ( - ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), - seq_op.step, - ) - ) - merge_op(op, seq_op) - tb.ops.remove(seq_op) - queue.remove(seq_op) - fused = True - if fused: - continue - elif op.inst == Instruction.put_packet: - fused = False - if len(queue) > 1: - seq_op = queue[1] - if ( - seq_op.inst == Instruction.put_packet - and same_src_dst_buffer_type(op, seq_op) - and same_chan_type(op, seq_op) - and same_count(op, seq_op) - ): - op.dsts.append( - ( - ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), - seq_op.step, - ) - ) - op.srcs.append( - ( - ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), - seq_op.step, - ) - ) - merge_op(op, seq_op) - tb.ops.remove(seq_op) - queue.remove(seq_op) - fused = True - queue = queue[1:] - - # For signal/wait ops, if they are independent of other operations and no other operations in between, - # then merge them into a single signal/wait op - # wait(src,sbuf,si,_,_,_) wait(src,sbuf,si,_,_,_) -> wait(list[src,sbuf,si],_,_,_,_]) - def _parallel_signal_wait(self): - for rank, rank_tbs in enumerate(self.tbs): - for tbid, tb in rank_tbs.items(): - if tbid == -1: - continue - queue = list(tb.ops) - while len(queue) > 0: - op = queue[0] - if op.inst == Instruction.signal: - fused = False - if len(queue) > 1: - seq_op = queue[1] - if ( - seq_op.inst == Instruction.signal - and same_src_dst_buffer_type(op, seq_op) - and same_chan_type(op, seq_op) - ): - op.dsts.append( - ( - ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), - seq_op.step, - ) - ) - op.srcs.append( - ( - ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), - seq_op.step, - ) - ) - merge_op(op, seq_op) - tb.ops.remove(seq_op) - queue.remove(seq_op) - fused = True - if fused: - continue - elif op.inst == Instruction.wait: - fused = False - if len(queue) > 1: - seq_op = queue[1] - if ( - seq_op.inst == Instruction.wait - and same_src_dst_buffer_type(op, seq_op) - and same_chan_type(op, seq_op) - ): - op.dsts.append( - ( - ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), - seq_op.step, - ) - ) - op.srcs.append( - ( - ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), - seq_op.step, - ) - ) - merge_op(op, seq_op) - tb.ops.remove(seq_op) - queue.remove(seq_op) - fused = True - if fused: - continue - queue = queue[1:] - - def _get_tb_step(self, rank: int, tb: int): - if tb in self.tb_steps[rank]: - self.tb_steps[rank][tb] += 1 - return self.tb_steps[rank][tb] - else: - self.tb_steps[rank][tb] = 0 - return 0 - - def optimize(self): - self._optimize_redundant_signal_wait() - self._optimize_rrc_r_signal_wait() - self._optimize_rrcs_rs() - self._optimize_get_put() - - self._parallel_signal_wait() - - def replicate(self, instances: int, replication_policy: ReplicationPolicy): - # update op step - for rank, rank_tbs in enumerate(self.tbs): - for _, tb in rank_tbs.items(): - for id, op in enumerate(tb.ops): - op.step = id - - if instances == 1: - self.instanced_tbs = self.tbs - return - - self.instanced_tbs = [] - for _ in range(self.num_ranks): - self.instanced_tbs.append({}) - - def is_scratch(buffer): - return buffer != Buffer.input and buffer != Buffer.output - - def get_new_index(rank, buffer, index, size, i): - # Scratch buffers always use batched - if is_scratch(buffer): - buf_instance_len = self.buffers[rank][buffer].instance_size() - return buf_instance_len * i + index - return len(self.buffers[rank][buffer]) * i + index - - def get_instance_ref(ref): - iindex = get_new_index(ref.rank, ref.buffer, ref.index, ref.size, i) - iref = ChunkRef(ref.rank, ref.buffer, iindex, ref.size) - return iref - - if replication_policy == ReplicationPolicy.duplicated: - for i in range(instances): - # Generate all the threadblocks and ops - for rank, rank_tbs in enumerate(self.tbs): - # rank_channels = self.num_channels[rank] - for tbid, tb in rank_tbs.items(): - itbid = tbid * instances + i - itb = Threadblock(id=itbid) - itb.ops = [None] * len(tb.ops) - for s, op in enumerate(tb.ops): - isrc = get_instance_ref(op.src) - idst = get_instance_ref(op.dst) - idepends = [] - # Note: We don't need the fill out the rest of the metadata since replication is the last optimization - iop = Op( - op.inst, op.rank, isrc, idst, idepends, op.step, itbid, channel_type=op.channel_type - ) - itb.ops[s] = iop - for src, step in op.srcs: - isrc = get_instance_ref(src) - iop.srcs.append((isrc, step)) - for dst, step in op.dsts: - idst = get_instance_ref(dst) - iop.dsts.append((idst, step)) - for chan in tb.channels: - itb.channels.append(chan) - self.instanced_tbs[op.rank][itbid] = itb - - # Redo dependency analysis - for rank, rank_tbs in enumerate(self.tbs): - for tbid, tb in rank_tbs.items(): - for i in range(instances): - itbid = tbid * instances + i - itb = self.instanced_tbs[rank][itbid] - for op, iop in zip(tb.ops, itb.ops): - iop.depends = [None] * len(op.depends) - for s, dep in enumerate(op.depends): - dep_tbid = dep.tb - dep_itbid = dep_tbid * instances + i - dep_step = dep.step - iop.depends[s] = self.instanced_tbs[op.rank][dep_itbid].ops[dep_step] diff --git a/msccl/language/mscclpp/ir.py b/msccl/language/mscclpp/ir.py deleted file mode 100644 index 2c8f37e..0000000 --- a/msccl/language/mscclpp/ir.py +++ /dev/null @@ -1,301 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from collections import defaultdict -from dataclasses import dataclass -import json - -from msccl.language.types import Buffer, ChannelType, Op, Program, MscclppInstruction as Instruction - -_local_src_insts_mscclpp = { - Instruction.put, - Instruction.put_packet, - Instruction.signal, - Instruction.copy, - Instruction.copy_packet, - Instruction.reduce, - Instruction.reduce_packet, - Instruction.reduce_send, - Instruction.reduce_send_packet, -} -_local_dst_insts_mscclpp = { - Instruction.get, - Instruction.wait, - Instruction.read_reduce_copy, - Instruction.copy, - Instruction.copy_packet, - Instruction.reduce, - Instruction.read_reduce_copy_send, - Instruction.reduce_send, - Instruction.reduce_packet, - Instruction.reduce_send_packet, -} - - -def ir_to_json(program: Program): - # Figure out sizes of buffers based on usage - buffer_sizes = defaultdict(lambda: 0) - for gpu in program.gpus: - for tb in gpu.threadblocks: - for op in tb.ops: - if op.inst in _local_src_insts_mscclpp: - key = (gpu.rank, op.src.buffer) - buffer_sizes[key] = max(buffer_sizes[key], op.src.index + op.src.size) - for src in op.srcs: - key = (gpu.rank, src.buffer) - buffer_sizes[key] = max(buffer_sizes[key], src.index + src.size) - if op.inst in _local_dst_insts_mscclpp: - key = (gpu.rank, op.dst.buffer) - buffer_sizes[key] = max(buffer_sizes[key], op.dst.index + op.dst.size) - # ignore remote buffers - if ( - op.inst != Instruction.read_reduce_copy_send - and op.inst != Instruction.reduce_send - and op.inst != Instruction.reduce_send_packet - ): - for dst in op.dsts: - key = (gpu.rank, dst.buffer) - buffer_sizes[key] = max(buffer_sizes[key], dst.index + dst.size) - for gpu in program.gpus: - gpu.input_chunks = max(buffer_sizes[(gpu.rank, Buffer.input)], gpu.input_chunks) - gpu.output_chunks = max(buffer_sizes[(gpu.rank, Buffer.output)], gpu.output_chunks) - gpu.scratch_chunks = max(buffer_sizes[(gpu.rank, Buffer.scratch)], gpu.scratch_chunks) - - # Since LL protocol will double the scratch size. We need to make sure all GPUs have the same scratch size. - # Otherwise the offset calculation will be wrong. - if program.protocol == "LL": - max_scratch = max(gpu.scratch_chunks for gpu in program.gpus) - for gpu in program.gpus: - gpu.scratch_chunks = max_scratch - - # get channel info for each GPU and threadblock - for gpu in program.gpus: - gpu.threadblocks = sorted(gpu.threadblocks, key=lambda tb: tb.id) - chan_dict = {} - # the channel key is the tuple (srcBuffer, dstBuffer, type) - for tb in gpu.threadblocks: - for ch in tb.channels: - key = (ch.srcBuffer, ch.dstBuffer, ch.type) - if key not in chan_dict: - chan_dict[key] = [(tb.id, ch.connected_to)] - else: - chan_dict[key].append((tb.id, ch.connected_to)) - for key, value in chan_dict.items(): - chan_dict[key] = sorted(value) - gpu.channels = chan_dict - - # Remove the dependencies of wait after signal. They are actually depends on remote chunk - for gpu in program.gpus: - for tb in gpu.threadblocks: - for op in tb.ops: - if op.inst == Instruction.wait: - op.depends = list(filter(lambda dep: dep.inst != Instruction.signal, op.depends)) - - # Filter out redundant dependencies - # e.g. if op1 and op2 depend on op, and op1 happends before op2 - # then op2 does not need to explicitly depend on op - for gpu in program.gpus: - for tb in gpu.threadblocks: - running_depends = [] - for op in tb.ops: - op.depends = list(filter(lambda dep: dep not in running_depends, op.depends)) - running_depends = running_depends + op.depends - - # Do some additional postprocessing of operations: - # - Expand operations with dependencies with no-ops - if program.protocol != "LL": # TODO(binyli): fix this. Should based on OP type not algorithm - for gpu in program.gpus: - for tb in gpu.threadblocks: - new_ops = [] - for op in tb.ops: - # Expand extra dependencies into nop operations - for i, dep in enumerate(op.depends): - new_ops.append(Op(Instruction.nop, -1, None, None, [dep])) - new_ops.append(op) - tb.ops = new_ops - - # update step and tid for ops - for gpu in program.gpus: - for tb in gpu.threadblocks: - for i, op in enumerate(tb.ops): - op.step = i - op.tb = tb.id - - # Need to calculate channel info for each GPU - nchannels = 0 - for gpu in program.gpus: - max_tb_channels = 0 - if len(gpu.threadblocks) > 0: - max_tb_channels = max(tb.channel + 1 for tb in gpu.threadblocks) - nchannels = max(nchannels, max_tb_channels) - return dump_to_json(program) - - -def dump_to_json(program: Program): - gpus = [] - - def get_channel_ids(chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_type): - channel_ids = [] - for c in chunk_list: - key = (src_buffer, dst_buffer, chan_type) - channel_ids.extend( - [ - {"id": id, "off": c.index} - for id, ele in enumerate(tb_channel_dict[key]["connectedTo"]) - if ele == c.rank - ] - ) - return channel_ids - - def remove_empty_fields(d): - return {k: v for k, v in d.items() if v not in [None, "", [], {}]} - - for id, gpu in enumerate(program.gpus): - gpu_instance = { - "id": id, - "inputChunks": gpu.input_chunks, - "outputChunks": gpu.output_chunks, - "scratchChunks": gpu.scratch_chunks, - "chunkGroups": program.num_chunk_groups, - "threadblocks": [], - "channels": [], - } - for (srcBuffer, dstBuffer, type), channels in gpu.channels.items(): - obj = { - "srcbuff": srcBuffer.value if hasattr(srcBuffer, "value") else srcBuffer, - "dstbuff": dstBuffer.value if hasattr(dstBuffer, "value") else dstBuffer, - "type": type.value, - "connectedTo": [eles[1] for eles in channels], - } - gpu_instance["channels"].append(obj) - gpu_instance["channels"] = list(filter(lambda x: x["type"] != "none", gpu_instance["channels"])) - gpu_instance["channels"] = sorted(gpu_instance["channels"], key=lambda x: (x["srcbuff"], x["dstbuff"])) - for tb in gpu.threadblocks: - if tb.id < 0: - continue - ops = [] - tb_channels = [] - tb_channel_dict = {} - for (srcBuffer, dstBuffer, type), channels in gpu.channels.items(): - obj = { - "srcbuff": srcBuffer.value if hasattr(srcBuffer, "value") else srcBuffer, - "dstbuff": dstBuffer.value if hasattr(dstBuffer, "value") else dstBuffer, - "type": type.value, - "chanIds": [id for id, ele in enumerate(channels) if ele[0] == tb.id], - "connectedTo": [ele[1] for ele in channels if ele[0] == tb.id], - } - tb_channel_dict[(srcBuffer, dstBuffer, type)] = obj - tb_channels.append(obj) - tb_channels = filter(lambda x: x["type"] != "none", tb_channels) - tb_channels = sorted(tb_channels, key=lambda x: (x["srcbuff"], x["dstbuff"])) - for op in tb.ops: - o_buff = None - i_buff = None - dst_channel_ids = [] - src_channel_ids = [] - srcs = [] - dsts = [] - src = None - dst = None - if op.tb == -1: - continue - if op.inst == Instruction.signal: - # get dst channel ids - dst_channel_ids = get_channel_ids( - op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type - ) - o_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} - elif op.inst == Instruction.wait: - # get src channel ids - src_channel_ids = get_channel_ids( - op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type - ) - i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} - elif op.inst == Instruction.read_reduce_copy: - src_channel_ids = get_channel_ids( - op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type - ) - i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} - dst = op.dst - src = op.dst # TODO(binyli): fix this - elif op.inst == Instruction.read_reduce_copy_send: - src_channel_ids = get_channel_ids( - op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type - ) - dst_channel_ids = get_channel_ids( - op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, op.channel_type - ) - i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} - o_buff = {"src": op.dst.buffer.value, "dst": op.dsts[0].buffer.value} - dst = op.dst - src = op.dst # TODO(binyli): fix this - elif op.inst == Instruction.reduce_send or op.inst == Instruction.reduce_send_packet: - dst_channel_ids = get_channel_ids( - op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, ChannelType.sm - ) - o_buff = {"src": op.dst.buffer.value, "dst": op.dsts[0].buffer.value} - srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs)) - dst = op.dst - src = op.dst # TODO(binyli): fix this - elif op.inst == Instruction.reduce: - srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs)) - dst = op.dst - elif op.inst == Instruction.nop: - instr = { - "name": op.inst.value, - "deps": list(map(lambda dep: {"tb": dep.tb, "step": dep.step}, op.depends)), - } - elif op.inst == Instruction.put or op.inst == Instruction.put_packet: - dst_channel_ids = get_channel_ids( - op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type - ) - o_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} - srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs)) - elif op.inst == Instruction.get: - src_channel_ids = get_channel_ids( - op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type - ) - i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} - dsts = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.dsts)) - elif op.inst == Instruction.copy or op.inst == Instruction.copy_packet: - src = op.src - dst = op.dst - if op.inst != Instruction.nop: - instr = { - "name": op.inst.value, - "i_buff": i_buff, - "i_cids": src_channel_ids, - "o_buff": o_buff, - "o_cids": dst_channel_ids, - "src": src.rank if src else None, - "srcs": srcs if srcs else None, - "dsts": dsts if dsts else None, - "srcbuff": src.buffer.value if src and src.buffer else None, - "srcoff": src.index if src else None, - "dst": dst.rank if dst else None, - "dstbuff": dst.buffer.value if dst and dst.buffer else None, - "dstoff": dst.index if dst else None, - "ctype": op.channel_type.value, - "cnt": op.cnt(), - } - ops.append(remove_empty_fields(instr)) - threadblock = { - "id": tb.id, - "ops": ops, - "channels": list( - map( - lambda x: {"src": x["srcbuff"], "dst": x["dstbuff"], "ctype": x["type"], "cids": x["chanIds"]}, - tb_channels, - ) - ), - } - gpu_instance["threadblocks"].append(threadblock) - gpus.append(gpu_instance) - obj = { - "name": program.name, - "colletive": program.collective, - "protocol": program.protocol, - "inplace": program.inplace, - "gpus": gpus, - } - return json.dumps(obj, indent=2) diff --git a/msccl/language/types.py b/msccl/language/types.py index 97a4277..6236457 100644 --- a/msccl/language/types.py +++ b/msccl/language/types.py @@ -3,7 +3,6 @@ from dataclasses import dataclass, field from enum import Enum -from typing import Union from msccl.language.buffer import Buffer from msccl.language.channel import ChannelType @@ -104,27 +103,6 @@ def __str__(self): return self.value -class MscclppInstruction(Enum): - nop = "nop" - read_reduce_copy = "rrc" - read_reduce_copy_send = "rrcs" - reduce_send = "rs" - copy = "copy" - reduce = "reduce" - copy_packet = "cpkt" - reduce_send_packet = "rspkt" - reduce_packet = "rpkt" - put = "put" - put_packet = "ppkt" - get = "get" - wait = "wait" - signal = "signal" - flush = "flush" - - def __str__(self): - return self.value - - @dataclass class ChunkRef: rank: int @@ -138,7 +116,7 @@ def __hash__(self): @dataclass class Op: - inst: Union[Instruction, MscclppInstruction] + inst: Instruction rank: int src: ChunkRef dst: ChunkRef diff --git a/tests/test_language.py b/tests/test_language.py index bcd22ab..fcb5943 100755 --- a/tests/test_language.py +++ b/tests/test_language.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import msccl -from msccl.language.types import MscclppInstruction from msccl.topologies import line, fully_connected from msccl.language import * from msccl.language.routines import * @@ -202,32 +201,6 @@ def test_instruction_fusion(): assert lowered_prgm.gpus[1].threadblocks[0].ops[1].inst == Instruction.recv assert lowered_prgm.gpus[2].threadblocks[0].ops[0].inst == Instruction.recv_reduce_copy_send - -def test_instruction_fusion_mscclpp(): - topology = fully_connected(3) - collective = AllReduce(3, 3, True) - prgm = MSCCLPPProgram("allreduce", topology, collective, 1) - with prgm: - c01 = chunk(1, Buffer.input, 0, 3).reduce(chunk(0, Buffer.input, 0, 3), recvtb=0) - c01.signal(2, Buffer.input, 0, sendtb=0) - c012 = chunk(2, Buffer.input, 0, 3) - c012.wait(1, Buffer.input, 0, recvtb=0) - c012.reduce(c01, recvtb=0).put(0, Buffer.input, 0, sendtb=0) - c012.signal(0, Buffer.input, 0, sendtb=0) - c0 = chunk(0, Buffer.input, 0, 3) - c0.wait(2, Buffer.input, 0, recvtb=0) - c0.put(1, Buffer.input, 0, sendtb=0) - assert Check() - lowered_prgm = prgm.lower() - assert lowered_prgm.gpus[0].threadblocks[0].ops[0].inst == MscclppInstruction.wait - assert lowered_prgm.gpus[0].threadblocks[0].ops[1].inst == MscclppInstruction.put - assert lowered_prgm.gpus[1].threadblocks[0].ops[0].inst == MscclppInstruction.read_reduce_copy - assert lowered_prgm.gpus[1].threadblocks[0].ops[1].inst == MscclppInstruction.signal - assert lowered_prgm.gpus[2].threadblocks[0].ops[0].inst == MscclppInstruction.wait - assert lowered_prgm.gpus[2].threadblocks[0].ops[1].inst == MscclppInstruction.read_reduce_copy_send - assert lowered_prgm.gpus[2].threadblocks[0].ops[2].inst == MscclppInstruction.signal - - def test_replication(): topology = fully_connected(2) collective = AllToAll(2, 1, False) @@ -315,37 +288,3 @@ def test_routines_allreduce_nodes(): c = c.copy(r, Buffer.output, exchange_index) XML() assert Check() - -def test_routines_allreduce_packet_inplace_mscclpp(): - size = 8 - topology = fully_connected(size) - collective = AllReduce(size, size * size, True) - with MSCCLPPProgram("allreduce_packet", topology, collective, 2, protocol="LL"): - # Each rank sends the nth chunk to the nth rank into scratch space - for r1 in range(size): - for tb in range(size): - if tb == r1: - continue - remote_rank = tb - index = remote_rank * size - c = chunk(r1, Buffer.input, index, size) - c.put_packet(remote_rank, "scratch", index=r1 * size, sendtb=tb) - # Each rank performs a local reduction on the nth chunk - # Utilize 8 threadblocks for this reduction for better parallelism - for r in range(size): - for index in range(size): - c = chunk(r, Buffer.input, r * size + index) - for peer in range(size): - if peer != r: - c.reduce_packet(chunk(r, "scratch", peer * size + index), recvtb=index) - for peer in range(size): - if peer != r: - c.put_packet(peer, "scratch", (size * size) + r * size + index, sendtb=index) - # Each rank get final result from scratch space - for r in range(size): - for peer in range(size): - if peer != r: - c = chunk(r, "scratch", size * size + peer * size, size) - c.copy_packet(r, Buffer.input, peer * size, sendtb=peer) - Json() - assert Check() From 8f46276f6482d5f59e54ed05757b9858d494f158 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 29 Apr 2024 03:40:11 +0000 Subject: [PATCH 53/76] revert --- .../mscclang/allreduce_mi300_sm_mscclpp.py | 69 ------------------- .../mscclang/allreduce_mi300_sm_mscclpp2.py | 69 ------------------- 2 files changed, 138 deletions(-) delete mode 100644 examples/mscclang/allreduce_mi300_sm_mscclpp.py delete mode 100644 examples/mscclang/allreduce_mi300_sm_mscclpp2.py diff --git a/examples/mscclang/allreduce_mi300_sm_mscclpp.py b/examples/mscclang/allreduce_mi300_sm_mscclpp.py deleted file mode 100644 index 4e97661..0000000 --- a/examples/mscclang/allreduce_mi300_sm_mscclpp.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import argparse -from msccl.language import * -from msccl.topologies import * -from msccl.language.collectives import AllReduce - - -def allreduce(gpus, instances, protocol): - size = gpus - chunksperloop = gpus * (gpus - 1) - topology = fully_connected(size) - collective = AllReduce(size, chunksperloop, True) - with MSCCLPPProgram( - "allreduce_mi300", - topology, - collective, - instances, - protocol=protocol, - ): - for rank in range(size): - for tb in range(size - 1): - for i in range(size): - nghr = (rank + tb + 1 + i) % size - if rank == nghr: - continue - c = chunk(rank, Buffer.input, nghr * (size - 1) + tb) - c.put(nghr, "scratch", rank * (size - 1) + tb, sendtb=tb) - for rank in range(size): - for tb in range(size - 1): - for i in range(size): - nghr = (rank + tb + 1 + i) % size - if rank == nghr: - continue - c = chunk(rank, Buffer.input, nghr * (size - 1) + tb) - c.signal(nghr, "scratch", rank * (size - 1) + tb, sendtb=tb) - for i in range(size): - nghr = (rank + tb + 1 + i) % size - if rank == nghr: - continue - c = chunk(rank, "scratch", nghr * (size - 1) + tb) - c.wait(nghr, Buffer.input, rank * (size - 1) + tb, recvtb=tb) - - for rank in range(size): - for tb in range(size - 1): - c = chunk(rank, Buffer.input, rank * (size - 1) + tb) - for nghr in range(size): - if rank != nghr: - index = nghr * (size - 1) - c.reduce(chunk(rank, "scratch", index + tb), recvtb=tb) - for i in range(size): - nghr = (rank + i) % size - index = rank * (size-1) - if rank != nghr: - c = chunk(rank, Buffer.input, index + tb) - c.put(nghr, Buffer.input, index + tb, sendtb=tb) - - Json() - - -parser = argparse.ArgumentParser() -parser.add_argument("num_gpus", type=int, help="number of gpus") -parser.add_argument("instances", type=int, help="number of instances") -parser.add_argument("--protocol", type=str, default="Simple", choices=["Simple"], help="Protocol") - -args = parser.parse_args() - -allreduce(args.num_gpus, args.instances, args.protocol) diff --git a/examples/mscclang/allreduce_mi300_sm_mscclpp2.py b/examples/mscclang/allreduce_mi300_sm_mscclpp2.py deleted file mode 100644 index 80a9e6d..0000000 --- a/examples/mscclang/allreduce_mi300_sm_mscclpp2.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import argparse -from msccl.language import * -from msccl.topologies import * -from msccl.language.collectives import AllReduce - - -def allreduce(gpus, instances, protocol): - size = gpus - chunksperloop = gpus * gpus - topology = fully_connected(size) - collective = AllReduce(size, chunksperloop, True) - with MSCCLPPProgram( - "allreduce_mi300", - topology, - collective, - instances, - protocol=protocol, - ): - for rank in range(size): - for tb in range(size): - for i in range(size): - nghr = (rank + tb + 1 + i) % size - if rank == nghr: - continue - c = chunk(rank, Buffer.input, nghr * size + tb) - c.put(nghr, "scratch", rank * size + tb, sendtb=tb) - for rank in range(size): - for tb in range(size): - for i in range(size): - nghr = (rank + tb + 1 + i) % size - if rank == nghr: - continue - c = chunk(rank, Buffer.input, nghr * size + tb) - c.signal(nghr, "scratch", rank * size + tb, sendtb=tb) - for i in range(size): - nghr = (rank + tb + 1 + i) % size - if rank == nghr: - continue - c = chunk(rank, "scratch", nghr * size + tb) - c.wait(nghr, Buffer.input, rank * size + tb, recvtb=tb) - - for rank in range(size): - for tb in range(size): - c = chunk(rank, Buffer.input, rank * size + tb) - for nghr in range(size): - if rank != nghr: - index = nghr * size - c.reduce(chunk(rank, "scratch", index + tb), recvtb=tb) - for i in range(size): - nghr = (rank + i) % size - index = rank * size - if rank != nghr: - c = chunk(rank, Buffer.input, index + tb) - c.put(nghr, Buffer.input, index + tb, sendtb=tb) - - Json() - - -parser = argparse.ArgumentParser() -parser.add_argument("num_gpus", type=int, help="number of gpus") -parser.add_argument("instances", type=int, help="number of instances") -parser.add_argument("--protocol", type=str, default="Simple", choices=["Simple"], help="Protocol") - -args = parser.parse_args() - -allreduce(args.num_gpus, args.instances, args.protocol) From 775d9d60a7bea3b9db225fa4965f55faa6f51238 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 29 Apr 2024 03:53:22 +0000 Subject: [PATCH 54/76] fix --- msccl/language/collectives.py | 2 +- msccl/language/types.py | 6 ------ 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/msccl/language/collectives.py b/msccl/language/collectives.py index 74db17a..8afc458 100755 --- a/msccl/language/collectives.py +++ b/msccl/language/collectives.py @@ -118,7 +118,7 @@ def get_buffer_index(self, rank, buffer, index): class AllReduce(Collective): def __init__(self, num_ranks, chunk_factor, inplace): - Collective.__init__(self, num_ranks, chunk_factor, inplace, num_ranks) + Collective.__init__(self, num_ranks, chunk_factor, inplace) self.name = "allreduce" def init_buffers(self): diff --git a/msccl/language/types.py b/msccl/language/types.py index 6236457..e914dfd 100644 --- a/msccl/language/types.py +++ b/msccl/language/types.py @@ -5,7 +5,6 @@ from enum import Enum from msccl.language.buffer import Buffer -from msccl.language.channel import ChannelType @dataclass @@ -32,7 +31,6 @@ class Gpu: output_chunks: int = 0 scratch_chunks: int = 0 scratch: dict = field(default_factory=dict) - channels: dict = field(default_factory=dict) def scratch_size(self): return max((idx for addr, idx in self.scratch.items()), default=-1) + 1 @@ -46,7 +44,6 @@ class Threadblock: recv: int = -1 ops: list = field(default_factory=list) rbid: int = -1 # threadblock id of the receiver - channels: list = field(default_factory=list) def __eq__(self, other): return self is other @@ -73,8 +70,6 @@ def __str__(self): class ReplicationPolicy(Enum): - # this means pack multi instrances to deal with the same chunk and share the channels - packed = "packed" # this means each instance deal with the different chunk # Chunk A, Chunk B -> Chunk A0, Chunk B0, Chunk A1, Chunk B1 duplicated = "duplicated" @@ -131,7 +126,6 @@ class Op: recv_match = None send_match = None channel: int = -1 - channel_type: ChannelType = ChannelType.none srcs: list = field(default_factory=list) dsts: list = field(default_factory=list) From 8a58c84e9aeab07f8732eda9a9de4d3aaba0ee18 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 29 Apr 2024 08:11:10 +0000 Subject: [PATCH 55/76] WIP --- .github/workflows/tests.yaml | 38 +++++++++++++- tests/configs/example-config.json | 86 +++++++++++++++++++++++++++++++ tests/generate_exmpale_results.py | 46 +++++++++++++++++ 3 files changed, 168 insertions(+), 2 deletions(-) create mode 100644 tests/configs/example-config.json create mode 100644 tests/generate_exmpale_results.py diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index f3b6c8c..c8b217c 100755 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -16,9 +16,9 @@ jobs: name: Test with Python ${{ matrix.python-version }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install msccl and dependencies @@ -28,3 +28,37 @@ jobs: - name: Run tests and check at least 85% coverage run: | pytest + + compare_outputs: + runs-on: ubuntu-latest + name: Compare outputs + + steps: + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Checkout main branch + uses: actions/checkout@v4 + with: + ref: main + - name: Install msccl and dependencies + run: | + pip install --upgrade pip + pip install -r requirements.txt + - name: generate outputs + run: | + python tests/generate_exmpale_results.py -i examples/mscclang/ -o $RUNNER_TEMP/tests/main-outputs/ -c tests/configs/example-config.json + - name: Checkout current branch + uses: actions/checkout@v4 + - name: Install msccl and dependencies + run: | + pip install --upgrade pip + pip install -r requirements.txt + - name: generate outputs + run: | + python tests/generate_exmpale_results.py -i examples/mscclang/ -o $RUNNER_TEMP/tests/pr-outputs/ -c tests/configs/example-config.json + - name: Compare outputs + run: | + diff -rw $RUNNER_TEMP/tests/main-outputs/ $RUNNER_TEMP/tests/pr-outputs/ + diff --git a/tests/configs/example-config.json b/tests/configs/example-config.json new file mode 100644 index 0000000..caffdf8 --- /dev/null +++ b/tests/configs/example-config.json @@ -0,0 +1,86 @@ +[ + { + "filename": "allgather_a100_pcie.py", + "args": ["4", "2"] + }, + { + "filename": "allgather_allpairs.py", + "args": ["8", "2"] + }, + { + "filename": "allgather_recursive_doubling.py", + "args": ["8", "4"] + }, + { + "filename": "allgather_ring.py", + "args": ["8", "8", "4"] + }, + { + "filename": "allreduce_1step.py", + "args": ["8", "1"] + }, + { + "filename": "allreduce_a100_allpairs_v2.py", + "args": ["8", "2"] + }, + { + "filename": "allreduce_a100_allpairs.py", + "args": ["8", "4"] + }, + { + "filename": "allreduce_a100_multinode_allpairs.py", + "args": ["16", "2"] + }, + { + "filename": "allreduce_a100_ncv4_v2.py", + "args": ["4", "2"] + }, + { + "filename": "allreduce_a100_ncv4.py", + "args": ["4", "2"] + }, + { + "filename": "allreduce_a100_pcie_hierarchical.py", + "args": ["8", "2"] + }, + { + "filename": "allreduce_a100_ring.py", + "args": ["8", "8", "4"] + }, + { + "filename": "allreduce_recursive_doubling_halving.py", + "args": ["8", "4"] + }, + { + "filename": "alltoall_a100_three_step.py", + "args": ["2", "8", "2"] + }, + { + "filename": "alltoall_a100_two_step.py", + "args": ["8", "2"] + }, + { + "filename": "alltoall_allpairs.py", + "args": ["8", "2"] + }, + { + "filename": "alltonext_backward.py", + "args": ["8", "2"] + }, + { + "filename": "alltonext_forward.py", + "args": ["8", "2"] + }, + { + "filename": "hierarchical_allreduce.py", + "args": ["8", "2", "2"] + }, + { + "filename": "pipeline_a100_allpairs.py", + "args": ["8", "2"] + }, + { + "filename": "pipeline_a100_ring.py", + "args": ["8", "4", "2"] + } +] diff --git a/tests/generate_exmpale_results.py b/tests/generate_exmpale_results.py new file mode 100644 index 0000000..dcf1980 --- /dev/null +++ b/tests/generate_exmpale_results.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +import json +from pathlib import Path +import subprocess + + +def run_examples(input_folder, configs, output_folder): + for config in configs: + file_name = config['filename'] + args = config['args'] + + input_file_path = Path(input_folder) / file_name + # Strip the ".py" from the filename and add ".output" + base_file_name = file_name[:-3] if file_name.endswith('.py') else file_name + output_file_path = Path(output_folder) / f"{base_file_name}.xml" + + # Construct the command to run the Python script + command = ["python3", str(input_file_path)] + args + + # Run the command and capture output + with open(output_file_path, 'w') as output_file: + result = subprocess.run(command, stdout=output_file, stderr=subprocess.STDOUT, text=True) + + # Optional: Check the return code to handle errors + if result.returncode != 0: + print(f"Error running {file_name}. See {output_file_path} for details.") + + +def main(input_folder, config_path, output_folder): + with open(config_path, "r") as f: + config = json.load(f) + + Path(output_folder).mkdir(parents=True, exist_ok=True) + run_examples(input_folder, config, output_folder) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Process files according to a configuration and save the results.") + parser.add_argument("input_folder", type=str, help="Path to the folder containing the input files.") + parser.add_argument("config", type=str, help="Path to the configuration file.") + parser.add_argument("output_folder", type=str, help="Path to the folder where the processed files will be saved.") + args = parser.parse_args() + main(args.input_folder, args.config, args.output_folder) From da282fb04c34e0789f4ed19c70734a19d02e71fe Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 29 Apr 2024 08:14:20 +0000 Subject: [PATCH 56/76] WIP --- .github/workflows/tests.yaml | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index c8b217c..1151452 100755 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -38,26 +38,29 @@ jobs: uses: actions/setup-python@v2 with: python-version: 3.8 - - name: Checkout main branch + - name: Checkout current branch uses: actions/checkout@v4 - with: - ref: main - name: Install msccl and dependencies run: | pip install --upgrade pip pip install -r requirements.txt + - name: Copy test script to temp directory + run: | + cp tests/generate_exmpale_results.py $RUNNER_TEMP/ - name: generate outputs run: | - python tests/generate_exmpale_results.py -i examples/mscclang/ -o $RUNNER_TEMP/tests/main-outputs/ -c tests/configs/example-config.json - - name: Checkout current branch + python $RUNNER_TEMP/generate_exmpale_results.py -i examples/mscclang/ -o $RUNNER_TEMP/tests/main-outputs/ -c tests/configs/example-config.json + - name: Checkout main branch uses: actions/checkout@v4 + with: + ref: main - name: Install msccl and dependencies run: | pip install --upgrade pip pip install -r requirements.txt - name: generate outputs run: | - python tests/generate_exmpale_results.py -i examples/mscclang/ -o $RUNNER_TEMP/tests/pr-outputs/ -c tests/configs/example-config.json + python $RUNNER_TEMP/generate_exmpale_results.py -i examples/mscclang/ -o $RUNNER_TEMP/tests/pr-outputs/ -c tests/configs/example-config.json - name: Compare outputs run: | diff -rw $RUNNER_TEMP/tests/main-outputs/ $RUNNER_TEMP/tests/pr-outputs/ From 1f5114b7d3a868814c3b2b84b2add7678c321ba5 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 29 Apr 2024 08:16:01 +0000 Subject: [PATCH 57/76] WIP --- .github/workflows/tests.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 1151452..c8703bf 100755 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -49,7 +49,7 @@ jobs: cp tests/generate_exmpale_results.py $RUNNER_TEMP/ - name: generate outputs run: | - python $RUNNER_TEMP/generate_exmpale_results.py -i examples/mscclang/ -o $RUNNER_TEMP/tests/main-outputs/ -c tests/configs/example-config.json + python $RUNNER_TEMP/generate_exmpale_results.py examples/mscclang/ $RUNNER_TEMP/tests/pr-outputs/ tests/configs/example-config.json - name: Checkout main branch uses: actions/checkout@v4 with: @@ -60,7 +60,7 @@ jobs: pip install -r requirements.txt - name: generate outputs run: | - python $RUNNER_TEMP/generate_exmpale_results.py -i examples/mscclang/ -o $RUNNER_TEMP/tests/pr-outputs/ -c tests/configs/example-config.json + python $RUNNER_TEMP/generate_exmpale_results.py examples/mscclang/ $RUNNER_TEMP/tests/main-outputs/ tests/configs/example-config.json - name: Compare outputs run: | diff -rw $RUNNER_TEMP/tests/main-outputs/ $RUNNER_TEMP/tests/pr-outputs/ From 5e978bfa49899b760f17198021fe87a62c78e2ac Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 29 Apr 2024 08:17:48 +0000 Subject: [PATCH 58/76] WIP --- .github/workflows/tests.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index c8703bf..fefcdce 100755 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -49,7 +49,7 @@ jobs: cp tests/generate_exmpale_results.py $RUNNER_TEMP/ - name: generate outputs run: | - python $RUNNER_TEMP/generate_exmpale_results.py examples/mscclang/ $RUNNER_TEMP/tests/pr-outputs/ tests/configs/example-config.json + python $RUNNER_TEMP/generate_exmpale_results.py examples/mscclang/ tests/configs/example-config.json $RUNNER_TEMP/tests/pr-outputs/ - name: Checkout main branch uses: actions/checkout@v4 with: @@ -60,7 +60,7 @@ jobs: pip install -r requirements.txt - name: generate outputs run: | - python $RUNNER_TEMP/generate_exmpale_results.py examples/mscclang/ $RUNNER_TEMP/tests/main-outputs/ tests/configs/example-config.json + python $RUNNER_TEMP/generate_exmpale_results.py examples/mscclang/ tests/configs/example-config.json $RUNNER_TEMP/tests/main-outputs/ - name: Compare outputs run: | diff -rw $RUNNER_TEMP/tests/main-outputs/ $RUNNER_TEMP/tests/pr-outputs/ From a90647ece7f7d951cc54f8338b1a5f5786af8d0b Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 29 Apr 2024 08:19:59 +0000 Subject: [PATCH 59/76] WIP --- .github/workflows/tests.yaml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index fefcdce..9559060 100755 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -44,12 +44,13 @@ jobs: run: | pip install --upgrade pip pip install -r requirements.txt - - name: Copy test script to temp directory + - name: Copy test script/config to temp directory run: | cp tests/generate_exmpale_results.py $RUNNER_TEMP/ + cp tests/configs/example-config.json $RUNNER_TEMP/ - name: generate outputs run: | - python $RUNNER_TEMP/generate_exmpale_results.py examples/mscclang/ tests/configs/example-config.json $RUNNER_TEMP/tests/pr-outputs/ + python $RUNNER_TEMP/generate_exmpale_results.py examples/mscclang/ $RUNNER_TEMP/example-config.json $RUNNER_TEMP/tests/pr-outputs/ - name: Checkout main branch uses: actions/checkout@v4 with: @@ -60,7 +61,7 @@ jobs: pip install -r requirements.txt - name: generate outputs run: | - python $RUNNER_TEMP/generate_exmpale_results.py examples/mscclang/ tests/configs/example-config.json $RUNNER_TEMP/tests/main-outputs/ + python $RUNNER_TEMP/generate_exmpale_results.py examples/mscclang/ $RUNNER_TEMP/example-config.json $RUNNER_TEMP/tests/main-outputs/ - name: Compare outputs run: | diff -rw $RUNNER_TEMP/tests/main-outputs/ $RUNNER_TEMP/tests/pr-outputs/ From 0e77d88229666e14eb4a74d1ecde2e03fb671cd7 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 29 Apr 2024 09:50:58 +0000 Subject: [PATCH 60/76] fix --- msccl/language/tb_assignment.py | 11 ----------- msccl/language/types.py | 1 - 2 files changed, 12 deletions(-) diff --git a/msccl/language/tb_assignment.py b/msccl/language/tb_assignment.py index 4d3c6ad..85818b3 100755 --- a/msccl/language/tb_assignment.py +++ b/msccl/language/tb_assignment.py @@ -40,17 +40,6 @@ def manual_assign_tbs(rank_dag): f"Threadblock {tbid} send:{tb.send} recv:{tb.recv} channel:{tb.channel}\n" \ f"Operation send:{op.dst.rank if op.is_send() else -1} recv:{op.dst.rank if op.is_recv() else -1} channel:{op.channel}") -def convert_to_exectuion_plan(instr_dag): - ops = instr_dag.convert_set_list() - ops = sorted(ops, key=lambda x: x.step) - for op in ops: - rank = op.rank - tbid = op.tb - if tbid not in instr_dag.tbs[rank]: - instr_dag.tbs[rank][tbid] = Threadblock(id=tbid) - tb = instr_dag.tbs[rank][tbid] - tb.ops.append(op) - def _get_tb_options(mapping, send, recv, channel, num_tbs): options = [] for tbid, tb in mapping.items(): diff --git a/msccl/language/types.py b/msccl/language/types.py index e914dfd..466ebf6 100644 --- a/msccl/language/types.py +++ b/msccl/language/types.py @@ -38,7 +38,6 @@ def scratch_size(self): @dataclass class Threadblock: - id: int = -1 channel: int = -1 send: int = -1 recv: int = -1 From e05edcf7f3d44b449614648b0d48329df26ab470 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 29 Apr 2024 15:26:25 +0000 Subject: [PATCH 61/76] WIP --- tests/configs/example-config.json | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/tests/configs/example-config.json b/tests/configs/example-config.json index caffdf8..0f2dc1b 100644 --- a/tests/configs/example-config.json +++ b/tests/configs/example-config.json @@ -17,20 +17,12 @@ }, { "filename": "allreduce_1step.py", - "args": ["8", "1"] - }, - { - "filename": "allreduce_a100_allpairs_v2.py", - "args": ["8", "2"] + "args": ["2", "1"] }, { "filename": "allreduce_a100_allpairs.py", "args": ["8", "4"] }, - { - "filename": "allreduce_a100_multinode_allpairs.py", - "args": ["16", "2"] - }, { "filename": "allreduce_a100_ncv4_v2.py", "args": ["4", "2"] @@ -39,10 +31,6 @@ "filename": "allreduce_a100_ncv4.py", "args": ["4", "2"] }, - { - "filename": "allreduce_a100_pcie_hierarchical.py", - "args": ["8", "2"] - }, { "filename": "allreduce_a100_ring.py", "args": ["8", "8", "4"] @@ -51,10 +39,6 @@ "filename": "allreduce_recursive_doubling_halving.py", "args": ["8", "4"] }, - { - "filename": "alltoall_a100_three_step.py", - "args": ["2", "8", "2"] - }, { "filename": "alltoall_a100_two_step.py", "args": ["8", "2"] @@ -75,10 +59,6 @@ "filename": "hierarchical_allreduce.py", "args": ["8", "2", "2"] }, - { - "filename": "pipeline_a100_allpairs.py", - "args": ["8", "2"] - }, { "filename": "pipeline_a100_ring.py", "args": ["8", "4", "2"] From 54463bf3e0283a4a7a804329ea1eb65900f4f49a Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 29 Apr 2024 15:30:28 +0000 Subject: [PATCH 62/76] WIP --- tests/configs/example-config.json | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/configs/example-config.json b/tests/configs/example-config.json index 0f2dc1b..a4bab6e 100644 --- a/tests/configs/example-config.json +++ b/tests/configs/example-config.json @@ -19,10 +19,6 @@ "filename": "allreduce_1step.py", "args": ["2", "1"] }, - { - "filename": "allreduce_a100_allpairs.py", - "args": ["8", "4"] - }, { "filename": "allreduce_a100_ncv4_v2.py", "args": ["4", "2"] From ca3f10aec1bbb439981a3ba9401922669cd45858 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 29 Apr 2024 15:31:50 +0000 Subject: [PATCH 63/76] WIP --- tests/configs/example-config.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/configs/example-config.json b/tests/configs/example-config.json index a4bab6e..7e2946b 100644 --- a/tests/configs/example-config.json +++ b/tests/configs/example-config.json @@ -9,7 +9,7 @@ }, { "filename": "allgather_recursive_doubling.py", - "args": ["8", "4"] + "args": ["2", "4"] }, { "filename": "allgather_ring.py", From b9c9734d3a540e09c4e5e2b6b3e22ee9da80f3fb Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 29 Apr 2024 15:33:54 +0000 Subject: [PATCH 64/76] add back --- tests/configs/example-config.json | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/configs/example-config.json b/tests/configs/example-config.json index 7e2946b..f1c4ccc 100644 --- a/tests/configs/example-config.json +++ b/tests/configs/example-config.json @@ -19,6 +19,10 @@ "filename": "allreduce_1step.py", "args": ["2", "1"] }, + { + "filename": "allreduce_a100_allpairs.py", + "args": ["2", "4"] + }, { "filename": "allreduce_a100_ncv4_v2.py", "args": ["4", "2"] From a3b97457eedccf5d945428ffd29bdb753aedc249 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 29 Apr 2024 15:39:06 +0000 Subject: [PATCH 65/76] update --- .github/workflows/codeql.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 12a15f0..6f24ed2 100755 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -19,12 +19,12 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Initialize CodeQL - uses: github/codeql-action/init@v1 + uses: github/codeql-action/init@v3 with: languages: python - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v1 + uses: github/codeql-action/analyze@v3 From d81936fd69804cc669879e6de5fae17c55d9585e Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Tue, 30 Apr 2024 06:14:51 +0000 Subject: [PATCH 66/76] Fix --- .github/workflows/tests.yaml | 15 +++++++++------ ...ale_results.py => generate_example_results.py} | 0 2 files changed, 9 insertions(+), 6 deletions(-) rename tests/{generate_exmpale_results.py => generate_example_results.py} (100%) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 9559060..2a49e66 100755 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -31,13 +31,16 @@ jobs: compare_outputs: runs-on: ubuntu-latest - name: Compare outputs + strategy: + matrix: + python-version: ['3.8', '3.9', '3.10'] + name: Compare outputs with Python ${{ matrix.python-version }} steps: - - name: Set up Python 3.8 + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: ${{ matrix.python-version }} - name: Checkout current branch uses: actions/checkout@v4 - name: Install msccl and dependencies @@ -46,11 +49,11 @@ jobs: pip install -r requirements.txt - name: Copy test script/config to temp directory run: | - cp tests/generate_exmpale_results.py $RUNNER_TEMP/ + cp tests/generate_example_results.py $RUNNER_TEMP/ cp tests/configs/example-config.json $RUNNER_TEMP/ - name: generate outputs run: | - python $RUNNER_TEMP/generate_exmpale_results.py examples/mscclang/ $RUNNER_TEMP/example-config.json $RUNNER_TEMP/tests/pr-outputs/ + python $RUNNER_TEMP/generate_example_results.py examples/mscclang/ $RUNNER_TEMP/example-config.json $RUNNER_TEMP/tests/pr-outputs/ - name: Checkout main branch uses: actions/checkout@v4 with: @@ -61,7 +64,7 @@ jobs: pip install -r requirements.txt - name: generate outputs run: | - python $RUNNER_TEMP/generate_exmpale_results.py examples/mscclang/ $RUNNER_TEMP/example-config.json $RUNNER_TEMP/tests/main-outputs/ + python $RUNNER_TEMP/generate_example_results.py examples/mscclang/ $RUNNER_TEMP/example-config.json $RUNNER_TEMP/tests/main-outputs/ - name: Compare outputs run: | diff -rw $RUNNER_TEMP/tests/main-outputs/ $RUNNER_TEMP/tests/pr-outputs/ diff --git a/tests/generate_exmpale_results.py b/tests/generate_example_results.py similarity index 100% rename from tests/generate_exmpale_results.py rename to tests/generate_example_results.py From 2651faff03e3a6504ffabda2b061dd00b197450d Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Tue, 30 Apr 2024 06:17:57 +0000 Subject: [PATCH 67/76] revert --- .github/workflows/tests.yaml | 2 +- pytest.ini | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 2a49e66..3392790 100755 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -25,7 +25,7 @@ jobs: run: | pip install --upgrade pip pip install -r requirements.txt - - name: Run tests and check at least 85% coverage + - name: Run tests and check at least 90% coverage run: | pytest diff --git a/pytest.ini b/pytest.ini index 4621e92..d68bf05 100755 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +1,2 @@ [pytest] -addopts = --cov=msccl --cov-report term-missing:skip-covered --cov-fail-under 85 -n auto +addopts = --cov=msccl --cov-report term-missing:skip-covered --cov-fail-under 90 -n auto From bb7a584c9a163976b8b5c6c13911a03158bea0f1 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 6 May 2024 11:47:52 +0000 Subject: [PATCH 68/76] address comments --- .github/workflows/tests.yaml | 20 ++++++++++++------- .../{example-config.json => test-config.json} | 0 ...le_results.py => generate_test_results.py} | 0 3 files changed, 13 insertions(+), 7 deletions(-) rename tests/configs/{example-config.json => test-config.json} (100%) rename tests/{generate_example_results.py => generate_test_results.py} (100%) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 3392790..63ae367 100755 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -1,6 +1,12 @@ name: Tests on: + workflow_dispatch: + inputs: + commit_hash: + description: 'The commit hash which would be compared to' + required: true + default: 'main' push: pull_request: branches: [ main ] @@ -21,7 +27,7 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install msccl and dependencies + - name: Install msccl-tools and dependencies run: | pip install --upgrade pip pip install -r requirements.txt @@ -43,28 +49,28 @@ jobs: python-version: ${{ matrix.python-version }} - name: Checkout current branch uses: actions/checkout@v4 - - name: Install msccl and dependencies + - name: Install msccl-tools and dependencies run: | pip install --upgrade pip pip install -r requirements.txt - name: Copy test script/config to temp directory run: | - cp tests/generate_example_results.py $RUNNER_TEMP/ - cp tests/configs/example-config.json $RUNNER_TEMP/ + cp tests/generate_test_results.py $RUNNER_TEMP/ + cp tests/configs/test-config.json $RUNNER_TEMP/ - name: generate outputs run: | - python $RUNNER_TEMP/generate_example_results.py examples/mscclang/ $RUNNER_TEMP/example-config.json $RUNNER_TEMP/tests/pr-outputs/ + python $RUNNER_TEMP/generate_test_results.py examples/mscclang/ $RUNNER_TEMP/test-config.json $RUNNER_TEMP/tests/pr-outputs/ - name: Checkout main branch uses: actions/checkout@v4 with: - ref: main + ref: ${{ inputs.commit_hash }} - name: Install msccl and dependencies run: | pip install --upgrade pip pip install -r requirements.txt - name: generate outputs run: | - python $RUNNER_TEMP/generate_example_results.py examples/mscclang/ $RUNNER_TEMP/example-config.json $RUNNER_TEMP/tests/main-outputs/ + python $RUNNER_TEMP/generate_test_results.py examples/mscclang/ $RUNNER_TEMP/test-config.json $RUNNER_TEMP/tests/main-outputs/ - name: Compare outputs run: | diff -rw $RUNNER_TEMP/tests/main-outputs/ $RUNNER_TEMP/tests/pr-outputs/ diff --git a/tests/configs/example-config.json b/tests/configs/test-config.json similarity index 100% rename from tests/configs/example-config.json rename to tests/configs/test-config.json diff --git a/tests/generate_example_results.py b/tests/generate_test_results.py similarity index 100% rename from tests/generate_example_results.py rename to tests/generate_test_results.py From d072a34d95bd446c4299065fe174dbd6ece490b4 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 6 May 2024 12:02:58 +0000 Subject: [PATCH 69/76] WIP --- .github/workflows/tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 63ae367..ae20c26 100755 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -63,7 +63,7 @@ jobs: - name: Checkout main branch uses: actions/checkout@v4 with: - ref: ${{ inputs.commit_hash }} + ref: ${{ github.event.inputs.commit_hash }} - name: Install msccl and dependencies run: | pip install --upgrade pip From dda74f9d3a6e5857c322a23b24deafe77a0db794 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 6 May 2024 12:17:22 +0000 Subject: [PATCH 70/76] Fix --- .github/workflows/tests.yaml | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index ae20c26..b56be05 100755 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -4,7 +4,7 @@ on: workflow_dispatch: inputs: commit_hash: - description: 'The commit hash which would be compared to' + description: 'Git SHA or branch to comapre with' required: true default: 'main' push: @@ -60,10 +60,16 @@ jobs: - name: generate outputs run: | python $RUNNER_TEMP/generate_test_results.py examples/mscclang/ $RUNNER_TEMP/test-config.json $RUNNER_TEMP/tests/pr-outputs/ - - name: Checkout main branch + - name: Checkout repo (workflow_dispatch) + if: github.event_name == 'workflow_dispatch' uses: actions/checkout@v4 with: ref: ${{ github.event.inputs.commit_hash }} + - name: Checkout main branch (pull_request) + if: github.event_name == 'pull_request' + uses: actions/checkout@v4 + with: + ref: 'main' - name: Install msccl and dependencies run: | pip install --upgrade pip From f47dfe977d1baa3084c92d5c2576ba573865d28a Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 6 May 2024 12:20:28 +0000 Subject: [PATCH 71/76] Fix --- .github/workflows/tests.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index b56be05..c710d3f 100755 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -60,13 +60,13 @@ jobs: - name: generate outputs run: | python $RUNNER_TEMP/generate_test_results.py examples/mscclang/ $RUNNER_TEMP/test-config.json $RUNNER_TEMP/tests/pr-outputs/ - - name: Checkout repo (workflow_dispatch) + - name: Checkout specific branch if: github.event_name == 'workflow_dispatch' uses: actions/checkout@v4 with: ref: ${{ github.event.inputs.commit_hash }} - - name: Checkout main branch (pull_request) - if: github.event_name == 'pull_request' + - name: Checkout main branch + if: github.event_name == 'pull_request' || github.event_name == 'push' uses: actions/checkout@v4 with: ref: 'main' From daef76ab3f1a4f6f5a7b36b707be976ad83d9ccf Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 6 May 2024 12:23:29 +0000 Subject: [PATCH 72/76] WIP --- .github/workflows/tests.yaml | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index c710d3f..cc71639 100755 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -60,13 +60,12 @@ jobs: - name: generate outputs run: | python $RUNNER_TEMP/generate_test_results.py examples/mscclang/ $RUNNER_TEMP/test-config.json $RUNNER_TEMP/tests/pr-outputs/ - - name: Checkout specific branch - if: github.event_name == 'workflow_dispatch' - uses: actions/checkout@v4 - with: - ref: ${{ github.event.inputs.commit_hash }} + # - name: Checkout specific branch + # if: github.event_name == 'workflow_dispatch' + # uses: actions/checkout@v4 + # with: + # ref: ${{ github.event.inputs.commit_hash }} - name: Checkout main branch - if: github.event_name == 'pull_request' || github.event_name == 'push' uses: actions/checkout@v4 with: ref: 'main' From 3d2d8385eff7bd568d49582ecbbcb83476879cd9 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 6 May 2024 12:24:53 +0000 Subject: [PATCH 73/76] WIP --- .github/workflows/tests.yaml | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index cc71639..499b159 100755 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -4,7 +4,7 @@ on: workflow_dispatch: inputs: commit_hash: - description: 'Git SHA or branch to comapre with' + description: 'The commit hash which would be compared to' required: true default: 'main' push: @@ -60,15 +60,10 @@ jobs: - name: generate outputs run: | python $RUNNER_TEMP/generate_test_results.py examples/mscclang/ $RUNNER_TEMP/test-config.json $RUNNER_TEMP/tests/pr-outputs/ - # - name: Checkout specific branch - # if: github.event_name == 'workflow_dispatch' - # uses: actions/checkout@v4 - # with: - # ref: ${{ github.event.inputs.commit_hash }} - name: Checkout main branch - uses: actions/checkout@v4 - with: - ref: 'main' + uses: actions/checkout@v4 + with: + ref: main - name: Install msccl and dependencies run: | pip install --upgrade pip From 5a34cd5c8b4ae62b84368ea399bf6554b1880fbc Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 6 May 2024 12:25:55 +0000 Subject: [PATCH 74/76] WIP --- .github/workflows/tests.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 499b159..e5311a5 100755 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -62,6 +62,7 @@ jobs: python $RUNNER_TEMP/generate_test_results.py examples/mscclang/ $RUNNER_TEMP/test-config.json $RUNNER_TEMP/tests/pr-outputs/ - name: Checkout main branch uses: actions/checkout@v4 + if: github.event_name == 'pull_request' || github.event_name == 'push' with: ref: main - name: Install msccl and dependencies From 02b5de94876416b55240c41599fcce34edbbb74f Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 6 May 2024 12:27:18 +0000 Subject: [PATCH 75/76] WIP --- .github/workflows/tests.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index e5311a5..97cf8a6 100755 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -60,6 +60,11 @@ jobs: - name: generate outputs run: | python $RUNNER_TEMP/generate_test_results.py examples/mscclang/ $RUNNER_TEMP/test-config.json $RUNNER_TEMP/tests/pr-outputs/ + - name: Checkout specific branch + if: github.event_name == 'workflow_dispatch' + uses: actions/checkout@v4 + with: + ref: ${{ github.event.inputs.commit_hash }} - name: Checkout main branch uses: actions/checkout@v4 if: github.event_name == 'pull_request' || github.event_name == 'push' From 06d77762ba1dfc046b249f3d02c4dcbc6f0120ed Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 6 May 2024 12:33:50 +0000 Subject: [PATCH 76/76] done --- .github/workflows/tests.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 97cf8a6..133e11a 100755 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -4,9 +4,9 @@ on: workflow_dispatch: inputs: commit_hash: - description: 'The commit hash which would be compared to' + description: 'The git commit hash to compare against' required: true - default: 'main' + default: 'fa5accc63ac39840422ff0d6b0ee875706c95e90' # legacy main branch commit hash push: pull_request: branches: [ main ]