From 3e3104ee63024131ec058cd60d9c2620b6fb3192 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 11 Sep 2024 08:26:54 +0000 Subject: [PATCH] WIP --- msccl/language/mscclpp/__init__.py | 11 ++ msccl/language/mscclpp/instruction_dag.py | 194 ++++++---------------- msccl/language/mscclpp/ir.py | 1 + msccl/language/mscclpp/optimizer.py | 58 +++++++ 4 files changed, 124 insertions(+), 140 deletions(-) create mode 100644 msccl/language/mscclpp/optimizer.py diff --git a/msccl/language/mscclpp/__init__.py b/msccl/language/mscclpp/__init__.py index 8ba171e..ceedb28 100644 --- a/msccl/language/mscclpp/__init__.py +++ b/msccl/language/mscclpp/__init__.py @@ -252,6 +252,17 @@ def signal(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) self.prog.instr_dag.add_signal(sender, self, dst_chunkref, sendtb, chan_type) + # only proxy channel need to use this function + def flush(self, dst, buffer=None, index=-1, sendtb=-1): + sender = self.rank + receiver = dst + assert sender != receiver, "Cannot flush to the same rank" + buffer, index = self._get_buffer_index(dst, buffer, index) + + assert self.prog.topo.link(self.rank, dst) or dst == self.rank, f"No link from {self.rank} to {dst}" + dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) + self.prog.instr_dag.add_flush(sender, self, dst_chunkref, sendtb) + def wait(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm): sender = src receiver = self.rank diff --git a/msccl/language/mscclpp/instruction_dag.py b/msccl/language/mscclpp/instruction_dag.py index 82555b1..a7c604c 100644 --- a/msccl/language/mscclpp/instruction_dag.py +++ b/msccl/language/mscclpp/instruction_dag.py @@ -3,7 +3,6 @@ from msccl.language.buffer import Buffer -from msccl.language.types import Channel, ChannelType from msccl.language.instruction_dag import ( buf_dst_src_match, merge_op, @@ -16,7 +15,16 @@ same_src_dst_buffer_type, ) from msccl.language.instruction_dag import InstructionDAG -from msccl.language.types import ChunkRef, MscclppInstruction as Instruction, Op, ReplicationPolicy, Threadblock +from msccl.language.mscclpp.optimizer import Optimizer +from msccl.language.types import ( + Channel, + ChannelType, + ChunkRef, + MscclppInstruction as Instruction, + Op, + ReplicationPolicy, + Threadblock, +) class MscclppInstructionDAG(InstructionDAG): @@ -133,6 +141,27 @@ def add_signal(self, rank, send_ref, recv_ref, tb, ch_type): op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) return op + def add_flush(self, rank, send_ref, recv_ref, tb): + tb_step = self._get_tb_step(rank, tb) + op = Op( + Instruction.flush, + rank, + send_ref, + recv_ref, + next=set(), + prev=set(), + tb=tb, + channel_type=ChannelType.proxy, + step=tb_step, + ) + buffer = send_ref.buffer + index = send_ref.index + size = send_ref.size + self._read(rank, buffer, index, size, op) + op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step)) + op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) + return op + def add_wait(self, rank, dst_ref, src_ref, tb, ch_type): tb_step = self._get_tb_step(rank, tb) op = Op( @@ -220,152 +249,37 @@ def _optimize_redundant_signal_wait(self): # rrc(_,_,_,dst,dbuf,di) rrc(_,_,_,dst,dbuf,di) -> rrc(list[src,sbuf,si], dst, dbuf, di) # signal(_,_,_,dst,dbuf,di) signal(_,_,_,dst,dbuf,di) -> signal(_,_,_,list[dst,dbuf,di]) + # flush(_,_,_,dst,dbuf,di) flush(_,_,_,dst,dbuf,di) -> flush(_,_,_,list[dst,dbuf,di]) # wait(src,sbuf,si,_,_,_) wait(src,sbuf,si,_,_,_) -> wait(list[src,sbuf,si],_,_,_,_]) # reduce(_,_,_,dst,dbuf,di) reduce(_,_,_,dst,dbuf,di) -> reduce(list[src,sbuf,si], dst, dbuf, di) # reduce_packet(_,_,_,dst,dbuf,di) reduce_packet(_,_,_,dst,dbuf,di) -> reduce_packet(list[src,sbuf,si], dst, dbuf, di) - def _optimize_rrc_r_signal_wait(self): - for rank, rank_tbs in enumerate(self.tbs): - for tbid, tb in rank_tbs.items(): + def _optimize_fuse_same_instruction(self): + optimizer = Optimizer() + # Mapping instruction to their respective condition checks and same buffer function + instruction_handlers = { + Instruction.read_reduce_copy: same_buf_dst, + Instruction.reduce: same_buf_dst, + Instruction.reduce_packet: same_buf_dst, + Instruction.signal: same_buf_src, + Instruction.wait: same_buf_dst, + Instruction.flush: same_buf_src, + } + + for _, rank_tbs in enumerate(self.tbs): + for _, tb in rank_tbs.items(): queue = list(tb.ops) while len(queue) > 0: op = queue[0] - if op.inst == Instruction.read_reduce_copy: - fused = False - for next_op in op.next: - if ( - next_op.inst == Instruction.read_reduce_copy - and same_count(op, next_op) - and same_buf_dst(op, next_op) - and same_chan_type(op, next_op) - and not circular_dep_after_merge(op, next_op) - ): - op.srcs.append( - ( - ChunkRef( - next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size - ), - next_op.step, - ) - ) - merge_op(op, next_op) - tb.ops.remove(next_op) - queue.remove(next_op) - fused = True - break - if fused: - continue - elif op.inst == Instruction.reduce: - fused = False - for next_op in op.next: - if ( - next_op.inst == Instruction.reduce - and same_buf_dst(op, next_op) - and same_chan_type(op, next_op) - and not circular_dep_after_merge(op, next_op) - ): - op.srcs.append( - ( - ChunkRef( - next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size - ), - next_op.step, - ) - ) - merge_op(op, next_op) - tb.ops.remove(next_op) - queue.remove(next_op) - fused = True - break - if fused: - continue - elif op.inst == Instruction.reduce_packet: - fused = False - for next_op in op.next: - if ( - next_op.inst == Instruction.reduce_packet - and same_buf_dst(op, next_op) - and same_chan_type(op, next_op) - and not circular_dep_after_merge(op, next_op) - ): - op.srcs.append( - ( - ChunkRef( - next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size - ), - next_op.step, - ) - ) - merge_op(op, next_op) - tb.ops.remove(next_op) - queue.remove(next_op) - fused = True - break - if fused: - continue - elif op.inst == Instruction.signal: - fused = False + fused = False + inst_type = op.inst + if inst_type in instruction_handlers: for next_op in op.next: - if ( - next_op.inst == Instruction.signal - and same_buf_src(op, next_op) - and same_chan_type(op, next_op) - and not circular_dep_after_merge(op, next_op) - ): - op.dsts.append( - ( - ChunkRef( - next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size - ), - next_op.step, - ) - ) - op.srcs.append( - ( - ChunkRef( - next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size - ), - next_op.step, - ) - ) - merge_op(op, next_op) - tb.ops.remove(next_op) - queue.remove(next_op) - fused = True - break - if fused: - continue - elif op.inst == Instruction.wait: - fused = False - for next_op in op.next: - if ( - next_op.inst == Instruction.wait - and same_buf_dst(op, next_op) - and same_chan_type(op, next_op) - and not circular_dep_after_merge(op, next_op) - ): - op.srcs.append( - ( - ChunkRef( - next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size - ), - next_op.step, - ) - ) - op.dsts.append( - ( - ChunkRef( - next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size - ), - next_op.step, - ) - ) - merge_op(op, next_op) - tb.ops.remove(next_op) - queue.remove(next_op) + same_buf_func = instruction_handlers[inst_type] + if optimizer.try_merge_same_instruction(op, next_op, tb, queue, inst_type, same_buf_func): fused = True break - if fused: - continue + if fused: + continue queue = queue[1:] # rrc(_,_,_,dst,dbuf,di) put(dst,dbuf,di,_,_,_) -> rrcs(_,_,_,_,_,_) @@ -643,7 +557,7 @@ def _get_tb_step(self, rank: int, tb: int): def optimize(self): self._optimize_redundant_signal_wait() - self._optimize_rrc_r_signal_wait() + self._optimize_fuse_same_instruction() self._optimize_rrcs_rs() self._optimize_get_put() diff --git a/msccl/language/mscclpp/ir.py b/msccl/language/mscclpp/ir.py index baee9e6..5aef8cc 100644 --- a/msccl/language/mscclpp/ir.py +++ b/msccl/language/mscclpp/ir.py @@ -10,6 +10,7 @@ Instruction.put, Instruction.put_packet, Instruction.signal, + Instruction.flush, Instruction.copy, Instruction.copy_packet, Instruction.transform_to_packet, diff --git a/msccl/language/mscclpp/optimizer.py b/msccl/language/mscclpp/optimizer.py new file mode 100644 index 0000000..dec1ac7 --- /dev/null +++ b/msccl/language/mscclpp/optimizer.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from msccl.language.instruction_dag import ( + merge_op, + circular_dep_after_merge, + same_chan_type, + same_count, +) +from msccl.language.types import ChunkRef, MscclppInstruction as Instruction, Op, ReplicationPolicy, Threadblock + + +class Optimizer: + def try_merge_same_instruction( + self, op: Op, next_op: Op, tb: Threadblock, queue: list, inst_type: Instruction, same_buf_func: callable + ) -> bool: + """ + Attempts to merge two instruction if conditions are met. + :param op: The current operation. + :param next_op: The next operation to potentially merge with. + :param tb: The thread block containing the operations. + :param queue: The queue of operations. + :param inst_type: The type of the instruction being processed. + :param same_buf_func: The function to check if the buffer is the same (same_buf_dst or same_buf_src). + :return: True if operations are merged, False otherwise. + """ + if ( + next_op.inst == inst_type + and same_buf_func(op, next_op) + and same_count(op, next_op) + and same_chan_type(op, next_op) + and not circular_dep_after_merge(op, next_op) + ): + # Append the source chunks from next_op + op.srcs.append( + ( + ChunkRef( + next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size + ), + next_op.step, + ) + ) + # For 'signal' and 'wait' instructions, append destination chunks too + if inst_type in [Instruction.signal, Instruction.wait, Instruction.flush]: + op.dsts.append( + ( + ChunkRef( + next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size + ), + next_op.step, + ) + ) + # Merge operations + merge_op(op, next_op) + tb.ops.remove(next_op) + queue.remove(next_op) + return True + return False