Skip to content

Commit

Permalink
OPT
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Apr 19, 2024
1 parent 42c4a7d commit b494b75
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 6 deletions.
6 changes: 4 additions & 2 deletions examples/mscclang/allreduce_a100_allpairs_sm_mscclpp_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def allreduce_allpairs(gpus, instances, protocol):
if rank != nghr:
c.wait(nghr, Buffer.input, index + tb, recvtb=tb)
# reduce the chunks
for nghr in range(size):
for i in range(size):
nghr = (rank + i) % size
if rank != nghr:
c.reduce(chunk(nghr, Buffer.input, index + tb), recvtb=tb)
for nghr in range(size):
Expand All @@ -50,7 +51,8 @@ def allreduce_allpairs(gpus, instances, protocol):
index = nghr * size
c = chunk(rank, Buffer.input, index + tb)
c.wait(nghr, Buffer.input, index + tb, recvtb=tb)
for nghr in range(size):
for i in range(size):
nghr = (rank + i) % size
index = nghr * size
if rank != nghr:
c = chunk(rank, Buffer.input, index + tb)
Expand Down
6 changes: 4 additions & 2 deletions msccl/language/ir_mscclpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def remove_empty_fields(d):
dst_channel_ids = []
src_channel_ids = []
srcs = []
dsts = []
src = None
dst = None
if op.tb == -1:
Expand Down Expand Up @@ -239,10 +240,10 @@ def remove_empty_fields(d):
src = op.src
elif op.inst == Instruction.get:
src_channel_ids = get_channel_ids(
[op.src], tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type
op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type
)
i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value}
dst = op.dst
dsts = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.dsts))
elif op.inst == Instruction.copy or op.inst == Instruction.copy_packet:
src = op.src
dst = op.dst
Expand All @@ -255,6 +256,7 @@ def remove_empty_fields(d):
"o_cids": dst_channel_ids,
"src": src.rank if src else None,
"srcs": srcs if srcs else None,
"dsts": dsts if dsts else None,
"srcbuff": src.buffer.value if src and src.buffer else None,
"srcoff": src.index if src else None,
"dst": dst.rank if dst else None,
Expand Down
30 changes: 28 additions & 2 deletions msccl/language/rank_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ def add_get(self, rank, send_ref, recv_ref, tb, ch_type):
index = recv_ref.index
size = recv_ref.size
self._write(rank, buffer, index, size, op)
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step))
return op

# InstructionDAG - adds a signal node.
Expand Down Expand Up @@ -328,7 +330,7 @@ def complete_channels(self):
chans.add(chan)
tb.channels = list(chans)

def _optimize_redandant_signal_wait(self):
def _optimize_redundant_signal_wait(self):
# For packet ops, we can remove signal/wait
for rank, rank_tbs in enumerate(self.tbs):
for tbid, tb in rank_tbs.items():
Expand Down Expand Up @@ -489,6 +491,29 @@ def _optimize_rrcs_rs(self):
continue
queue = queue[1:]

# get(src, sbuf. si, dst, dbuf, di) get(src, sbuf, si, dst, dbuf, di) -> get(list[src,sbuf,si], list[dst,dbuf,di])
# put(src, sbuf, si, dst, dbuf, di) put(src, sbuf, si, dst, dbuf, di) -> put(list[src,sbuf,si], list[dst,dbuf,di])
def _optimize_get_put(self):
for rank, rank_tbs in enumerate(self.tbs):
for tbid, tb in rank_tbs.items():
queue = list(tb.ops)
while len(queue) > 0:
op = queue[0]
if op.inst == Instruction.get:
fused = False
if len(queue) > 1:
seq_op = queue[1]
if seq_op.inst == Instruction.get and same_src_dst_buffer_type(op, seq_op) and same_chan_type(op, seq_op) and same_count(op, seq_op):
op.dsts.append((ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), seq_op.step))
op.srcs.append((ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), seq_op.step))
merge_op(op, seq_op)
tb.ops.remove(seq_op)
queue.remove(seq_op)
fused = True
if fused:
continue
queue = queue[1:]

# For signal/wait ops, if they are independent of other operations and no other operations in between,
# then merge them into a single signal/wait op
# wait(src,sbuf,si,_,_,_) wait(src,sbuf,si,_,_,_) -> wait(list[src,sbuf,si],_,_,_,_])
Expand Down Expand Up @@ -529,9 +554,10 @@ def _parallel_signal_wait(self):
queue = queue[1:]

def optimize_mscclpp(self):
self._optimize_redandant_signal_wait()
self._optimize_redundant_signal_wait()
self._optimize_rrc_r_signal_wait()
self._optimize_rrcs_rs()
self._optimize_get_put()

self._parallel_signal_wait()

Expand Down

0 comments on commit b494b75

Please sign in to comment.