From 3e0e51716e33462d45c0a8717bc229add9e361e2 Mon Sep 17 00:00:00 2001 From: Caio Rocha Date: Thu, 29 Aug 2024 05:44:37 +0000 Subject: [PATCH] adjust put_package operation for proxxychannels --- .../mscclpp/allreduce_a100_allpairs_packet_mscclpp.py | 3 +-- msccl/language/mscclpp/__init__.py | 10 +++++++--- msccl/language/mscclpp/ir.py | 3 +++ 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/mscclang/mscclpp/allreduce_a100_allpairs_packet_mscclpp.py b/examples/mscclang/mscclpp/allreduce_a100_allpairs_packet_mscclpp.py index d24942e..4ad5aba 100644 --- a/examples/mscclang/mscclpp/allreduce_a100_allpairs_packet_mscclpp.py +++ b/examples/mscclang/mscclpp/allreduce_a100_allpairs_packet_mscclpp.py @@ -21,8 +21,7 @@ def allreduce_allpairs(gpus, instances): ): rank = 0 c = chunk(rank, Buffer.input, 0, 1) - scartch_chunk = chunk(rank, "scratch", 0, size) - c.put_packet(1, "scratch", index=1, -1, ChannelType.proxy, temp_chunk=scartch_chunk) + c.put_packet(1, "scratch", index=1, sendtb=0, channel_type=ChannelType.proxy) # # Each rank sends the nth chunk to the nth rank into scratch space # for r1 in range(size): # for tb in range(size): diff --git a/msccl/language/mscclpp/__init__.py b/msccl/language/mscclpp/__init__.py index 3a02ff3..15258ed 100644 --- a/msccl/language/mscclpp/__init__.py +++ b/msccl/language/mscclpp/__init__.py @@ -189,6 +189,10 @@ def _get_buffer_index(self, remote_rank, buffer, index): def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, use_packet=False): self.prog.check_buffer_exists(dst, buffer) + temp_chunk=None + if chan_type == ChannelType.proxy: + self.prog.check_buffer_exists(self.rank, "scratch") + temp_chunk = self.prog.get_ref(self.rank, "scratch", 0, self.size) assert self.rank != dst, "Cannot put to the same rank" buffer, index = self._get_buffer_index(dst, buffer, index) @@ -197,17 +201,17 @@ def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) if use_packet: - self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, use_packet) + self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, use_packet, temp_chunk) self.prog.instr_dag.add_signal(self.rank, self, dst_chunkref, -1, ChannelType.none) self.prog.instr_dag.add_wait(dst, dst_chunkref, self, -1, ChannelType.none) else: - self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type) + self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, temp_chunk) return dst_chunkref def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): return self._put(dst, buffer, index, sendtb, chan_type) - def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, channel_type=ChannelType.sm, temp_chunk=None): + def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, channel_type=ChannelType.sm): return self._put(dst, buffer, index, sendtb, channel_type, use_packet=True) def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm): diff --git a/msccl/language/mscclpp/ir.py b/msccl/language/mscclpp/ir.py index 577166e..d766da4 100644 --- a/msccl/language/mscclpp/ir.py +++ b/msccl/language/mscclpp/ir.py @@ -261,6 +261,9 @@ def remove_empty_fields(d): elif op.inst == Instruction.copy or op.inst == Instruction.copy_packet or op.inst == Instruction.transform_to_packet: src = op.src dst = op.dst + elif op.inst == Instruction.transform_to_packet: + src = op.src + dst = op.dst if op.inst != Instruction.nop: instr = { "name": op.inst.value,