Skip to content

Commit

Permalink
adjust put_package operation for proxxychannels
Browse files Browse the repository at this point in the history
  • Loading branch information
caiomcbr committed Aug 29, 2024
1 parent 3c9bcff commit 3e0e517
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ def allreduce_allpairs(gpus, instances):
):
rank = 0
c = chunk(rank, Buffer.input, 0, 1)
scartch_chunk = chunk(rank, "scratch", 0, size)
c.put_packet(1, "scratch", index=1, -1, ChannelType.proxy, temp_chunk=scartch_chunk)
c.put_packet(1, "scratch", index=1, sendtb=0, channel_type=ChannelType.proxy)
# # Each rank sends the nth chunk to the nth rank into scratch space
# for r1 in range(size):
# for tb in range(size):
Expand Down
10 changes: 7 additions & 3 deletions msccl/language/mscclpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ def _get_buffer_index(self, remote_rank, buffer, index):

def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, use_packet=False):
self.prog.check_buffer_exists(dst, buffer)
temp_chunk=None
if chan_type == ChannelType.proxy:
self.prog.check_buffer_exists(self.rank, "scratch")
temp_chunk = self.prog.get_ref(self.rank, "scratch", 0, self.size)
assert self.rank != dst, "Cannot put to the same rank"
buffer, index = self._get_buffer_index(dst, buffer, index)

Expand All @@ -197,17 +201,17 @@ def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm,
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size)
if use_packet:
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, use_packet)
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, use_packet, temp_chunk)
self.prog.instr_dag.add_signal(self.rank, self, dst_chunkref, -1, ChannelType.none)
self.prog.instr_dag.add_wait(dst, dst_chunkref, self, -1, ChannelType.none)
else:
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type)
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, temp_chunk)
return dst_chunkref

def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm):
return self._put(dst, buffer, index, sendtb, chan_type)

def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, channel_type=ChannelType.sm, temp_chunk=None):
def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, channel_type=ChannelType.sm):
return self._put(dst, buffer, index, sendtb, channel_type, use_packet=True)

def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm):
Expand Down
3 changes: 3 additions & 0 deletions msccl/language/mscclpp/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ def remove_empty_fields(d):
elif op.inst == Instruction.copy or op.inst == Instruction.copy_packet or op.inst == Instruction.transform_to_packet:
src = op.src
dst = op.dst
elif op.inst == Instruction.transform_to_packet:
src = op.src
dst = op.dst
if op.inst != Instruction.nop:
instr = {
"name": op.inst.value,
Expand Down

0 comments on commit 3e0e517

Please sign in to comment.