Skip to content

Commit

Permalink
Add TODO #80 and #81, and simplify the logic for dst_weight
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanbin Hu committed Mar 26, 2021
1 parent 562b231 commit 6764efa
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 121 deletions.
1 change: 1 addition & 0 deletions bluefog/common/global_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ struct BluefogGlobalState {
// Threshold for Tensor Fusion. All tensors that occupy memory beyond this
// threshold will be fused.
int64_t tensor_fusion_threshold = 8 * 1024 * 1024;
int64_t tensor_fusion_threshold_for_dst_weight = 8 * 1024 * 1024;
FusionBufferManager fusion_buffer;

// Because setting topology happens in the main thread instead of communication
Expand Down
1 change: 1 addition & 0 deletions bluefog/common/mpi_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ class MPIContext {
MPI_Datatype mpi_float16_t;
MPI_Op mpi_float16_sum;

// TODO(hhb): #80 We should use a common context for MPI and NCCL controller for CUDA usage.
#if HAVE_CUDA
// CUDA Stream
cudaStream_t stream;
Expand Down
134 changes: 43 additions & 91 deletions bluefog/common/mpi_controller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -459,16 +459,10 @@ void MPIController::NeighborAllreduce(TensorTableEntry& entry) {

if (!entry.is_hierarchical) {
if (!entry.dynamic_neighbors_enabled) {
int ret_code = MPI_Neighbor_allgather(
MPICHECK(MPI_Neighbor_allgather(
sendbuf, num_elements, mpi_ctx_.GetMPIDataType(entry.tensor),
buffer_data, num_elements, mpi_ctx_.GetMPIDataType(entry.output),
mpi_ctx_.GetMPICommunicator(Communicator::GRAPH));
if (ret_code != MPI_SUCCESS) {
throw std::runtime_error(
"MPI_Neighbor_allreduce (through neighbor_allgather) failed, see "
"MPI "
"output for details.");
}
mpi_ctx_.GetMPICommunicator(Communicator::GRAPH)));
} else {
int nsend = entry.send_neighbors->size();
int nrecv = entry.recv_neighbors->size();
Expand All @@ -490,17 +484,10 @@ void MPIController::NeighborAllreduce(TensorTableEntry& entry) {
for (int i = 0; i < nrecv; ++i) {
void* recvbuf = (void*)(static_cast<const char*>(entry.output->data()) +
num_elements * i * element_size);
int ret_code = MPI_Irecv(
recvbuf, num_elements, mpi_ctx_.GetMPIDataType(entry.output),
entry.recv_neighbors->at(i),
mpi_ctx_.rank_ + entry.recv_neighbors->at(i),
mpi_ctx_.GetMPICommunicator(Communicator::GRAPH),
&requests[i + nsend]);
if (ret_code != MPI_SUCCESS) {
throw std::runtime_error(
"MPI_Irecv (for dynamic neighbor_allreduce) failed, see MPI "
"output for details.");
}
MPICHECK(MPI_Irecv(recvbuf, num_elements,
mpi_ctx_.GetMPIDataType(entry.output), entry.recv_neighbors->at(i),
/*tag=*/mpi_ctx_.rank_ + entry.recv_neighbors->at(i),
mpi_ctx_.GetMPICommunicator(Communicator::GLOBAL), &requests[i + nsend]));
}

if (entry.dst_weighting_enabled) {
Expand All @@ -509,33 +496,16 @@ void MPIController::NeighborAllreduce(TensorTableEntry& entry) {
std::this_thread::sleep_for(std::chrono::nanoseconds(100));
}
}
for (int i = 0; i < nsend; ++i) {
int ret_code = MPI_Isend(
weighted_tensors[i].get()->data(), num_elements,
mpi_ctx_.GetMPIDataType(entry.tensor), entry.send_neighbors->at(i),
mpi_ctx_.rank_ + entry.send_neighbors->at(i),
mpi_ctx_.GetMPICommunicator(Communicator::GRAPH), &requests[i]);
if (ret_code != MPI_SUCCESS) {
throw std::runtime_error(
"MPI_Isend (for dynamic neighbor_allreduce) failed, see MPI "
"output for details.");
}
}
} else {
for (int i = 0; i < nsend; ++i) {
int ret_code = MPI_Isend(
sendbuf, num_elements, mpi_ctx_.GetMPIDataType(entry.tensor),
entry.send_neighbors->at(i),
mpi_ctx_.rank_ + entry.send_neighbors->at(i),
mpi_ctx_.GetMPICommunicator(Communicator::GRAPH), &requests[i]);
if (ret_code != MPI_SUCCESS) {
throw std::runtime_error(
"MPI_Isend (for dynamic neighbor_allreduce) failed, see MPI "
"output for details.");
}
}
}

for (int i = 0; i < nsend; ++i) {
const void* buffer_send = sendbuf;
if (entry.dst_weighting_enabled)
buffer_send = weighted_tensors[i].get()->data();
MPICHECK(MPI_Isend(buffer_send, num_elements,
mpi_ctx_.GetMPIDataType(entry.tensor), entry.send_neighbors->at(i),
/*tag=*/mpi_ctx_.rank_ + entry.send_neighbors->at(i),
mpi_ctx_.GetMPICommunicator(Communicator::GLOBAL), &requests[i]));
}
MPI_Waitall(nsend + nrecv, requests.data(), statuses.data());
error_message =
GenerateNeighborExchangeErrorMessage(statuses, nsend, nrecv);
Expand All @@ -551,6 +521,11 @@ void MPIController::NeighborAllreduce(TensorTableEntry& entry) {
"Local size is smaller than 2, in this case, you should use "
"neighbor_allreduce instead of hierarchical_neighbor_allreduce.");
}
if (entry.dst_weighting_enabled) {
throw std::runtime_error(
"Under hierarchical neighbor_allreduce, argument "
"dst_weight should not be enabled for now.");
}
// 1. In-place allreduce
MPI_Allreduce(MPI_IN_PLACE, (void*)sendbuf, num_elements,
mpi_ctx_.GetMPIDataType(entry.tensor), MPI_SUM,
Expand Down Expand Up @@ -644,8 +619,8 @@ void MPIController::NeighborAllreduce(std::vector<TensorTableEntry>& entries) {
if (first_entry.enable_topo_check && first_entry.dynamic_neighbors_enabled) {
if (first_entry.is_hierarchical) {
// TODO: support check.
BFLOG(INFO) << "Request to check topology for hierarchical neighbor "
<< "allreduce ops but it is not supported yet.";
BFLOG(WARNING) << "Request to check topology for hierarchical neighbor "
<< "allreduce ops but it is not supported yet.";
}
is_topo_check_fail = CheckNeighborSendRecvPattern(
first_entry.send_neighbors.get(), first_entry.recv_neighbors.get(), first_entry.tensor_name,
Expand Down Expand Up @@ -709,15 +684,10 @@ void MPIController::NeighborAllreduce(std::vector<TensorTableEntry>& entries) {

if (!first_entry.is_hierarchical) {
if (!first_entry.dynamic_neighbors_enabled) {
int ret_code = MPI_Neighbor_allgather(
MPICHECK(MPI_Neighbor_allgather(
fused_input_data, num_elements, mpi_ctx_.GetMPIDataType(first_entry.tensor),
buffer_data, num_elements, mpi_ctx_.GetMPIDataType(first_entry.output),
mpi_ctx_.GetMPICommunicator(Communicator::GRAPH));
if (ret_code != MPI_SUCCESS) {
throw std::runtime_error(
"MPI_Neighbor_allreduce (through neighbor_allgather) failed, see MPI "
"output for details.");
}
mpi_ctx_.GetMPICommunicator(Communicator::GRAPH)));
} else {
int nsend = first_entry.send_neighbors->size();
int nrecv = first_entry.recv_neighbors->size();
Expand All @@ -726,48 +696,25 @@ void MPIController::NeighborAllreduce(std::vector<TensorTableEntry>& entries) {
for (int i = 0; i < nrecv; ++i) {
void* recvbuf =
(void*)((uint8_t*)buffer_data + num_elements * i * element_size);
int ret_code = MPI_Irecv(recvbuf, num_elements,
mpi_ctx_.GetMPIDataType(first_entry.output),
first_entry.recv_neighbors->at(i),
/*tag=*/mpi_ctx_.rank_ + first_entry.recv_neighbors->at(i),
mpi_ctx_.GetMPICommunicator(Communicator::GRAPH),
&requests[i + nsend]);
if (ret_code != MPI_SUCCESS) {
throw std::runtime_error(
"MPI_Irecv (for dynamic neighbor_allreduce) failed, see MPI output "
"for details.");
}
MPICHECK(MPI_Irecv(recvbuf, num_elements,
mpi_ctx_.GetMPIDataType(first_entry.output), first_entry.recv_neighbors->at(i),
/*tag=*/mpi_ctx_.rank_ + first_entry.recv_neighbors->at(i),
mpi_ctx_.GetMPICommunicator(Communicator::GLOBAL), &requests[i + nsend]));
}
if (!first_entry.dst_weighting_enabled) {
for (int i = 0; i < nsend; ++i) {
int ret_code = MPI_Isend(
fused_input_data, num_elements, mpi_ctx_.GetMPIDataType(first_entry.tensor),
first_entry.send_neighbors->at(i),
/*tag=*/mpi_ctx_.rank_ + first_entry.send_neighbors->at(i),
mpi_ctx_.GetMPICommunicator(Communicator::GRAPH), &requests[i]);
if (ret_code != MPI_SUCCESS) {
throw std::runtime_error(
"MPI_Isend (for dynamic neighbor_allreduce) failed, see MPI output "
"for details.");
}
}
} else {
#if HAVE_CUDA
if (first_entry.dst_weighting_enabled && first_entry.device != CPU_DEVICE_ID) {
cudaStreamSynchronize(mpi_ctx_.stream);
}
#endif
for (int i = 0; i < nsend; ++i) {
void* sendbuf =
for (int i = 0; i < nsend; ++i) {
const void* sendbuf = fused_input_data;
if (first_entry.dst_weighting_enabled)
sendbuf =
(void*)((uint8_t*)weighted_fused_input_data + num_elements * i * element_size);
int ret_code = MPI_Isend(sendbuf, num_elements,
mpi_ctx_.GetMPIDataType(first_entry.tensor), first_entry.send_neighbors->at(i),
/*tag=*/mpi_ctx_.rank_ + first_entry.send_neighbors->at(i),
mpi_ctx_.GetMPICommunicator(Communicator::GRAPH), &requests[i]);
if (ret_code != MPI_SUCCESS) {
throw std::runtime_error(
"MPI_Isend (for dynamic neighbor_allreduce) failed, see MPI output "
"for details.");
}
}
MPICHECK(MPI_Isend(sendbuf, num_elements,
mpi_ctx_.GetMPIDataType(first_entry.tensor), first_entry.send_neighbors->at(i),
/*tag=*/mpi_ctx_.rank_ + first_entry.send_neighbors->at(i),
mpi_ctx_.GetMPICommunicator(Communicator::GLOBAL), &requests[i]));
}
MPI_Waitall(nsend + nrecv, requests.data(), statuses.data());
error_message =
Expand All @@ -784,6 +731,11 @@ void MPIController::NeighborAllreduce(std::vector<TensorTableEntry>& entries) {
"Local size is smaller than 2, in this case, you should use "
"neighbor_allreduce instead of hierarchical_neighbor_allreduce.");
}
if (first_entry.dst_weighting_enabled) {
throw std::runtime_error(
"Under hierarchical neighbor_allreduce, argument "
"dst_weight should not be enabled for now.");
}
// 1. In-place allreduce
MPI_Allreduce(MPI_IN_PLACE, (void*)fused_input_data, num_elements,
mpi_ctx_.GetMPIDataType(first_entry.tensor), MPI_SUM,
Expand Down
53 changes: 25 additions & 28 deletions bluefog/common/nccl_controller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -789,23 +789,19 @@ void NCCLController::NeighborAllreduce(TensorTableEntry& entry) {
NCCLCHECK(ncclRecv(recvbuf, num_elements, GetNCCLDataType(entry.tensor),
recv_rank, nccl_ctx_.nccl_comm, nccl_ctx_.stream));
}
if(entry.dst_weighting_enabled)
{
if(entry.dst_weighting_enabled) {
if (ready_event != nullptr) {
while (!ready_event->Ready()) {
std::this_thread::sleep_for(std::chrono::nanoseconds(100));
}
}
for (size_t i = 0; i < entry.send_neighbors->size(); ++i) {
NCCLCHECK(ncclSend(weighted_tensors[i].get()->data(), num_elements,
GetNCCLDataType(entry.tensor), entry.send_neighbors->at(i),
nccl_ctx_.nccl_comm, nccl_ctx_.stream));
}
} else {
for (int send_rank : *entry.send_neighbors) {
NCCLCHECK(ncclSend(sendbuf, num_elements, GetNCCLDataType(entry.tensor),
send_rank, nccl_ctx_.nccl_comm, nccl_ctx_.stream));
}
}
for (size_t i = 0; i < entry.send_neighbors->size(); ++i) {
const void* buffer_send = sendbuf;
if (entry.dst_weighting_enabled)
buffer_send = weighted_tensors[i].get()->data();
NCCLCHECK(ncclSend(buffer_send, num_elements, GetNCCLDataType(entry.tensor),
entry.send_neighbors->at(i), nccl_ctx_.nccl_comm, nccl_ctx_.stream));
}
}
ncclGroupEnd();
Expand All @@ -816,6 +812,11 @@ void NCCLController::NeighborAllreduce(TensorTableEntry& entry) {
"send_machine_neighbors should "
"not be empty.");
}
if (entry.dst_weighting_enabled) {
throw std::runtime_error(
"Under hierarchical neighbor_allreduce, argument "
"dst_weight should not be enabled for now.");
}
// 1. In-place allreduce for all local ranks. Note it is sum, so we need to
// divided by local size at call back stage.
NCCLCHECK(ncclAllReduce(sendbuf, (void*)sendbuf, num_elements,
Expand Down Expand Up @@ -1093,22 +1094,13 @@ void NCCLController::NeighborAllreduce(std::vector<TensorTableEntry>& entries) {
GetNCCLDataType(first_entry.tensor), recv_rank,
nccl_ctx_.nccl_comm, nccl_ctx_.stream));
}
if (!first_entry.dst_weighting_enabled)
{
for (int send_rank : *first_entry.send_neighbors) {
NCCLCHECK(ncclSend(fused_input_data, num_elements,
GetNCCLDataType(first_entry.tensor), send_rank,
nccl_ctx_.nccl_comm, nccl_ctx_.stream));
}
} else {
for (size_t i = 0; i < first_entry.send_neighbors->size(); ++i) {
void* sendbuf =
(void*)((uint8_t*)weighted_fused_input_data + num_elements * i * element_size);
NCCLCHECK(ncclSend(sendbuf, num_elements,
GetNCCLDataType(first_entry.tensor),
first_entry.send_neighbors->at(i),
nccl_ctx_.nccl_comm, nccl_ctx_.stream));
}
for (size_t i = 0; i < first_entry.send_neighbors->size(); ++i) {
const void* sendbuf = fused_input_data;
if (first_entry.dst_weighting_enabled)
sendbuf = (void*)((uint8_t*)weighted_fused_input_data + num_elements * i * element_size);
NCCLCHECK(ncclSend(sendbuf, num_elements, GetNCCLDataType(first_entry.tensor),
first_entry.send_neighbors->at(i),
nccl_ctx_.nccl_comm, nccl_ctx_.stream));
}
}
ncclGroupEnd();
Expand All @@ -1124,6 +1116,11 @@ void NCCLController::NeighborAllreduce(std::vector<TensorTableEntry>& entries) {
"neighbor_allreduce instead of hierarchical_neighbor_allreduce."
);
}
if (first_entry.dst_weighting_enabled) {
throw std::runtime_error(
"Under hierarchical neighbor_allreduce, argument "
"dst_weight should not be enabled for now.");
}

// 1. In-place allreduce for all local ranks. Note it is sum, so we need to
// divided by local size at call back stage.
Expand Down
2 changes: 1 addition & 1 deletion bluefog/common/operations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ void PerformOperationWithFusion(std::vector<TensorTableEntry>& entries) {
Status status_dst_weight = Status::OK();
if (first_entry.dst_weighting_enabled) {
status_dst_weight = bluefog_global.fusion_buffer.InitializeWeightBuffer(
bluefog_global.tensor_fusion_threshold, mpi_context.size_,
bluefog_global.tensor_fusion_threshold_for_dst_weight, mpi_context.size_,
first_entry.device, first_entry.context,
[&]() { timeline.ActivityStartAll(entries, "INIT_WEIGHT_FUSION_BUFFER"); },
[&]() { timeline.ActivityEndAll(entries); });
Expand Down
2 changes: 1 addition & 1 deletion bluefog/torch/mpi_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def neighbor_allreduce_nonblocking(tensor: torch.Tensor, *,
return _neighbor_allreduce_nonblocking(tensor, output, self_weight, src_weights,
dst_weights, enable_topo_check, name=name)


# TODO(hanbinhu) #81 Add dst_weight for hierarchical neighbor allreduce.
def hierarchical_neighbor_allreduce(tensor: torch.Tensor,
self_weight: float = None,
neighbor_machine_weights: Dict[int, float] = None,
Expand Down

0 comments on commit 6764efa

Please sign in to comment.