From 2ba1760011961c7dd9fb8dc0d95a04448821dc75 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Thu, 25 Apr 2024 08:28:23 +0000 Subject: [PATCH] update --- .../allreduce_a100_allpairs_packet_mscclpp.py | 2 +- .../allreduce_a100_allpairs_sm_mscclpp.py | 1 + .../allreduce_a100_allpairs_sm_mscclpp_get.py | 1 + msccl/language/__init__.py | 8 +- msccl/language/instruction_dag.py | 12 +-- msccl/language/mscclpp/__init__.py | 9 ++- msccl/language/mscclpp/instruction_dag.py | 6 +- msccl/language/mscclpp/ir.py | 7 ++ msccl/language/types.py | 2 +- tests/test_language.py | 75 +++++++++++++++++-- 10 files changed, 98 insertions(+), 25 deletions(-) diff --git a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py index 1fad562..d9c6275 100644 --- a/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py +++ b/examples/mscclang/allreduce_a100_allpairs_packet_mscclpp.py @@ -50,7 +50,7 @@ def allreduce_allpairs(gpus, instances): c.copy_packet(r, Buffer.input, peer * size, sendtb=peer) Json() - # Check() + Check() parser = argparse.ArgumentParser() diff --git a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py index 74ae223..08717c4 100644 --- a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py +++ b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py @@ -45,6 +45,7 @@ def allreduce_allpairs(gpus, instances, protocol): c_peer.wait(nghr, Buffer.input, peer_index + tb, recvtb=tb) Json() + Check() parser = argparse.ArgumentParser() diff --git a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py index 49f3606..5407676 100644 --- a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py +++ b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py @@ -59,6 +59,7 @@ def allreduce_allpairs(gpus, instances, protocol): c.get(nghr, Buffer.input, index + tb, recvtb=tb) Json() + Check() parser = argparse.ArgumentParser() diff --git a/msccl/language/__init__.py b/msccl/language/__init__.py index 9ba80a0..51cd330 100755 --- a/msccl/language/__init__.py +++ b/msccl/language/__init__.py @@ -11,7 +11,7 @@ from msccl.language.mscclpp import * from typing import Union -from msccl.language.types import ThreadblockPolicy +from msccl.language.types import ReplicationPolicy, ThreadblockPolicy # from msccl.language.visualize import * @@ -49,7 +49,9 @@ def __init__( self.instances = instances self.protocol = protocol self.threadblock_policy = threadblock_policy - self.interleaved_replication = interleaved_replication + self.replication_policy = ( + ReplicationPolicy.interleaved if interleaved_replication else ReplicationPolicy.duplicated + ) self.instr_fusion = instr_fusion self.check_xml = check_xml self.dependence_nop = dependence_nop @@ -136,7 +138,7 @@ def lower(self): else: auto_assign_tbs(self.instr_dag) self.instr_dag.lower_pt1(self.instances) - gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.interleaved_replication) + gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.replication_policy) if self.check_xml: # Check generated MSCCL-IR for correctness - no circular dependencies, sends and receives are ordered # For very large programs, turn off check_xml when shipping diff --git a/msccl/language/instruction_dag.py b/msccl/language/instruction_dag.py index 2e74ef0..7bd70bf 100755 --- a/msccl/language/instruction_dag.py +++ b/msccl/language/instruction_dag.py @@ -5,7 +5,7 @@ from collections import defaultdict from msccl.language.buffer import Buffer -from msccl.language.types import ChunkRef, Gpu, InstancePolicy, Instruction, Op, Threadblock +from msccl.language.types import ChunkRef, Gpu, Instruction, Op, ReplicationPolicy, Threadblock def remove_op(op: Op): @@ -203,8 +203,8 @@ def lower_pt1(self, instances: int): self._infer_dependencies() self._lower_buffers(instances) - def lower_pt2(self, instances: int, instance_pollicy: InstancePolicy): - self.replicate(instances, instance_pollicy) + def lower_pt2(self, instances: int, replication_policy: ReplicationPolicy): + self.replicate(instances, replication_policy) return self._lower_tbs() @abstractmethod @@ -212,7 +212,7 @@ def optimize(self): pass @abstractmethod - def replicate(self, instances: int, instance_policy: InstancePolicy): + def replicate(self, instances: int, replication_policy: ReplicationPolicy): pass @@ -382,7 +382,7 @@ def _optimize_rrcs_rrs(self): # only interleaved replication will be correct # Interleaved policy only supports single count sends/receives from the input/output buffer # (multicount ops are fine between scratch) - def replicate(self, instances, interleaved): + def replicate(self, instances, replication_policy: ReplicationPolicy): if instances == 1: self.instanced_tbs = self.tbs return @@ -401,7 +401,7 @@ def get_new_index(rank, buffer, index, size, i): return buf_instance_len * i + index # If this is operating on the input/output buffer then replication strategy can be either interleaved or batched # This is to fit with the semantics of certain collectives - elif interleaved: + elif replication_policy == ReplicationPolicy.interleaved: return index * instances + i * size else: return len(self.buffers[rank][buffer]) * i + index diff --git a/msccl/language/mscclpp/__init__.py b/msccl/language/mscclpp/__init__.py index 5e8e29f..9bc27af 100644 --- a/msccl/language/mscclpp/__init__.py +++ b/msccl/language/mscclpp/__init__.py @@ -31,7 +31,7 @@ def __init__( instances: int, protocol: str = "Simple", instr_fusion: bool = True, - instance_policy: InstancePolicy = InstancePolicy.duplicated, + replication_policy: ReplicationPolicy = ReplicationPolicy.duplicated, ): self.name = name self.topo = topo @@ -40,7 +40,7 @@ def __init__( self.instances = instances self.protocol = protocol self.instr_fusion = instr_fusion - self.instance_policy = instance_policy + self.replication_policy = replication_policy assert protocol == "Simple" or protocol == "LL", f"Given protocol: {protocol}. Must be either Simple, LL" self.run_opt = True # Runs optimization passes # Initialize the input buffers @@ -114,7 +114,7 @@ def lower(self): if self.instr_fusion: self.instr_dag.optimize() self.instr_dag.lower_pt1(self.instances) - gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.instance_policy) + gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.replication_policy) return Program( self.name, self.collective.name, @@ -189,9 +189,10 @@ def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, self.prog.instr_dag.add_wait(dst, dst_chunkref, self, -1, ChannelType.none) else: self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type) + return dst_chunkref def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): - self._put(dst, buffer, index, sendtb, chan_type) + return self._put(dst, buffer, index, sendtb, chan_type) def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, channel_type=ChannelType.sm): return self._put(dst, buffer, index, sendtb, channel_type, use_packet=True) diff --git a/msccl/language/mscclpp/instruction_dag.py b/msccl/language/mscclpp/instruction_dag.py index 093d41d..e0bf47a 100644 --- a/msccl/language/mscclpp/instruction_dag.py +++ b/msccl/language/mscclpp/instruction_dag.py @@ -15,7 +15,7 @@ same_src_dst_buffer_type, ) from msccl.language.instruction_dag import InstructionDAG -from msccl.language.types import ChunkRef, InstancePolicy, MscclppInstruction as Instruction, Op, Threadblock +from msccl.language.types import ChunkRef, MscclppInstruction as Instruction, Op, ReplicationPolicy, Threadblock class MscclppInstructionDAG(InstructionDAG): @@ -631,7 +631,7 @@ def optimize(self): self._parallel_signal_wait() - def replicate(self, instances: int, instance_policy: InstancePolicy): + def replicate(self, instances: int, replication_policy: ReplicationPolicy): # update op step for rank, rank_tbs in enumerate(self.tbs): for _, tb in rank_tbs.items(): @@ -661,7 +661,7 @@ def get_instance_ref(ref): iref = ChunkRef(ref.rank, ref.buffer, iindex, ref.size) return iref - if instance_policy == InstancePolicy.duplicated: + if replication_policy == ReplicationPolicy.duplicated: for i in range(instances): # Generate all the threadblocks and ops for rank, rank_tbs in enumerate(self.tbs): diff --git a/msccl/language/mscclpp/ir.py b/msccl/language/mscclpp/ir.py index 28df6f5..2c8f37e 100644 --- a/msccl/language/mscclpp/ir.py +++ b/msccl/language/mscclpp/ir.py @@ -61,6 +61,13 @@ def ir_to_json(program: Program): gpu.output_chunks = max(buffer_sizes[(gpu.rank, Buffer.output)], gpu.output_chunks) gpu.scratch_chunks = max(buffer_sizes[(gpu.rank, Buffer.scratch)], gpu.scratch_chunks) + # Since LL protocol will double the scratch size. We need to make sure all GPUs have the same scratch size. + # Otherwise the offset calculation will be wrong. + if program.protocol == "LL": + max_scratch = max(gpu.scratch_chunks for gpu in program.gpus) + for gpu in program.gpus: + gpu.scratch_chunks = max_scratch + # get channel info for each GPU and threadblock for gpu in program.gpus: gpu.threadblocks = sorted(gpu.threadblocks, key=lambda tb: tb.id) diff --git a/msccl/language/types.py b/msccl/language/types.py index 55f2b1d..97a4277 100644 --- a/msccl/language/types.py +++ b/msccl/language/types.py @@ -73,7 +73,7 @@ def __str__(self): return self.value -class InstancePolicy(Enum): +class ReplicationPolicy(Enum): # this means pack multi instrances to deal with the same chunk and share the channels packed = "packed" # this means each instance deal with the different chunk diff --git a/tests/test_language.py b/tests/test_language.py index 9c2a38d..6fc9b98 100755 --- a/tests/test_language.py +++ b/tests/test_language.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import msccl +from msccl.language.types import MscclppInstruction from msccl.topologies import line, fully_connected from msccl.language import * from msccl.language.routines import * @@ -16,15 +17,15 @@ def init_buffers(self): rank_buffers = [] for r in range(self.num_ranks): input_buffer = [None] * chunks_per_node - output_buffer = [None] * chunks_per_node + output_buffer = [None] * chunks_per_node if r == 0: for c in range(chunks_per_node): input_buffer[c] = Chunk(r, c, 2, c) - buffers = {Buffer.input : input_buffer, + buffers = {Buffer.input : input_buffer, Buffer.output : output_buffer} rank_buffers.append(buffers) return rank_buffers - + # Final state chunk0 from rank0 is in the output buffer of rank2 def check(self, prog): @@ -46,14 +47,14 @@ def init_buffers(self): rank_buffers = [] for r in range(self.num_ranks): input_buffer = [None] * chunks_per_node - output_buffer = [None] * chunks_per_node + output_buffer = [None] * chunks_per_node for c in range(chunks_per_node): input_buffer[c] = Chunk(r, c, -1, c) - buffers = {Buffer.input : input_buffer, + buffers = {Buffer.input : input_buffer, Buffer.output : output_buffer} rank_buffers.append(buffers) return rank_buffers - + # Final state rank2 has a fully reduced chunk from gpus 0, 1, and 2 def check(self, prog): @@ -201,6 +202,32 @@ def test_instruction_fusion(): assert lowered_prgm.gpus[1].threadblocks[0].ops[1].inst == Instruction.recv assert lowered_prgm.gpus[2].threadblocks[0].ops[0].inst == Instruction.recv_reduce_copy_send + +def test_instruction_fusion_mscclpp(): + topology = fully_connected(3) + collective = AllReduce(3, 3, True) + prgm = MSCCLPPProgram("allreduce", topology, collective, 1) + with prgm: + c01 = chunk(1, Buffer.input, 0, 3).reduce(chunk(0, Buffer.input, 0, 3), recvtb=0) + c01.signal(2, Buffer.input, 0, sendtb=0) + c012 = chunk(2, Buffer.input, 0, 3) + c012.wait(1, Buffer.input, 0, recvtb=0) + c012.reduce(c01, recvtb=0).put(0, Buffer.input, 0, sendtb=0) + c012.signal(0, Buffer.input, 0, sendtb=0) + c0 = chunk(0, Buffer.input, 0, 3) + c0.wait(2, Buffer.input, 0, recvtb=0) + c0.put(1, Buffer.input, 0, sendtb=0) + assert Check() + lowered_prgm = prgm.lower() + assert lowered_prgm.gpus[0].threadblocks[0].ops[0].inst == MscclppInstruction.wait + assert lowered_prgm.gpus[0].threadblocks[0].ops[1].inst == MscclppInstruction.put + assert lowered_prgm.gpus[1].threadblocks[0].ops[0].inst == MscclppInstruction.read_reduce_copy + assert lowered_prgm.gpus[1].threadblocks[0].ops[1].inst == MscclppInstruction.signal + assert lowered_prgm.gpus[2].threadblocks[0].ops[0].inst == MscclppInstruction.wait + assert lowered_prgm.gpus[2].threadblocks[0].ops[1].inst == MscclppInstruction.read_reduce_copy_send + assert lowered_prgm.gpus[2].threadblocks[0].ops[2].inst == MscclppInstruction.signal + + def test_replication(): topology = fully_connected(2) collective = AllToAll(2, 1, False) @@ -287,4 +314,38 @@ def test_routines_allreduce_nodes(): c.reduce(chunk(r, Buffer.output, exchange_index, 4)) c = c.copy(r, Buffer.output, exchange_index) XML() - assert Check() \ No newline at end of file + assert Check() + +def test_routines_allreduce_packet_inplace_mscclpp(): + size = 8 + topology = fully_connected(size) + collective = AllReduce(size, size * size, True) + with MSCCLPPProgram("allreduce_packet", topology, collective, 1, protocol="LL"): + # Each rank sends the nth chunk to the nth rank into scratch space + for r1 in range(size): + for tb in range(size): + if tb == r1: + continue + remote_rank = tb + index = remote_rank * size + c = chunk(r1, Buffer.input, index, size) + c.put_packet(remote_rank, "scratch", index=r1 * size, sendtb=tb) + # Each rank performs a local reduction on the nth chunk + # Utilize 8 threadblocks for this reduction for better parallelism + for r in range(size): + for index in range(size): + c = chunk(r, Buffer.input, r * size + index) + for peer in range(size): + if peer != r: + c.reduce_packet(chunk(r, "scratch", peer * size + index), recvtb=index) + for peer in range(size): + if peer != r: + c.put_packet(peer, "scratch", (size * size) + r * size + index, sendtb=index) + # Each rank get final result from scratch space + for r in range(size): + for peer in range(size): + if peer != r: + c = chunk(r, "scratch", size * size + peer * size, size) + c.copy_packet(r, Buffer.input, peer * size, sendtb=peer) + Json() + assert Check()