Skip to content

Commit

Permalink
Merge branch 'main' into binyli/num_threads
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 authored Sep 13, 2024
2 parents cfd02f7 + 6cf1163 commit a242ec8
Showing 1 changed file with 40 additions and 39 deletions.
79 changes: 40 additions & 39 deletions msccl/language/mscclpp/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,51 +672,52 @@ def get_new_index(rank, buffer, index, size, i):
if is_scratch(buffer):
buf_instance_len = self.buffers[rank][buffer].instance_size()
return buf_instance_len * i + index
elif replication_policy == ReplicationPolicy.interleaved:
return index * instances + i * size
return len(self.buffers[rank][buffer]) * i + index

def get_instance_ref(ref):
iindex = get_new_index(ref.rank, ref.buffer, ref.index, ref.size, i)
iref = ChunkRef(ref.rank, ref.buffer, iindex, ref.size)
return iref

if replication_policy == ReplicationPolicy.duplicated:
for i in range(instances):
# Generate all the threadblocks and ops
for rank, rank_tbs in enumerate(self.tbs):
# rank_channels = self.num_channels[rank]
for tbid, tb in rank_tbs.items():
itbid = tbid * instances + i
itb = Threadblock(id=itbid)
itb.ops = [None] * len(tb.ops)
for s, op in enumerate(tb.ops):
isrc = get_instance_ref(op.src)
idst = get_instance_ref(op.dst)
idepends = []
# Note: We don't need the fill out the rest of the metadata since replication is the last optimization
iop = Op(
op.inst, op.rank, isrc, idst, idepends, op.step, itbid, channel_type=op.channel_type
)
itb.ops[s] = iop
for src, step in op.srcs:
isrc = get_instance_ref(src)
iop.srcs.append((isrc, step))
for dst, step in op.dsts:
idst = get_instance_ref(dst)
iop.dsts.append((idst, step))
for chan in tb.channels:
itb.channels.append(chan)
self.instanced_tbs[op.rank][itbid] = itb

# Redo dependency analysis
for i in range(instances):
# Generate all the threadblocks and ops
for rank, rank_tbs in enumerate(self.tbs):
# rank_channels = self.num_channels[rank]
for tbid, tb in rank_tbs.items():
for i in range(instances):
itbid = tbid * instances + i
itb = self.instanced_tbs[rank][itbid]
for op, iop in zip(tb.ops, itb.ops):
iop.depends = [None] * len(op.depends)
for s, dep in enumerate(op.depends):
dep_tbid = dep.tb
dep_itbid = dep_tbid * instances + i
dep_step = dep.step
iop.depends[s] = self.instanced_tbs[op.rank][dep_itbid].ops[dep_step]
itbid = tbid * instances + i
itb = Threadblock(id=itbid)
itb.ops = [None] * len(tb.ops)
for s, op in enumerate(tb.ops):
isrc = get_instance_ref(op.src)
idst = get_instance_ref(op.dst)
idepends = []
# Note: We don't need the fill out the rest of the metadata since replication is the last optimization
iop = Op(
op.inst, op.rank, isrc, idst, idepends, op.step, itbid, channel_type=op.channel_type
)
itb.ops[s] = iop
for src, step in op.srcs:
isrc = get_instance_ref(src)
iop.srcs.append((isrc, step))
for dst, step in op.dsts:
idst = get_instance_ref(dst)
iop.dsts.append((idst, step))
for chan in tb.channels:
itb.channels.append(chan)
self.instanced_tbs[op.rank][itbid] = itb

# Redo dependency analysis
for rank, rank_tbs in enumerate(self.tbs):
for tbid, tb in rank_tbs.items():
for i in range(instances):
itbid = tbid * instances + i
itb = self.instanced_tbs[rank][itbid]
for op, iop in zip(tb.ops, itb.ops):
iop.depends = [None] * len(op.depends)
for s, dep in enumerate(op.depends):
dep_tbid = dep.tb
dep_itbid = dep_tbid * instances + i
dep_step = dep.step
iop.depends[s] = self.instanced_tbs[op.rank][dep_itbid].ops[dep_step]

0 comments on commit a242ec8

Please sign in to comment.