Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Aug 27, 2024
1 parent e1f41fb commit 3c9bcff
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 32 deletions.
61 changes: 32 additions & 29 deletions examples/mscclang/mscclpp/allreduce_a100_allpairs_packet_mscclpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,38 +19,41 @@ def allreduce_allpairs(gpus, instances):
instances,
protocol="LL",
):
rank = 0
c = chunk(rank, Buffer.input, 0, 1)
scartch_chunk = chunk(rank, "scratch", 0, size)
c.put_packet(1, "scratch", index=1, -1, ChannelType.proxy, temp_chunk=scartch_chunk)
# # Each rank sends the nth chunk to the nth rank into scratch space
# for r1 in range(size):
# for tb in range(size):
# if tb == r1:
# continue
# remote_rank = tb
# index = remote_rank * size
# c = chunk(r1, Buffer.input, index, size)
# c.put_packet(remote_rank, "scratch", index=r1*size, sendtb=tb)

# Each rank sends the nth chunk to the nth rank into scratch space
for r1 in range(size):
for tb in range(size):
if tb == r1:
continue
remote_rank = tb
index = remote_rank * size
c = chunk(r1, Buffer.input, index, size)
c.put_packet(remote_rank, "scratch", index=r1*size, sendtb=tb)

# Each rank performs a local reduction on the nth chunk
# Utilize 8 threadblocks for this reduction for better parallelism
for r in range(size):
for index in range(size):
c = chunk(r, Buffer.input, r * size + index)
for peer in range(size):
if peer != r:
c.reduce_packet(chunk(r, "scratch", peer * size + index), recvtb=index)
for peer in range(size):
if peer != r:
c.put_packet(peer, "scratch", (size * size) + r * size + index, sendtb=index)

# Each rank get final result from scratch space
for r in range(size):
for peer in range(size):
if peer != r:
c = chunk(r, "scratch", size * size + peer * size, size)
c.copy_packet(r, Buffer.input, peer * size, sendtb=peer)
# # Each rank performs a local reduction on the nth chunk
# # Utilize 8 threadblocks for this reduction for better parallelism
# for r in range(size):
# for index in range(size):
# c = chunk(r, Buffer.input, r * size + index)
# for peer in range(size):
# if peer != r:
# c.reduce_packet(chunk(r, "scratch", peer * size + index), recvtb=index)
# for peer in range(size):
# if peer != r:
# c.put_packet(peer, "scratch", (size * size) + r * size + index, sendtb=index)

# # Each rank get final result from scratch space
# for r in range(size):
# for peer in range(size):
# if peer != r:
# c = chunk(r, "scratch", size * size + peer * size, size)
# c.copy_packet(r, Buffer.input, peer * size, sendtb=peer)

Json()
Check()
# Check()


parser = argparse.ArgumentParser()
Expand Down
2 changes: 1 addition & 1 deletion msccl/language/mscclpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm,
def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm):
return self._put(dst, buffer, index, sendtb, chan_type)

def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, channel_type=ChannelType.sm):
def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, channel_type=ChannelType.sm, temp_chunk=None):
return self._put(dst, buffer, index, sendtb, channel_type, use_packet=True)

def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm):
Expand Down
35 changes: 34 additions & 1 deletion msccl/language/mscclpp/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,41 @@ def add_reduce(self, rank, send_ref, recv_ref, tb, use_packet=False):
return op

# InstructionDAG - adds a put node
def add_put(self, rank, send_ref, recv_ref, tb, ch_type, use_packet=False):
def add_put(self, rank, send_ref, recv_ref, tb, ch_type, use_packet=False, temp_chunk=None):
tb_step = self._get_tb_step(rank, tb)
if ch_type == ChannelType.proxy and temp_chunk is not None:
op = Op(
Instruction.transform_to_packet,
rank,
send_ref,
temp_chunk,
next=set(),
prev=set(),
tb=tb,
channel_type=ch_type,
step=tb_step,
)
tb_step = self._get_tb_step(rank, tb)
op2 = Op(
Instruction.put,
rank,
send_ref,
recv_ref,
next=set(),
prev=set(),
tb=tb,
channel_type=ch_type,
step=tb_step,
)
buffer = send_ref.buffer
index = send_ref.index
size = send_ref.size
self._read(rank, buffer, index, size, op)
self._write(rank, temp_chunk.buffer, temp_chunk.index, temp_chunk.size, op)
self._read(rank, temp_chunk.buffer, temp_chunk.index, temp_chunk.size, op2)
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
op2.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step))
return op
if use_packet:
op = Op(
Instruction.put_packet,
Expand Down
4 changes: 3 additions & 1 deletion msccl/language/mscclpp/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Instruction.signal,
Instruction.copy,
Instruction.copy_packet,
Instruction.transform_to_packet,
Instruction.reduce,
Instruction.reduce_packet,
Instruction.reduce_send,
Expand All @@ -23,6 +24,7 @@
Instruction.read_reduce_copy,
Instruction.copy,
Instruction.copy_packet,
Instruction.transform_to_packet,
Instruction.reduce,
Instruction.read_reduce_copy_send,
Instruction.reduce_send,
Expand Down Expand Up @@ -256,7 +258,7 @@ def remove_empty_fields(d):
)
i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value}
dsts = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.dsts))
elif op.inst == Instruction.copy or op.inst == Instruction.copy_packet:
elif op.inst == Instruction.copy or op.inst == Instruction.copy_packet or op.inst == Instruction.transform_to_packet:
src = op.src
dst = op.dst
if op.inst != Instruction.nop:
Expand Down
1 change: 1 addition & 0 deletions msccl/language/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class MscclppInstruction(Enum):
copy = "copy"
reduce = "reduce"
copy_packet = "cpkt"
transform_to_packet = "tpkt"
reduce_send_packet = "rspkt"
reduce_packet = "rpkt"
put = "put"
Expand Down

0 comments on commit 3c9bcff

Please sign in to comment.