Skip to content

Commit

Permalink
Better consideration for weight buffer size
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanbin Hu committed Mar 21, 2021
1 parent 9f2f55d commit d7a9310
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
2 changes: 2 additions & 0 deletions bluefog/common/operations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,8 @@ void PerformOperationWithFusion(std::vector<TensorTableEntry>& entries) {
[&]() { timeline.ActivityStartAll(entries, "INIT_FUSION_BUFFER"); },
[&]() { timeline.ActivityEndAll(entries); });

// As the dst_weight requires extra memory to scale the tensor for each destination, therefore,
// extra memory is required.
Status status_dst_weight = Status::OK();
if (first_entry.dst_weighting_enabled) {
status_dst_weight = bluefog_global.fusion_buffer.InitializeWeightBuffer(
Expand Down
7 changes: 4 additions & 3 deletions bluefog/common/tensor_queue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,19 +155,20 @@ std::shared_ptr<PersistentBuffer> FusionBufferManager::GetBuffer(int device) {
}

Status FusionBufferManager::InitializeWeightBuffer(
int64_t threshold, int mpi_size, int device, std::shared_ptr<OpContext> context,
int64_t threshold, int world_size, int device, std::shared_ptr<OpContext> context,
std::function<void()> on_start_init, std::function<void()> on_end_init) {
auto& elem = weight_tensor_fusion_buffers_[device];
auto& buffer = elem.first;
int64_t& size = elem.second;
if (size != threshold*mpi_size) {
// threshold * (world_size-1) is the upper bound for buffer
if (size != threshold*(world_size-1)) {
buffer.reset();
size = 0;
}

if (buffer == nullptr) {
on_start_init();
size = threshold*mpi_size;
size = threshold*(world_size-1);

// Lazily allocate persistent buffer for Tensor Fusion and keep it
// forever per device.
Expand Down
4 changes: 2 additions & 2 deletions bluefog/common/tensor_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ class FusionBufferManager {
//
// Args:
// threshold: Size of the buffer in bytes.
// mpi_size: Size of MPI nodes.
// world_size: Size of MPI nodes.
// device: Device ID to associate the buffer.
// context: Framework used to create the buffer and associate it.
// on_start_init: Callback on starting buffer initialization.
// on_end_init: Callback on completing buffer initialization.
Status InitializeWeightBuffer(int64_t threshold,
int mpi_size,
int world_size,
int device,
std::shared_ptr<OpContext> context,
std::function<void()> on_start_init,
Expand Down

0 comments on commit d7a9310

Please sign in to comment.