Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Sep 23, 2024
1 parent 399df8a commit 5528f3b
Showing 1 changed file with 8 additions and 25 deletions.
33 changes: 8 additions & 25 deletions msccl/language/mscclpp/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,19 +326,24 @@ def _optimize_rrcs_rs(self):
continue
queue = queue[1:]

# merge ops which are independent of other operations and no other operations in between
# get(src, sbuf. si, dst, dbuf, di) get(src, sbuf, si, dst, dbuf, di) -> get(list[src,sbuf,si], list[dst,dbuf,di])
# put(src, sbuf, si, dst, dbuf, di) put(src, sbuf, si, dst, dbuf, di) -> put(list[src,sbuf,si], list[dst,dbuf,di])
# putWithSignal/putWithSignalAndFlush(src, sbuf, si, dst, dbuf, di)
# putWithSignal/putWithSignalAndFlush(src, sbuf, si, dst, dbuf, di)
# -> putWithSignal/putWithSignalAndFlush(list[src,sbuf,si], list[dst,dbuf,di])
def _optimize_get_put(self):
# wait(src,sbuf,si,_,_,_) wait(src,sbuf,si,_,_,_) -> wait(list[src,sbuf,si],_,_,_,_])
def _compact_instructions(self):
optimizer = InstructionOptimizer()
campactable_inst = [
Instruction.get,
Instruction.put,
Instruction.put_packet,
Instruction.put_with_signal,
Instruction.put_with_signal_and_flush,
Instruction.signal,
Instruction.flush,
Instruction.wait,
]
for _, rank_tbs in enumerate(self.tbs):
for _, tb in rank_tbs.items():
Expand All @@ -347,28 +352,8 @@ def _optimize_get_put(self):
op = queue[0]
fused = False
if op.inst in campactable_inst:
if len(queue) > 1:
fused = optimizer.try_compact_instructions(op, tb, queue, op.inst, same_src_dst_buffer_type)

if fused:
continue
queue = queue[1:]

# For signal/wait ops, if they are independent of other operations and no other operations in between,
# then merge them into a single signal/wait/flush op
# wait(src,sbuf,si,_,_,_) wait(src,sbuf,si,_,_,_) -> wait(list[src,sbuf,si],_,_,_,_])
def _compact_signal_flush_wait(self):
optimizer = InstructionOptimizer()
for rank, rank_tbs in enumerate(self.tbs):
for tbid, tb in rank_tbs.items():
if tbid == -1:
continue
queue = list(tb.ops)
while len(queue) > 0:
op = queue[0]
fused = False
if op.inst == Instruction.signal or op.inst == Instruction.wait or op.inst == Instruction.flush:
fused = optimizer.try_compact_instructions(op, tb, queue, op.inst, same_src_dst_buffer_type)

if fused:
continue
queue = queue[1:]
Expand All @@ -386,9 +371,7 @@ def optimize(self):
self._remove_redundant_signal_wait()
self._fuse_same_instructions()
self._optimize_rrcs_rs()
self._optimize_get_put()

self._compact_signal_flush_wait()
self._compact_instructions()

def replicate(self, instances: int, replication_policy: ReplicationPolicy):
# update op step
Expand Down

0 comments on commit 5528f3b

Please sign in to comment.