Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Apr 25, 2024
1 parent 8cfcc30 commit 2ba1760
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def allreduce_allpairs(gpus, instances):
c.copy_packet(r, Buffer.input, peer * size, sendtb=peer)

Json()
# Check()
Check()


parser = argparse.ArgumentParser()
Expand Down
1 change: 1 addition & 0 deletions examples/mscclang/allreduce_a100_allpairs_sm_mscclpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def allreduce_allpairs(gpus, instances, protocol):
c_peer.wait(nghr, Buffer.input, peer_index + tb, recvtb=tb)

Json()
Check()


parser = argparse.ArgumentParser()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def allreduce_allpairs(gpus, instances, protocol):
c.get(nghr, Buffer.input, index + tb, recvtb=tb)

Json()
Check()


parser = argparse.ArgumentParser()
Expand Down
8 changes: 5 additions & 3 deletions msccl/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from msccl.language.mscclpp import *
from typing import Union

from msccl.language.types import ThreadblockPolicy
from msccl.language.types import ReplicationPolicy, ThreadblockPolicy

# from msccl.language.visualize import *

Expand Down Expand Up @@ -49,7 +49,9 @@ def __init__(
self.instances = instances
self.protocol = protocol
self.threadblock_policy = threadblock_policy
self.interleaved_replication = interleaved_replication
self.replication_policy = (
ReplicationPolicy.interleaved if interleaved_replication else ReplicationPolicy.duplicated
)
self.instr_fusion = instr_fusion
self.check_xml = check_xml
self.dependence_nop = dependence_nop
Expand Down Expand Up @@ -136,7 +138,7 @@ def lower(self):
else:
auto_assign_tbs(self.instr_dag)
self.instr_dag.lower_pt1(self.instances)
gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.interleaved_replication)
gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.replication_policy)
if self.check_xml:
# Check generated MSCCL-IR for correctness - no circular dependencies, sends and receives are ordered
# For very large programs, turn off check_xml when shipping
Expand Down
12 changes: 6 additions & 6 deletions msccl/language/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import defaultdict

from msccl.language.buffer import Buffer
from msccl.language.types import ChunkRef, Gpu, InstancePolicy, Instruction, Op, Threadblock
from msccl.language.types import ChunkRef, Gpu, Instruction, Op, ReplicationPolicy, Threadblock


def remove_op(op: Op):
Expand Down Expand Up @@ -203,16 +203,16 @@ def lower_pt1(self, instances: int):
self._infer_dependencies()
self._lower_buffers(instances)

def lower_pt2(self, instances: int, instance_pollicy: InstancePolicy):
self.replicate(instances, instance_pollicy)
def lower_pt2(self, instances: int, replication_policy: ReplicationPolicy):
self.replicate(instances, replication_policy)
return self._lower_tbs()

@abstractmethod
def optimize(self):
pass

@abstractmethod
def replicate(self, instances: int, instance_policy: InstancePolicy):
def replicate(self, instances: int, replication_policy: ReplicationPolicy):
pass


Expand Down Expand Up @@ -382,7 +382,7 @@ def _optimize_rrcs_rrs(self):
# only interleaved replication will be correct
# Interleaved policy only supports single count sends/receives from the input/output buffer
# (multicount ops are fine between scratch)
def replicate(self, instances, interleaved):
def replicate(self, instances, replication_policy: ReplicationPolicy):
if instances == 1:
self.instanced_tbs = self.tbs
return
Expand All @@ -401,7 +401,7 @@ def get_new_index(rank, buffer, index, size, i):
return buf_instance_len * i + index
# If this is operating on the input/output buffer then replication strategy can be either interleaved or batched
# This is to fit with the semantics of certain collectives
elif interleaved:
elif replication_policy == ReplicationPolicy.interleaved:
return index * instances + i * size
else:
return len(self.buffers[rank][buffer]) * i + index
Expand Down
9 changes: 5 additions & 4 deletions msccl/language/mscclpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
instances: int,
protocol: str = "Simple",
instr_fusion: bool = True,
instance_policy: InstancePolicy = InstancePolicy.duplicated,
replication_policy: ReplicationPolicy = ReplicationPolicy.duplicated,
):
self.name = name
self.topo = topo
Expand All @@ -40,7 +40,7 @@ def __init__(
self.instances = instances
self.protocol = protocol
self.instr_fusion = instr_fusion
self.instance_policy = instance_policy
self.replication_policy = replication_policy
assert protocol == "Simple" or protocol == "LL", f"Given protocol: {protocol}. Must be either Simple, LL"
self.run_opt = True # Runs optimization passes
# Initialize the input buffers
Expand Down Expand Up @@ -114,7 +114,7 @@ def lower(self):
if self.instr_fusion:
self.instr_dag.optimize()
self.instr_dag.lower_pt1(self.instances)
gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.instance_policy)
gpu_prgms = self.instr_dag.lower_pt2(self.instances, self.replication_policy)
return Program(
self.name,
self.collective.name,
Expand Down Expand Up @@ -189,9 +189,10 @@ def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm,
self.prog.instr_dag.add_wait(dst, dst_chunkref, self, -1, ChannelType.none)
else:
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type)
return dst_chunkref

def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm):
self._put(dst, buffer, index, sendtb, chan_type)
return self._put(dst, buffer, index, sendtb, chan_type)

def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, channel_type=ChannelType.sm):
return self._put(dst, buffer, index, sendtb, channel_type, use_packet=True)
Expand Down
6 changes: 3 additions & 3 deletions msccl/language/mscclpp/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
same_src_dst_buffer_type,
)
from msccl.language.instruction_dag import InstructionDAG
from msccl.language.types import ChunkRef, InstancePolicy, MscclppInstruction as Instruction, Op, Threadblock
from msccl.language.types import ChunkRef, MscclppInstruction as Instruction, Op, ReplicationPolicy, Threadblock


class MscclppInstructionDAG(InstructionDAG):
Expand Down Expand Up @@ -631,7 +631,7 @@ def optimize(self):

self._parallel_signal_wait()

def replicate(self, instances: int, instance_policy: InstancePolicy):
def replicate(self, instances: int, replication_policy: ReplicationPolicy):
# update op step
for rank, rank_tbs in enumerate(self.tbs):
for _, tb in rank_tbs.items():
Expand Down Expand Up @@ -661,7 +661,7 @@ def get_instance_ref(ref):
iref = ChunkRef(ref.rank, ref.buffer, iindex, ref.size)
return iref

if instance_policy == InstancePolicy.duplicated:
if replication_policy == ReplicationPolicy.duplicated:
for i in range(instances):
# Generate all the threadblocks and ops
for rank, rank_tbs in enumerate(self.tbs):
Expand Down
7 changes: 7 additions & 0 deletions msccl/language/mscclpp/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def ir_to_json(program: Program):
gpu.output_chunks = max(buffer_sizes[(gpu.rank, Buffer.output)], gpu.output_chunks)
gpu.scratch_chunks = max(buffer_sizes[(gpu.rank, Buffer.scratch)], gpu.scratch_chunks)

# Since LL protocol will double the scratch size. We need to make sure all GPUs have the same scratch size.
# Otherwise the offset calculation will be wrong.
if program.protocol == "LL":
max_scratch = max(gpu.scratch_chunks for gpu in program.gpus)
for gpu in program.gpus:
gpu.scratch_chunks = max_scratch

# get channel info for each GPU and threadblock
for gpu in program.gpus:
gpu.threadblocks = sorted(gpu.threadblocks, key=lambda tb: tb.id)
Expand Down
2 changes: 1 addition & 1 deletion msccl/language/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __str__(self):
return self.value


class InstancePolicy(Enum):
class ReplicationPolicy(Enum):
# this means pack multi instrances to deal with the same chunk and share the channels
packed = "packed"
# this means each instance deal with the different chunk
Expand Down
75 changes: 68 additions & 7 deletions tests/test_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.

import msccl
from msccl.language.types import MscclppInstruction
from msccl.topologies import line, fully_connected
from msccl.language import *
from msccl.language.routines import *
Expand All @@ -16,15 +17,15 @@ def init_buffers(self):
rank_buffers = []
for r in range(self.num_ranks):
input_buffer = [None] * chunks_per_node
output_buffer = [None] * chunks_per_node
output_buffer = [None] * chunks_per_node
if r == 0:
for c in range(chunks_per_node):
input_buffer[c] = Chunk(r, c, 2, c)
buffers = {Buffer.input : input_buffer,
buffers = {Buffer.input : input_buffer,
Buffer.output : output_buffer}
rank_buffers.append(buffers)
return rank_buffers


# Final state chunk0 from rank0 is in the output buffer of rank2
def check(self, prog):
Expand All @@ -46,14 +47,14 @@ def init_buffers(self):
rank_buffers = []
for r in range(self.num_ranks):
input_buffer = [None] * chunks_per_node
output_buffer = [None] * chunks_per_node
output_buffer = [None] * chunks_per_node
for c in range(chunks_per_node):
input_buffer[c] = Chunk(r, c, -1, c)
buffers = {Buffer.input : input_buffer,
buffers = {Buffer.input : input_buffer,
Buffer.output : output_buffer}
rank_buffers.append(buffers)
return rank_buffers


# Final state rank2 has a fully reduced chunk from gpus 0, 1, and 2
def check(self, prog):
Expand Down Expand Up @@ -201,6 +202,32 @@ def test_instruction_fusion():
assert lowered_prgm.gpus[1].threadblocks[0].ops[1].inst == Instruction.recv
assert lowered_prgm.gpus[2].threadblocks[0].ops[0].inst == Instruction.recv_reduce_copy_send


def test_instruction_fusion_mscclpp():
topology = fully_connected(3)
collective = AllReduce(3, 3, True)
prgm = MSCCLPPProgram("allreduce", topology, collective, 1)
with prgm:
c01 = chunk(1, Buffer.input, 0, 3).reduce(chunk(0, Buffer.input, 0, 3), recvtb=0)
c01.signal(2, Buffer.input, 0, sendtb=0)
c012 = chunk(2, Buffer.input, 0, 3)
c012.wait(1, Buffer.input, 0, recvtb=0)
c012.reduce(c01, recvtb=0).put(0, Buffer.input, 0, sendtb=0)
c012.signal(0, Buffer.input, 0, sendtb=0)
c0 = chunk(0, Buffer.input, 0, 3)
c0.wait(2, Buffer.input, 0, recvtb=0)
c0.put(1, Buffer.input, 0, sendtb=0)
assert Check()
lowered_prgm = prgm.lower()
assert lowered_prgm.gpus[0].threadblocks[0].ops[0].inst == MscclppInstruction.wait
assert lowered_prgm.gpus[0].threadblocks[0].ops[1].inst == MscclppInstruction.put
assert lowered_prgm.gpus[1].threadblocks[0].ops[0].inst == MscclppInstruction.read_reduce_copy
assert lowered_prgm.gpus[1].threadblocks[0].ops[1].inst == MscclppInstruction.signal
assert lowered_prgm.gpus[2].threadblocks[0].ops[0].inst == MscclppInstruction.wait
assert lowered_prgm.gpus[2].threadblocks[0].ops[1].inst == MscclppInstruction.read_reduce_copy_send
assert lowered_prgm.gpus[2].threadblocks[0].ops[2].inst == MscclppInstruction.signal


def test_replication():
topology = fully_connected(2)
collective = AllToAll(2, 1, False)
Expand Down Expand Up @@ -287,4 +314,38 @@ def test_routines_allreduce_nodes():
c.reduce(chunk(r, Buffer.output, exchange_index, 4))
c = c.copy(r, Buffer.output, exchange_index)
XML()
assert Check()
assert Check()

def test_routines_allreduce_packet_inplace_mscclpp():
size = 8
topology = fully_connected(size)
collective = AllReduce(size, size * size, True)
with MSCCLPPProgram("allreduce_packet", topology, collective, 1, protocol="LL"):
# Each rank sends the nth chunk to the nth rank into scratch space
for r1 in range(size):
for tb in range(size):
if tb == r1:
continue
remote_rank = tb
index = remote_rank * size
c = chunk(r1, Buffer.input, index, size)
c.put_packet(remote_rank, "scratch", index=r1 * size, sendtb=tb)
# Each rank performs a local reduction on the nth chunk
# Utilize 8 threadblocks for this reduction for better parallelism
for r in range(size):
for index in range(size):
c = chunk(r, Buffer.input, r * size + index)
for peer in range(size):
if peer != r:
c.reduce_packet(chunk(r, "scratch", peer * size + index), recvtb=index)
for peer in range(size):
if peer != r:
c.put_packet(peer, "scratch", (size * size) + r * size + index, sendtb=index)
# Each rank get final result from scratch space
for r in range(size):
for peer in range(size):
if peer != r:
c = chunk(r, "scratch", size * size + peer * size, size)
c.copy_packet(r, Buffer.input, peer * size, sendtb=peer)
Json()
assert Check()

0 comments on commit 2ba1760

Please sign in to comment.