From 71bf147c7f5df3ff51f43988c6321de4f96cf7fe Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Fri, 30 Aug 2024 08:26:18 +0000 Subject: [PATCH] WIP --- .../mscclpp/send_recv_a100_packet.mscclpp.py | 11 ++++++-- msccl/language/mscclpp/__init__.py | 27 ++++++++++++------- msccl/language/mscclpp/instruction_dag.py | 16 +++++------ msccl/language/mscclpp/ir.py | 4 +-- msccl/language/types.py | 2 +- 5 files changed, 38 insertions(+), 22 deletions(-) diff --git a/examples/mscclang/mscclpp/send_recv_a100_packet.mscclpp.py b/examples/mscclang/mscclpp/send_recv_a100_packet.mscclpp.py index 6752d22..1444eeb 100644 --- a/examples/mscclang/mscclpp/send_recv_a100_packet.mscclpp.py +++ b/examples/mscclang/mscclpp/send_recv_a100_packet.mscclpp.py @@ -24,8 +24,15 @@ def send_recv(instances): if nghr == r: continue c = chunk(r, Buffer.input, 0) - c = c.trans_to_packet(r, "scratch", 0, sendtb=0) - c.put_packet(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.proxy, trans_to_packet=False) + c.put_packet( + nghr, + "scratch", + 1, + sendtb=0, + chan_type=ChannelType.proxy, + temp_buffer="scratch", + temp_buffer_index=0, + ) for r in range(size): c = chunk(r, "scratch", 1) diff --git a/msccl/language/mscclpp/__init__.py b/msccl/language/mscclpp/__init__.py index 903329f..502f6b3 100644 --- a/msccl/language/mscclpp/__init__.py +++ b/msccl/language/mscclpp/__init__.py @@ -186,9 +186,7 @@ def _get_buffer_index(self, remote_rank, buffer, index): return buffer, self.prog.buffers[remote_rank][buffer].instance_size() return buffer, index - def _put( - self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, use_packet=False, trans_to_packet=False - ): + def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, use_packet=False): self.prog.check_buffer_exists(dst, buffer) assert self.rank != dst, "Cannot put to the same rank" buffer, index = self._get_buffer_index(dst, buffer, index) @@ -198,7 +196,7 @@ def _put( dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) if use_packet: - self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, trans_to_packet) + self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, True) self.prog.instr_dag.add_signal(self.rank, self, dst_chunkref, -1, ChannelType.none) self.prog.instr_dag.add_wait(dst, dst_chunkref, self, -1, ChannelType.none) else: @@ -208,8 +206,22 @@ def _put( def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): return self._put(dst, buffer, index, sendtb, chan_type) - def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, trans_to_packet=True): - return self._put(dst, buffer, index, sendtb, chan_type, use_packet=True, trans_to_packet=trans_to_packet) + def put_packet( + self, + dst, + buffer=None, + index=-1, + sendtb=-1, + chan_type=ChannelType.sm, + temp_buffer=None, + temp_buffer_index=-1, + ): + chunk_ref = self + if chan_type == ChannelType.proxy: + chunk_ref = self._copy( + self.rank, temp_buffer, temp_buffer_index, sendtb, trans_from_packet=False, trans_to_packet=True + ) + return chunk_ref._put(dst, buffer, index, sendtb, chan_type, True) def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm): self.prog.check_buffer_exists(src, buffer) @@ -269,9 +281,6 @@ def _copy(self, dst, buffer=None, index=-1, sendtb=-1, trans_from_packet=False, def copy(self, dst, buffer=None, index=-1, sendtb=-1): return self._copy(dst, buffer, index, sendtb) - def trans_to_packet(self, dst, buffer=None, index=-1, sendtb=-1): - return self._copy(dst, buffer, index, sendtb, trans_from_packet=False, trans_to_packet=True) - def copy_packet(self, dst, buffer=None, index=-1, sendtb=-1): return self._copy(dst, buffer, index, sendtb, trans_from_packet=True, trans_to_packet=False) diff --git a/msccl/language/mscclpp/instruction_dag.py b/msccl/language/mscclpp/instruction_dag.py index f18da46..82555b1 100644 --- a/msccl/language/mscclpp/instruction_dag.py +++ b/msccl/language/mscclpp/instruction_dag.py @@ -63,11 +63,11 @@ def add_reduce(self, rank, send_ref, recv_ref, tb, use_packet=False): return op # InstructionDAG - adds a put node - def add_put(self, rank, send_ref, recv_ref, tb, ch_type, trans_to_packet=False): + def add_put(self, rank, send_ref, recv_ref, tb, ch_type, use_packet=False): tb_step = self._get_tb_step(rank, tb) - if trans_to_packet: + if use_packet: op = Op( - Instruction.trans_put_packet, + Instruction.put_packet, rank, send_ref, recv_ref, @@ -167,7 +167,7 @@ def add_read_reduce(self, rank, send_ref, recv_ref, tb, ch_type): return op def complete_channels(self): - send_op = [Instruction.put, Instruction.signal, Instruction.trans_put_packet] + send_op = [Instruction.put, Instruction.signal, Instruction.put_packet] recv_op = [Instruction.wait, Instruction.get, Instruction.read_reduce_copy] for rank, rank_tbs in enumerate(self.tbs): for tbid, tb in rank_tbs.items(): @@ -198,7 +198,7 @@ def _optimize_redundant_signal_wait(self): queue = list(tb.ops) while len(queue) > 0: op = queue[0] - if op.inst == Instruction.trans_put_packet: + if op.inst == Instruction.put_packet: fused = False for next_op in op.next: if next_op.inst == Instruction.signal: @@ -439,7 +439,7 @@ def _optimize_rrcs_rs(self): fused = False for next_op in op.next: if ( - next_op.inst == Instruction.trans_put_packet + next_op.inst == Instruction.put_packet and same_count(op, next_op) and buf_dst_src_match(op, next_op) and next_op.channel_type == ChannelType.sm @@ -533,12 +533,12 @@ def _optimize_get_put(self): fused = True if fused: continue - elif op.inst == Instruction.trans_put_packet: + elif op.inst == Instruction.put_packet: fused = False if len(queue) > 1: seq_op = queue[1] if ( - seq_op.inst == Instruction.trans_put_packet + seq_op.inst == Instruction.put_packet and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op) and same_count(op, seq_op) diff --git a/msccl/language/mscclpp/ir.py b/msccl/language/mscclpp/ir.py index d9ebb49..c7a28f4 100644 --- a/msccl/language/mscclpp/ir.py +++ b/msccl/language/mscclpp/ir.py @@ -8,7 +8,7 @@ _local_src_insts_mscclpp: set = { Instruction.put, - Instruction.trans_put_packet, + Instruction.put_packet, Instruction.signal, Instruction.copy, Instruction.copy_packet, @@ -254,7 +254,7 @@ def remove_empty_fields(d): "name": op.inst.value, "deps": list(map(lambda dep: {"tb": dep.tb, "step": dep.step}, op.depends)), } - elif op.inst == Instruction.put or op.inst == Instruction.trans_put_packet: + elif op.inst == Instruction.put or op.inst == Instruction.put_packet: dst_channel_ids = get_channel_ids( op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type ) diff --git a/msccl/language/types.py b/msccl/language/types.py index 768579f..be5677d 100644 --- a/msccl/language/types.py +++ b/msccl/language/types.py @@ -115,7 +115,7 @@ class MscclppInstruction(Enum): reduce_send_packet = "rspkt" reduce_packet = "rpkt" put = "put" - trans_put_packet = "ppkt" + put_packet = "ppkt" get = "get" wait = "wait" signal = "signal"