From 6cf1163d079be55af7ac85622b48ce0769fd4950 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Fri, 13 Sep 2024 02:23:17 -0700 Subject: [PATCH] Add interleaved replication policy (#14) Add interleaved replication policy --- msccl/language/mscclpp/instruction_dag.py | 79 ++++++++++++----------- 1 file changed, 40 insertions(+), 39 deletions(-) diff --git a/msccl/language/mscclpp/instruction_dag.py b/msccl/language/mscclpp/instruction_dag.py index 82555b1..e5e45ac 100644 --- a/msccl/language/mscclpp/instruction_dag.py +++ b/msccl/language/mscclpp/instruction_dag.py @@ -672,6 +672,8 @@ def get_new_index(rank, buffer, index, size, i): if is_scratch(buffer): buf_instance_len = self.buffers[rank][buffer].instance_size() return buf_instance_len * i + index + elif replication_policy == ReplicationPolicy.interleaved: + return index * instances + i * size return len(self.buffers[rank][buffer]) * i + index def get_instance_ref(ref): @@ -679,44 +681,43 @@ def get_instance_ref(ref): iref = ChunkRef(ref.rank, ref.buffer, iindex, ref.size) return iref - if replication_policy == ReplicationPolicy.duplicated: - for i in range(instances): - # Generate all the threadblocks and ops - for rank, rank_tbs in enumerate(self.tbs): - # rank_channels = self.num_channels[rank] - for tbid, tb in rank_tbs.items(): - itbid = tbid * instances + i - itb = Threadblock(id=itbid) - itb.ops = [None] * len(tb.ops) - for s, op in enumerate(tb.ops): - isrc = get_instance_ref(op.src) - idst = get_instance_ref(op.dst) - idepends = [] - # Note: We don't need the fill out the rest of the metadata since replication is the last optimization - iop = Op( - op.inst, op.rank, isrc, idst, idepends, op.step, itbid, channel_type=op.channel_type - ) - itb.ops[s] = iop - for src, step in op.srcs: - isrc = get_instance_ref(src) - iop.srcs.append((isrc, step)) - for dst, step in op.dsts: - idst = get_instance_ref(dst) - iop.dsts.append((idst, step)) - for chan in tb.channels: - itb.channels.append(chan) - self.instanced_tbs[op.rank][itbid] = itb - - # Redo dependency analysis + for i in range(instances): + # Generate all the threadblocks and ops for rank, rank_tbs in enumerate(self.tbs): + # rank_channels = self.num_channels[rank] for tbid, tb in rank_tbs.items(): - for i in range(instances): - itbid = tbid * instances + i - itb = self.instanced_tbs[rank][itbid] - for op, iop in zip(tb.ops, itb.ops): - iop.depends = [None] * len(op.depends) - for s, dep in enumerate(op.depends): - dep_tbid = dep.tb - dep_itbid = dep_tbid * instances + i - dep_step = dep.step - iop.depends[s] = self.instanced_tbs[op.rank][dep_itbid].ops[dep_step] + itbid = tbid * instances + i + itb = Threadblock(id=itbid) + itb.ops = [None] * len(tb.ops) + for s, op in enumerate(tb.ops): + isrc = get_instance_ref(op.src) + idst = get_instance_ref(op.dst) + idepends = [] + # Note: We don't need the fill out the rest of the metadata since replication is the last optimization + iop = Op( + op.inst, op.rank, isrc, idst, idepends, op.step, itbid, channel_type=op.channel_type + ) + itb.ops[s] = iop + for src, step in op.srcs: + isrc = get_instance_ref(src) + iop.srcs.append((isrc, step)) + for dst, step in op.dsts: + idst = get_instance_ref(dst) + iop.dsts.append((idst, step)) + for chan in tb.channels: + itb.channels.append(chan) + self.instanced_tbs[op.rank][itbid] = itb + + # Redo dependency analysis + for rank, rank_tbs in enumerate(self.tbs): + for tbid, tb in rank_tbs.items(): + for i in range(instances): + itbid = tbid * instances + i + itb = self.instanced_tbs[rank][itbid] + for op, iop in zip(tb.ops, itb.ops): + iop.depends = [None] * len(op.depends) + for s, dep in enumerate(op.depends): + dep_tbid = dep.tb + dep_itbid = dep_tbid * instances + i + dep_step = dep.step + iop.depends[s] = self.instanced_tbs[op.rank][dep_itbid].ops[dep_step]