diff --git a/msccl/language/instruction_dag.py b/msccl/language/instruction_dag.py index 5c859d5..60f1867 100755 --- a/msccl/language/instruction_dag.py +++ b/msccl/language/instruction_dag.py @@ -145,10 +145,10 @@ def _read(self, rank, buffer, index, size, op): def _infer_dependencies(self): visited = set() - for slot, ops in self.operations.items(): - if ops in visited: + for _, op in self.operations.items(): + if op in visited: continue - frontier = [ops] + frontier = [op] while len(frontier) > 0: op = frontier[0] if op in visited: @@ -315,6 +315,7 @@ def optimize(self): # Completes metadata for chunk_steps (number of steps from a start op) and priority (number of steps to the last op) def _complete_metadata(self): visited = set() + def dfs(op, cs): # already visited and no need to update chunk_step if op.chunk_step >= cs + 1 and op in visited: @@ -348,10 +349,10 @@ def dfs(op, cs): # recv(src, sbuf, si, _, _, _ ) send(_, _, _, dst, dbuf, di) -> recv_copy_send(src, sbuf, si, dst, dbuf, di) def _optimize_rcs(self): visited = set() - for slot, ops in self.operations.items(): - if ops in visited: + for _, op in self.operations.items(): + if op in visited: continue - frontier = [ops] + frontier = [op] while len(frontier) > 0: op = frontier[0] if op in visited: @@ -382,10 +383,10 @@ def _optimize_rcs(self): def _optimize_rrcs_rrs(self): # RRC/S -> RRS visited = set() - for slot, ops in self.operations.items(): - if ops in visited: + for _, op in self.operations.items(): + if op in visited: continue - frontier = [ops] + frontier = [op] while len(frontier) > 0: op = frontier[0] if op in visited: