diff --git a/bluefog/torch/mpi_ops.py b/bluefog/torch/mpi_ops.py index dc508a34..e5e106ae 100644 --- a/bluefog/torch/mpi_ops.py +++ b/bluefog/torch/mpi_ops.py @@ -17,6 +17,7 @@ from contextlib import contextmanager from typing import List, Dict, Union, Optional +import numpy as np import torch from bluefog.torch import mpi_lib # C library @@ -488,10 +489,9 @@ def _neighbor_allreduce_nonblocking(tensor, output, self_weight, src_weights, "enabling dynamic topology.") else: dynamic_neighbors_enabled = True - dst_weighting_enabled = True if isinstance(dst_weights, list): dst_weights = {dst:1.0 for dst in dst_weights} - dst_weighting_enabled = False + dst_weighting_enabled = not np.allclose(list(dst_weights.values()), 1.0) if self_weight is None and src_weights is None: # Implying this is static graph. if is_topo_weighted(): @@ -515,13 +515,8 @@ def _neighbor_allreduce_nonblocking(tensor, output, self_weight, src_weights, raise ValueError("The key of weights should only contain the ranks that belong to " " in-neighbors and self rank.") uniform_weights = 1.0/(len(src_weights)+1) - weighted_average_computation = False - if abs(self_weight - uniform_weights) > 1e-6: - weighted_average_computation = True - for n_weights in src_weights.values(): - if abs(n_weights - uniform_weights) > 1e-6: - weighted_average_computation = True - break + weighted_average_computation = not(np.isclose(self_weight, uniform_weights) and + np.allclose(list(src_weights.values()), uniform_weights)) else: raise ValueError("Arguments self_weight and neighbor_weights have to be presented at " "the same time")