Skip to content

Commit

Permalink
Fix perf issue (#8)
Browse files Browse the repository at this point in the history
Reduce DAG travel time complexity from O(n^2) to O(n).
Since _complete_metadata needs to calculate depth, the time complexity is still O(n^2). Just removed unnecessary branches
  • Loading branch information
Binyang2014 authored Aug 15, 2024
1 parent b9b5c3e commit 52a226b
Showing 1 changed file with 36 additions and 8 deletions.
44 changes: 36 additions & 8 deletions msccl/language/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,16 @@ def _read(self, rank, buffer, index, size, op):
op.prev.add(prev_op)

def _infer_dependencies(self):
for slot, ops in self.operations.items():
frontier = [ops]
visited = set()
for _, op in self.operations.items():
if op in visited:
continue
frontier = [op]
while len(frontier) > 0:
op = frontier[0]
if op in visited:
frontier = frontier[1:]
continue
# Dependencies for every op is the same as the ops that are stored in prev
# Filter out dependencies that are satisified by tbs executing ops sequentially
# If multiple dependent ops from the same tb keep the one that happens last
Expand All @@ -158,6 +164,7 @@ def _infer_dependencies(self):
if tb not in depends or dep_op.step > depends[tb].step:
depends[tb] = dep_op
op.depends = list(depends.values())
visited.add(op)
frontier = frontier[1:] + op.next

# Convert local scratch buffers to index into one global scratch buffer
Expand Down Expand Up @@ -307,7 +314,12 @@ def optimize(self):

# Completes metadata for chunk_steps (number of steps from a start op) and priority (number of steps to the last op)
def _complete_metadata(self):
visited = set()

def dfs(op, cs):
# already visited and no need to update chunk_step
if op.chunk_step >= cs + 1 and op in visited:
return
op.chunk_step = max(op.chunk_step, cs + 1)

if len(op.next) == 0 and op.recv_match is None:
Expand All @@ -316,12 +328,14 @@ def dfs(op, cs):
for o in op.next:
dfs(o, op.chunk_step)
# Priority = +1 of the highest priority child
if len(op.next) > 0:
if len(op.next) > 0 and op not in visited:
highest_next_priority = max([x.priority + 1 for x in op.next])
op.priority = max(highest_next_priority, op.priority)
if op.is_send():
dfs(op.recv_match, op.chunk_step)
op.priority = max(op.priority, op.recv_match.priority + 1)
if op not in visited:
op.priority = max(op.priority, op.recv_match.priority + 1)
visited.add(op)

for chunk, op in self.operations.items():
if op.inst == Instruction.start:
Expand All @@ -334,10 +348,16 @@ def dfs(op, cs):
# recv-copy-send
# recv(src, sbuf, si, _, _, _ ) send(_, _, _, dst, dbuf, di) -> recv_copy_send(src, sbuf, si, dst, dbuf, di)
def _optimize_rcs(self):
for slot, ops in self.operations.items():
frontier = [ops]
visited = set()
for _, op in self.operations.items():
if op in visited:
continue
frontier = [op]
while len(frontier) > 0:
op = frontier[0]
if op in visited:
frontier = frontier[1:]
continue
for next_op in op.next:
if (
op.inst == Instruction.recv
Expand All @@ -353,6 +373,7 @@ def _optimize_rcs(self):
op.recv_match = next_op.recv_match
remove_op(next_op)
break
visited.add(op)
frontier = frontier[1:] + op.next

# recv-reduce-send - A rrc followed by a send that gets overwritten
Expand All @@ -361,10 +382,16 @@ def _optimize_rcs(self):
# rrc(src, sbuf, si, ...) send(_, _, _, dst, dbuf, di)
def _optimize_rrcs_rrs(self):
# RRC/S -> RRS
for slot, ops in self.operations.items():
frontier = [ops]
visited = set()
for _, op in self.operations.items():
if op in visited:
continue
frontier = [op]
while len(frontier) > 0:
op = frontier[0]
if op in visited:
frontier = frontier[1:]
continue
if len(op.next) == 1:
next_op = op.next[0]
if len(next_op.next) == 1:
Expand Down Expand Up @@ -395,6 +422,7 @@ def _optimize_rrcs_rrs(self):
next_op.recv_match.send_match = op
op.recv_match = next_op.recv_match
remove_op(next_op)
visited.add(op)
frontier = frontier[1:] + op.next

# Automatically replicates the algorithm instance number of times
Expand Down

0 comments on commit 52a226b

Please sign in to comment.