From b494b753cb719ca164928471969fc74864cb8704 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Fri, 19 Apr 2024 15:06:42 +0000 Subject: [PATCH] OPT --- .../allreduce_a100_allpairs_sm_mscclpp_get.py | 6 ++-- msccl/language/ir_mscclpp.py | 6 ++-- msccl/language/rank_dag.py | 30 +++++++++++++++++-- 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py index efca998..49f3606 100644 --- a/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py +++ b/examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py @@ -35,7 +35,8 @@ def allreduce_allpairs(gpus, instances, protocol): if rank != nghr: c.wait(nghr, Buffer.input, index + tb, recvtb=tb) # reduce the chunks - for nghr in range(size): + for i in range(size): + nghr = (rank + i) % size if rank != nghr: c.reduce(chunk(nghr, Buffer.input, index + tb), recvtb=tb) for nghr in range(size): @@ -50,7 +51,8 @@ def allreduce_allpairs(gpus, instances, protocol): index = nghr * size c = chunk(rank, Buffer.input, index + tb) c.wait(nghr, Buffer.input, index + tb, recvtb=tb) - for nghr in range(size): + for i in range(size): + nghr = (rank + i) % size index = nghr * size if rank != nghr: c = chunk(rank, Buffer.input, index + tb) diff --git a/msccl/language/ir_mscclpp.py b/msccl/language/ir_mscclpp.py index 7a57a51..1677399 100644 --- a/msccl/language/ir_mscclpp.py +++ b/msccl/language/ir_mscclpp.py @@ -181,6 +181,7 @@ def remove_empty_fields(d): dst_channel_ids = [] src_channel_ids = [] srcs = [] + dsts = [] src = None dst = None if op.tb == -1: @@ -239,10 +240,10 @@ def remove_empty_fields(d): src = op.src elif op.inst == Instruction.get: src_channel_ids = get_channel_ids( - [op.src], tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type + op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type ) i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} - dst = op.dst + dsts = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.dsts)) elif op.inst == Instruction.copy or op.inst == Instruction.copy_packet: src = op.src dst = op.dst @@ -255,6 +256,7 @@ def remove_empty_fields(d): "o_cids": dst_channel_ids, "src": src.rank if src else None, "srcs": srcs if srcs else None, + "dsts": dsts if dsts else None, "srcbuff": src.buffer.value if src and src.buffer else None, "srcoff": src.index if src else None, "dst": dst.rank if dst else None, diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index 07daada..b66475a 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -230,6 +230,8 @@ def add_get(self, rank, send_ref, recv_ref, tb, ch_type): index = recv_ref.index size = recv_ref.size self._write(rank, buffer, index, size, op) + op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) + op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step)) return op # InstructionDAG - adds a signal node. @@ -328,7 +330,7 @@ def complete_channels(self): chans.add(chan) tb.channels = list(chans) - def _optimize_redandant_signal_wait(self): + def _optimize_redundant_signal_wait(self): # For packet ops, we can remove signal/wait for rank, rank_tbs in enumerate(self.tbs): for tbid, tb in rank_tbs.items(): @@ -489,6 +491,29 @@ def _optimize_rrcs_rs(self): continue queue = queue[1:] + # get(src, sbuf. si, dst, dbuf, di) get(src, sbuf, si, dst, dbuf, di) -> get(list[src,sbuf,si], list[dst,dbuf,di]) + # put(src, sbuf, si, dst, dbuf, di) put(src, sbuf, si, dst, dbuf, di) -> put(list[src,sbuf,si], list[dst,dbuf,di]) + def _optimize_get_put(self): + for rank, rank_tbs in enumerate(self.tbs): + for tbid, tb in rank_tbs.items(): + queue = list(tb.ops) + while len(queue) > 0: + op = queue[0] + if op.inst == Instruction.get: + fused = False + if len(queue) > 1: + seq_op = queue[1] + if seq_op.inst == Instruction.get and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op) and same_count(op, seq_op): + op.dsts.append((ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), seq_op.step)) + op.srcs.append((ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), seq_op.step)) + merge_op(op, seq_op) + tb.ops.remove(seq_op) + queue.remove(seq_op) + fused = True + if fused: + continue + queue = queue[1:] + # For signal/wait ops, if they are independent of other operations and no other operations in between, # then merge them into a single signal/wait op # wait(src,sbuf,si,_,_,_) wait(src,sbuf,si,_,_,_) -> wait(list[src,sbuf,si],_,_,_,_]) @@ -529,9 +554,10 @@ def _parallel_signal_wait(self): queue = queue[1:] def optimize_mscclpp(self): - self._optimize_redandant_signal_wait() + self._optimize_redundant_signal_wait() self._optimize_rrc_r_signal_wait() self._optimize_rrcs_rs() + self._optimize_get_put() self._parallel_signal_wait()