Skip to content

Commit

Permalink
Make src_weights as std::map, and simplify logic for PerformNeighborA…
Browse files Browse the repository at this point in the history
…llreduceCallback
  • Loading branch information
Hanbin Hu committed Mar 26, 2021
1 parent 2fe9621 commit 562b231
Showing 1 changed file with 17 additions and 46 deletions.
63 changes: 17 additions & 46 deletions bluefog/torch/mpi_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,12 @@ std::function<std::function<void(const Status&)>(std::function<void()>)>

void PerformNeighborAllreduceCallback(::torch::Tensor tensor, ::torch::Tensor output,
double self_weight,
const std::unordered_map<int, double>& src_weights,
const std::vector<int>& src_neighbors,
const std::map<int, double>& src_weights,
bool avg_computation,
bool dynamic_neighbors_enabled,
bool is_hierarchical) {
int src_size = bluefog_neighbor_size();
if (dynamic_neighbors_enabled) src_size = src_neighbors.size();
if (dynamic_neighbors_enabled) src_size = src_weights.size();
if (src_size > 0) {
::torch::Tensor output_buffer = MaybeCopyToTensorBuffer(output);
::torch::Tensor tensor_buffer = MaybeCopyToTensorBuffer(tensor);
Expand All @@ -118,40 +117,17 @@ void PerformNeighborAllreduceCallback(::torch::Tensor tensor, ::torch::Tensor ou

// if avg_computation is set to be False, sum computation will be taken place.
if (avg_computation) {
// 1) For a distributed graph topology, created with
// MPI_Dist_graph_create, the sequence of neighbors in the send and
// receive buffers at each process is defined as the sequence returned
// by MPI_Dist_graph_neighbors for destinations and sources,
// respectively. 2) MPI_Dist_graph_neighbors: If the communicator was
// created with MPI_Dist_graph_create_adjacent then the order of the
// values in sources and destinations is identical to the input that
// was used by the process with the same rank in comm_old in the
// creation call.
int* sources_ptr = nullptr;
if (!dynamic_neighbors_enabled) {
int indgree = 0;
int outdegree = 0;
int* destinations_ptr = nullptr;
bluefog_load_topology(&indgree, sources_ptr, &outdegree,
destinations_ptr);
}
auto output_reduced = output_buffer.slice(0, 0, first_dim);
for (int i = 0; i < src_size; i++) {
double weight = 0.0;
int src_rank;
if (!dynamic_neighbors_enabled) src_rank = *(sources_ptr+i);
else src_rank = src_neighbors[i];
auto it = src_weights.find(src_rank);
if (it != src_weights.end()) {
weight = it->second;
}

int i = 0;
for (auto kv : src_weights) {
double weight = kv.second;
if (i == 0) {
output_reduced.mul_(weight);
} else {
output_reduced.add_(
output_buffer.slice(0, i * first_dim, (i + 1) * first_dim), weight);
}
++i;
}
output_buffer.resize_(shape_vector);
output_buffer.add_(tensor_buffer, self_weight);
Expand Down Expand Up @@ -431,10 +407,12 @@ int DoNeighborAllreduce(::torch::Tensor tensor, ::torch::Tensor output,

auto callback_wrapper = GetCallbackWrapper(handle, timeline_ptr, op_name, tid);

std::vector<int> src_neighbors;
std::map<int, double> src_weights_map;
for (auto kv : src_weights)
src_weights_map.insert(kv);
std::vector<int> src_neighbors;
for (auto kv : src_weights_map)
src_neighbors.push_back(kv.first);
std::sort(src_neighbors.begin(), src_neighbors.end());

std::vector<int> dst_neighbors;
for (auto kv : dst_weights)
Expand Down Expand Up @@ -462,16 +440,12 @@ int DoNeighborAllreduce(::torch::Tensor tensor, ::torch::Tensor output,
bf_src_neighbors, bf_dst_neighbors, bf_dst_weights_vec,
dynamic_neighbors_enabled, dst_weighting_enabled, is_hierarchical,
enable_topo_check, op_name, CPU_DEVICE_ID,
callback_wrapper([self_weight, src_weights, avg_computation,
cpu_output, tensor, src_neighbors,
dynamic_neighbors_enabled, is_hierarchical, output,
device]() mutable {
callback_wrapper([self_weight, src_weights_map, avg_computation, cpu_output, tensor,
dynamic_neighbors_enabled, is_hierarchical, output, device]() mutable {
with_device device_guard(device);
output.copy_(cpu_output);
PerformNeighborAllreduceCallback(tensor, output, self_weight,
src_weights, src_neighbors,
avg_computation,
dynamic_neighbors_enabled,
PerformNeighborAllreduceCallback(tensor, output, self_weight, src_weights_map,
avg_computation, dynamic_neighbors_enabled,
is_hierarchical);
}));

Expand All @@ -486,13 +460,10 @@ int DoNeighborAllreduce(::torch::Tensor tensor, ::torch::Tensor output,
bf_src_neighbors, bf_dst_neighbors, bf_dst_weights_vec,
dynamic_neighbors_enabled, dst_weighting_enabled,
is_hierarchical, enable_topo_check, op_name, device,
callback_wrapper([self_weight, src_weights, avg_computation,
src_neighbors, dynamic_neighbors_enabled,
callback_wrapper([self_weight, src_weights_map, avg_computation, dynamic_neighbors_enabled,
is_hierarchical, tensor, output]() mutable {
PerformNeighborAllreduceCallback(tensor, output, self_weight,
src_weights, src_neighbors,
avg_computation,
dynamic_neighbors_enabled,
PerformNeighborAllreduceCallback(tensor, output, self_weight, src_weights_map,
avg_computation, dynamic_neighbors_enabled,
is_hierarchical);
}));
ThrowIfError(enqueue_result);
Expand Down

0 comments on commit 562b231

Please sign in to comment.