Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transform_to_packet operator #11

Merged
merged 14 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def allreduce_allpairs(gpus, instances):
instances,
protocol="LL",
):

# Each rank sends the nth chunk to the nth rank into scratch space
for r1 in range(size):
for tb in range(size):
Expand All @@ -28,7 +27,7 @@ def allreduce_allpairs(gpus, instances):
remote_rank = tb
index = remote_rank * size
c = chunk(r1, Buffer.input, index, size)
c.put_packet(remote_rank, "scratch", index=r1*size, sendtb=tb)
c.put_packet(remote_rank, "scratch", index=r1 * size, sendtb=tb)

# Each rank performs a local reduction on the nth chunk
# Utilize 8 threadblocks for this reduction for better parallelism
Expand Down
50 changes: 50 additions & 0 deletions examples/mscclang/mscclpp/send_recv_a100_packet.mscclpp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
from msccl.language import *
from msccl.topologies import *
from msccl.language.collectives import SendRecv


def send_recv(instances):
size = 2
chunksperloop = 1
topology = fully_connected(size)
collective = SendRecv(size, chunksperloop, False)
with MSCCLPPProgram(
"send_recv",
topology,
collective,
instances,
protocol="LL",
):
for r in range(size):
for nghr in range(size):
if nghr == r:
continue
c = chunk(r, Buffer.input, 0)
c.put_packet(
nghr,
"scratch",
1,
sendtb=0,
chan_type=ChannelType.proxy,
temp_buffer="scratch",
temp_buffer_index=0,
)

for r in range(size):
c = chunk(r, "scratch", 1)
c.copy_packet(r, Buffer.output, 0, sendtb=0)

Json()
Check()


parser = argparse.ArgumentParser()
parser.add_argument("instances", type=int, help="number of instances")

args = parser.parse_args()

send_recv(args.instances)
46 changes: 40 additions & 6 deletions msccl/language/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ def check(self, prog):
chunk = output[index]
expected_origin_index = ch + r * self.chunk_factor
if chunk is None or chunk.origin_rank != i or chunk.origin_index != expected_origin_index:
print(
f"Rank {r} chunk {index} is incorrect should be chunk({i},{expected_origin_index}) given {chunk}"
)
print(f"Rank {r} chunk {index} is incorrect should be chunk({i},{expected_origin_index}) given {chunk}")
correct = False
return correct

Expand Down Expand Up @@ -162,9 +160,7 @@ def check(self, prog):
for c in range(chunks_per_node):
chunk = output[c]
if chunk is None or chunk != expected_chunks[c]:
print(
f"Rank {r} chunk {c} is incorrect should be ReduceChunk index {c} from all ranks, given {chunk}"
)
print(f"Rank {r} chunk {c} is incorrect should be ReduceChunk index {c} from all ranks, given {chunk}")
correct = False
return correct

Expand Down Expand Up @@ -228,3 +224,41 @@ def get_buffer_index(self, rank, buffer, index):
return Buffer.input, index + rank * self.chunk_factor
else:
return buffer, index


# SendRecv is a collective that sends a chunk from one rank to another
# It is used to test the correctness of the send and receive instructions
class SendRecv(Collective):
def __init__(self, num_ranks, chunk_factor, inplace):
assert num_ranks == 2, "SendRecv only supports 2 ranks"
Collective.__init__(self, num_ranks, chunk_factor, inplace)
self.name = "sendrecv"

def init_buffers(self):
rank_buffers = []
for r in range(self.num_ranks):
input_buffer = [None] * self.chunk_factor
output_buffer = [None] * self.chunk_factor
for c in range(self.chunk_factor):
input_buffer[c] = Chunk(r, c, -1, c)
buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer}
rank_buffers.append(buffers)
return rank_buffers

def check(self, prog):
correct = True
buff_type = Buffer.input if self.inplace else Buffer.output
for r in range(self.num_ranks):
output = prog.buffers[r][buff_type]
for c in range(self.chunk_factor):
chunk = output[c]
if chunk is None or chunk.origin_rank != 1 - r or chunk.origin_index != c:
print(f"Rank {r} chunk {c} is incorrect should be ({1 - r}, {c}) given {chunk}")
correct = False

return correct

def get_buffer_index(self, rank, buffer, index):
if self.inplace and buffer == Buffer.output:
return Buffer.input, index
return buffer, index
31 changes: 21 additions & 10 deletions msccl/language/mscclpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
raise RuntimeError("This program is not currently in context")
_current_program = None


def _convert_to_exectuion_plan(self):
ops = self.instr_dag.convert_set_list()
ops = sorted(ops, key=lambda x: x.step)
Expand Down Expand Up @@ -197,7 +196,7 @@ def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm,
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size)
if use_packet:
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, use_packet)
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, True)
self.prog.instr_dag.add_signal(self.rank, self, dst_chunkref, -1, ChannelType.none)
self.prog.instr_dag.add_wait(dst, dst_chunkref, self, -1, ChannelType.none)
else:
Expand All @@ -207,8 +206,22 @@ def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm,
def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm):
return self._put(dst, buffer, index, sendtb, chan_type)

def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, channel_type=ChannelType.sm):
return self._put(dst, buffer, index, sendtb, channel_type, use_packet=True)
def put_packet(
self,
dst,
buffer=None,
index=-1,
sendtb=-1,
chan_type=ChannelType.sm,
temp_buffer=None,
temp_buffer_index=-1,
):
chunk_ref = self
if chan_type == ChannelType.proxy:
chunk_ref = self._copy(
self.rank, temp_buffer, temp_buffer_index, sendtb, trans_from_packet=False, trans_to_packet=True
)
return chunk_ref._put(dst, buffer, index, sendtb, chan_type, True)

def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm):
self.prog.check_buffer_exists(src, buffer)
Expand Down Expand Up @@ -249,7 +262,7 @@ def wait(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm):
src_chunkref = self.prog.get_ref(src, buffer, index, self.size)
self.prog.instr_dag.add_wait(receiver, self, src_chunkref, recvtb, chan_type)

def _copy(self, dst, buffer=None, index=-1, sendtb=-1, use_packet=False):
def _copy(self, dst, buffer=None, index=-1, sendtb=-1, trans_from_packet=False, trans_to_packet=False):
self.prog.check_buffer_exists(dst, buffer)
buffer, index = self._get_buffer_index(dst, buffer, index)

Expand All @@ -260,7 +273,7 @@ def _copy(self, dst, buffer=None, index=-1, sendtb=-1, use_packet=False):
self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size)

assert self.rank == dst, "Chunk copy only supports intra-rank communication"
self.prog.instr_dag.add_copy(self.rank, self, dst_chunkref, sendtb, use_packet)
self.prog.instr_dag.add_copy(self.rank, self, dst_chunkref, sendtb, trans_from_packet, trans_to_packet)

return dst_chunkref

Expand All @@ -269,15 +282,13 @@ def copy(self, dst, buffer=None, index=-1, sendtb=-1):
return self._copy(dst, buffer, index, sendtb)

def copy_packet(self, dst, buffer=None, index=-1, sendtb=-1):
return self._copy(dst, buffer, index, sendtb, use_packet=True)
return self._copy(dst, buffer, index, sendtb, trans_from_packet=True, trans_to_packet=False)

def _reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.sm, use_packet=False):
dst = self.rank
src = other_chunkref.rank
assert self.prog.topo.link(src, dst) or src == dst, f"No link from {src} to {dst}"
self.prog.apply_reduce(
src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size
)
self.prog.apply_reduce(src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size)
if use_packet:
assert src == dst, "Packet reduce only supports intra-rank communication"

Expand Down
6 changes: 4 additions & 2 deletions msccl/language/mscclpp/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ def __init__(self, num_ranks, buffers):
super().__init__(num_ranks, buffers)

# InstructionDAG - adds a copy node
def add_copy(self, rank, send_ref, recv_ref, tb, use_packet=False):
def add_copy(self, rank, send_ref, recv_ref, tb, trans_from_packet=False, trans_to_packet=False):
tb_step = self._get_tb_step(rank, tb)
if use_packet:
if trans_from_packet:
op = Op(Instruction.copy_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step)
elif trans_to_packet:
op = Op(Instruction.transform_to_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step)
else:
op = Op(Instruction.copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step)
dstbuffer = recv_ref.buffer
Expand Down
41 changes: 29 additions & 12 deletions msccl/language/mscclpp/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,38 @@

from msccl.language.types import Buffer, ChannelType, Op, Program, MscclppInstruction as Instruction

_local_src_insts_mscclpp = {
_local_src_insts_mscclpp: set = {
Instruction.put,
Instruction.put_packet,
Instruction.signal,
Instruction.copy,
Instruction.copy_packet,
Instruction.transform_to_packet,
Instruction.reduce,
Instruction.reduce_packet,
Instruction.reduce_send,
Instruction.reduce_send_packet,
}
_local_dst_insts_mscclpp = {
_local_dst_insts_mscclpp: set = {
Instruction.get,
Instruction.wait,
Instruction.read_reduce_copy,
Instruction.copy,
Instruction.copy_packet,
Instruction.transform_to_packet,
Instruction.reduce,
Instruction.read_reduce_copy_send,
Instruction.reduce_send,
Instruction.reduce_packet,
Instruction.reduce_send_packet,
}

_insts_no_need_sync_barrier: set = {
Instruction.copy_packet,
Instruction.reduce_packet,
Instruction.reduce_send_packet,
}


def ir_to_json(program: Program):
# Figure out sizes of buffers based on usage
Expand Down Expand Up @@ -102,16 +110,18 @@ def ir_to_json(program: Program):

# Do some additional postprocessing of operations:
# - Expand operations with dependencies with no-ops
if program.protocol != "LL": # TODO(binyli): fix this. Should based on OP type not algorithm
for gpu in program.gpus:
for tb in gpu.threadblocks:
new_ops = []
for op in tb.ops:
# Expand extra dependencies into nop operations
for i, dep in enumerate(op.depends):
new_ops.append(Op(Instruction.nop, -1, None, None, [dep]))
for gpu in program.gpus:
for tb in gpu.threadblocks:
new_ops = []
for op in tb.ops:
if op.inst in _insts_no_need_sync_barrier:
new_ops.append(op)
tb.ops = new_ops
continue
# Expand extra dependencies into nop operations
for i, dep in enumerate(op.depends):
new_ops.append(Op(Instruction.nop, -1, None, None, [dep]))
new_ops.append(op)
tb.ops = new_ops

# update step and tid for ops
for gpu in program.gpus:
Expand Down Expand Up @@ -256,7 +266,14 @@ def remove_empty_fields(d):
)
i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value}
dsts = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.dsts))
elif op.inst == Instruction.copy or op.inst == Instruction.copy_packet:
elif (
op.inst == Instruction.copy
or op.inst == Instruction.copy_packet
or op.inst == Instruction.transform_to_packet
):
src = op.src
dst = op.dst
elif op.inst == Instruction.transform_to_packet:
Binyang2014 marked this conversation as resolved.
Show resolved Hide resolved
src = op.src
dst = op.dst
if op.inst != Instruction.nop:
Expand Down
1 change: 1 addition & 0 deletions msccl/language/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class MscclppInstruction(Enum):
copy = "copy"
reduce = "reduce"
copy_packet = "cpkt"
transform_to_packet = "tpkt"
reduce_send_packet = "rspkt"
reduce_packet = "rpkt"
put = "put"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[tool.black]
line-length = 140
line-length = 120
target-version = ['py38']
include = '\.pyi?$'
20 changes: 20 additions & 0 deletions tests/configs/test-config.json
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,25 @@
{
"filename": "pipeline_a100_ring.py",
"args": ["8", "4", "2"]
},
{
"filename": "mscclpp/allreduce_a100_allpairs_packet_mscclpp.py",
"args": ["8", "8"]
},
{
"filename": "mscclpp/allreduce_a100_allpairs_sm_mscclpp_get.py",
"args": ["8", "8"]
},
{
"filename": "mscclpp/allreduce_a100_allpairs_sm_mscclpp.py",
"args": ["8", "8"]
},
{
"filename": "mscclpp/allreduce_a100_ring_mscclpp.py",
"args": ["8", "8"]
},
{
"filename": "mscclpp/send_recv_a100_packet.mscclpp.py",
"args": ["2"]
}
]
3 changes: 2 additions & 1 deletion tests/generate_test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def run_examples(input_folder, configs, output_folder):
input_file_path = Path(input_folder) / file_name
# Strip the ".py" from the filename and add ".output"
base_file_name = file_name[:-3] if file_name.endswith('.py') else file_name
output_file_path = Path(output_folder) / f"{base_file_name}.xml"
base_file_name = base_file_name.replace("/", "_")
output_file_path = Path(output_folder) / f"{base_file_name}.output"

# Construct the command to run the Python script
command = ["python3", str(input_file_path)] + args
Expand Down
Loading