Skip to content

Commit

Permalink
Simpler logic for dst_weighting_enabled and weighted_average_computation
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanbin Hu committed Mar 21, 2021
1 parent e5f8722 commit 9f2f55d
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions bluefog/torch/mpi_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from contextlib import contextmanager
from typing import List, Dict, Union, Optional

import numpy as np
import torch

from bluefog.torch import mpi_lib # C library
Expand Down Expand Up @@ -488,10 +489,9 @@ def _neighbor_allreduce_nonblocking(tensor, output, self_weight, src_weights,
"enabling dynamic topology.")
else:
dynamic_neighbors_enabled = True
dst_weighting_enabled = True
if isinstance(dst_weights, list):
dst_weights = {dst:1.0 for dst in dst_weights}
dst_weighting_enabled = False
dst_weighting_enabled = not np.allclose(list(dst_weights.values()), 1.0)
if self_weight is None and src_weights is None:
# Implying this is static graph.
if is_topo_weighted():
Expand All @@ -515,13 +515,8 @@ def _neighbor_allreduce_nonblocking(tensor, output, self_weight, src_weights,
raise ValueError("The key of weights should only contain the ranks that belong to "
" in-neighbors and self rank.")
uniform_weights = 1.0/(len(src_weights)+1)
weighted_average_computation = False
if abs(self_weight - uniform_weights) > 1e-6:
weighted_average_computation = True
for n_weights in src_weights.values():
if abs(n_weights - uniform_weights) > 1e-6:
weighted_average_computation = True
break
weighted_average_computation = not(np.isclose(self_weight, uniform_weights) and
np.allclose(list(src_weights.values()), uniform_weights))
else:
raise ValueError("Arguments self_weight and neighbor_weights have to be presented at "
"the same time")
Expand Down

0 comments on commit 9f2f55d

Please sign in to comment.