From b314728a1b62afda14059caf69959e8edc9f94f2 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 14 Aug 2024 11:07:50 +0000 Subject: [PATCH] ix perf issue --- msccl/language/instruction_dag.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/msccl/language/instruction_dag.py b/msccl/language/instruction_dag.py index 451ea7e..5c859d5 100755 --- a/msccl/language/instruction_dag.py +++ b/msccl/language/instruction_dag.py @@ -144,10 +144,16 @@ def _read(self, rank, buffer, index, size, op): op.prev.add(prev_op) def _infer_dependencies(self): + visited = set() for slot, ops in self.operations.items(): + if ops in visited: + continue frontier = [ops] while len(frontier) > 0: op = frontier[0] + if op in visited: + frontier = frontier[1:] + continue # Dependencies for every op is the same as the ops that are stored in prev # Filter out dependencies that are satisified by tbs executing ops sequentially # If multiple dependent ops from the same tb keep the one that happens last @@ -158,6 +164,7 @@ def _infer_dependencies(self): if tb not in depends or dep_op.step > depends[tb].step: depends[tb] = dep_op op.depends = list(depends.values()) + visited.add(op) frontier = frontier[1:] + op.next # Convert local scratch buffers to index into one global scratch buffer @@ -307,7 +314,11 @@ def optimize(self): # Completes metadata for chunk_steps (number of steps from a start op) and priority (number of steps to the last op) def _complete_metadata(self): + visited = set() def dfs(op, cs): + # already visited and no need to update chunk_step + if op.chunk_step >= cs + 1 and op in visited: + return op.chunk_step = max(op.chunk_step, cs + 1) if len(op.next) == 0 and op.recv_match is None: @@ -316,12 +327,14 @@ def dfs(op, cs): for o in op.next: dfs(o, op.chunk_step) # Priority = +1 of the highest priority child - if len(op.next) > 0: + if len(op.next) > 0 and op not in visited: highest_next_priority = max([x.priority + 1 for x in op.next]) op.priority = max(highest_next_priority, op.priority) if op.is_send(): dfs(op.recv_match, op.chunk_step) - op.priority = max(op.priority, op.recv_match.priority + 1) + if op not in visited: + op.priority = max(op.priority, op.recv_match.priority + 1) + visited.add(op) for chunk, op in self.operations.items(): if op.inst == Instruction.start: @@ -334,10 +347,16 @@ def dfs(op, cs): # recv-copy-send # recv(src, sbuf, si, _, _, _ ) send(_, _, _, dst, dbuf, di) -> recv_copy_send(src, sbuf, si, dst, dbuf, di) def _optimize_rcs(self): + visited = set() for slot, ops in self.operations.items(): + if ops in visited: + continue frontier = [ops] while len(frontier) > 0: op = frontier[0] + if op in visited: + frontier = frontier[1:] + continue for next_op in op.next: if ( op.inst == Instruction.recv @@ -353,6 +372,7 @@ def _optimize_rcs(self): op.recv_match = next_op.recv_match remove_op(next_op) break + visited.add(op) frontier = frontier[1:] + op.next # recv-reduce-send - A rrc followed by a send that gets overwritten @@ -361,10 +381,16 @@ def _optimize_rcs(self): # rrc(src, sbuf, si, ...) send(_, _, _, dst, dbuf, di) def _optimize_rrcs_rrs(self): # RRC/S -> RRS + visited = set() for slot, ops in self.operations.items(): + if ops in visited: + continue frontier = [ops] while len(frontier) > 0: op = frontier[0] + if op in visited: + frontier = frontier[1:] + continue if len(op.next) == 1: next_op = op.next[0] if len(next_op.next) == 1: @@ -395,6 +421,7 @@ def _optimize_rrcs_rrs(self): next_op.recv_match.send_match = op op.recv_match = next_op.recv_match remove_op(next_op) + visited.add(op) frontier = frontier[1:] + op.next # Automatically replicates the algorithm instance number of times