Skip to content

Commit

Permalink
Merge pull request #520 from m-a-d-n-e-s-s/evaleev/feature/can-limit-…
Browse files Browse the repository at this point in the history
…max-msg-size-for-broadcast-and-reduce

introduced `WorldGopInterface::{set_,}max_reducebcast_msg_size()`
  • Loading branch information
evaleev authored Jan 25, 2024
2 parents de03265 + 87715d9 commit 8788aea
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 71 deletions.
14 changes: 13 additions & 1 deletion src/madness/world/safempi.h
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,13 @@ namespace SafeMPI {
return result;
}

/// \return the period of repeat of unique tags produces by unique_tag()
static int unique_tag_period() {
const auto min_tag_value = 1024;
const auto max_tag_value = 4094;
return max_tag_value - min_tag_value + 1;
}

/// Returns a unique tag reserved for long-term use (0<tag<1000)

/// Get a tag from this routine for long-term/repeated use.
Expand Down Expand Up @@ -810,7 +817,7 @@ namespace SafeMPI {
MADNESS_MPI_TEST(MPI_Barrier(pimpl->comm));
}

/// Returns a unique tag for temporary use (1023<tag<=4095)
/// Returns a unique tag for temporary use (1023<tag<4095)

/// These tags are intended for one time use to avoid tag
/// collisions with other messages around the same time period.
Expand All @@ -824,6 +831,11 @@ namespace SafeMPI {
return pimpl->unique_tag();
}

/// \return the period of repeat of unique tags produces by unique_tag()
static int unique_tag_period() {
return Impl::unique_tag_period();
}

/// Returns a unique tag reserved for long-term use (0<tag<1000)

/// Get a tag from this routine for long-term/repeated use.
Expand Down
5 changes: 5 additions & 0 deletions src/madness/world/test_world.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1178,10 +1178,15 @@ void test13(World& world) {
void test14(World& world) {

if (world.size() > 1) {
static size_t call_counter = 0;
++call_counter;

const auto n = 1 + std::numeric_limits<int>::max()/sizeof(int);
auto iarray = std::make_unique<int[]>(n);
iarray[0] = -1;
iarray[n-1] = -1;

world.gop.set_max_reducebcast_msg_size(std::numeric_limits<int>::max()/(std::min(10ul,call_counter)));
world.gop.broadcast(iarray.get(), n, 0);

if (world.rank() == 1) {
Expand Down
44 changes: 24 additions & 20 deletions src/madness/world/worldgop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,36 +170,40 @@ namespace madness {
fence();
}

/// Broadcasts bytes from process root while still processing AM & tasks
static void broadcast_impl(void* buf, int nbyte, ProcessID root, bool dowork, Tag bcast_tag, World &world) {
void WorldGopInterface::broadcast(void* buf, size_t nbyte, ProcessID root, bool dowork, Tag bcast_tag) {
if (bcast_tag < 0)
bcast_tag = world_.mpi.unique_tag();
ProcessID parent, child0, child1;
world_.mpi.binary_tree_info(root, parent, child0, child1);
const size_t max_msg_size =
static_cast<size_t>(max_reducebcast_msg_size());

auto broadcast_impl = [&, this](void *buf, int nbyte) {
SafeMPI::Request req0, req1;
ProcessID parent, child0, child1;
world.mpi.binary_tree_info(root, parent, child0, child1);

//print("BCAST TAG", bcast_tag);
// print("BCAST TAG", bcast_tag);

if (parent != -1) {
req0 = world.mpi.Irecv(buf, nbyte, MPI_BYTE, parent, bcast_tag);
World::await(req0, dowork);
req0 = world_.mpi.Irecv(buf, nbyte, MPI_BYTE, parent, bcast_tag);
World::await(req0, dowork);
}

if (child0 != -1) req0 = world.mpi.Isend(buf, nbyte, MPI_BYTE, child0, bcast_tag);
if (child1 != -1) req1 = world.mpi.Isend(buf, nbyte, MPI_BYTE, child1, bcast_tag);
if (child0 != -1)
req0 = world_.mpi.Isend(buf, nbyte, MPI_BYTE, child0, bcast_tag);
if (child1 != -1)
req1 = world_.mpi.Isend(buf, nbyte, MPI_BYTE, child1, bcast_tag);

if (child0 != -1) World::await(req0, dowork);
if (child1 != -1) World::await(req1, dowork);
}
if (child0 != -1)
World::await(req0, dowork);
if (child1 != -1)
World::await(req1, dowork);
};

/// Optimizations can be added for long messages
void WorldGopInterface::broadcast(void* buf, size_t nbyte, ProcessID root, bool dowork, Tag bcast_tag) {
if(bcast_tag < 0)
bcast_tag = world_.mpi.unique_tag();
const size_t int_max = static_cast<size_t>(std::numeric_limits<int>::max());
while (nbyte) {
const int n = static_cast<int>(std::min(int_max, nbyte));
broadcast_impl(buf, n, root, dowork, bcast_tag, world_);
const int n = static_cast<int>(std::min(max_msg_size, nbyte));
broadcast_impl(buf, n);
nbyte -= n;
buf = static_cast<char*>(buf) + n;
buf = static_cast<char *>(buf) + n;
}
}

Expand Down
190 changes: 140 additions & 50 deletions src/madness/world/worldgop.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,11 @@ namespace madness {
/// If native AM interoperates with MPI we probably should map these to MPI.
class WorldGopInterface {
private:
World& world_; ///< MPI interface
World& world_; ///< World object that this is a part of
std::shared_ptr<detail::DeferredCleanup> deferred_; ///< Deferred cleanup object.
bool debug_; ///< Debug mode
bool forbid_fence_=false; ///< forbid calling fence() in case of several active worlds
int max_reducebcast_msg_size_ = std::numeric_limits<int>::max(); ///< maximum size of messages (in bytes) sent by reduce and broadcast

friend class detail::DeferredCleanup;

Expand Down Expand Up @@ -624,11 +625,34 @@ namespace madness {
bool pause_during_epilogue = false,
bool debug = false);

int initial_max_reducebcast_msg_size() {
int result = std::numeric_limits<int>::max();
const auto* initial_max_reducebcast_msg_size_cstr = std::getenv("MAD_MAX_REDUCEBCAST_MSG_SIZE");
if (initial_max_reducebcast_msg_size_cstr) {
result = std::atoi(initial_max_reducebcast_msg_size_cstr);
const auto do_print = SafeMPI::COMM_WORLD.Get_rank() == 0 && !madness::quiet();
if (result<=0) {
if (do_print)
std::cout
<< "!!MADNESS WARNING: Invalid value for environment variable MAD_MAX_REDUCEBCAST_MSG_SIZE.\n"
<< "!!MADNESS WARNING: MAD_MAX_REDUCEBCAST_MSG_SIZE = "
<< result << "\n";
result = std::numeric_limits<int>::max();
}
if(do_print) {
std::cout
<< "MADNESS max msg size for GOP reduce/broadcast set to "
<< result << " bytes.\n";
}
}
return result;
}

public:

// In the World constructor can ONLY rely on MPI and MPI being initialized
WorldGopInterface(World& world) :
world_(world), deferred_(new detail::DeferredCleanup()), debug_(false)
world_(world), deferred_(new detail::DeferredCleanup()), debug_(false), max_reducebcast_msg_size_(initial_max_reducebcast_msg_size())
{ }

~WorldGopInterface() {
Expand All @@ -650,6 +674,26 @@ namespace madness {
forbid_fence_ = value;
return status;
}

/// Set the maximum size of messages (in bytes) sent by reduce and broadcast

/// \param sz the maximum size of messages (in bytes) sent by reduce and broadcast
/// \return the previous maximum size of messages (in bytes) sent by reduce and broadcast
/// \pre `sz>0`
int set_max_reducebcast_msg_size(int sz) {
MADNESS_ASSERT(sz>0);
std::swap(max_reducebcast_msg_size_,sz);
return max_reducebcast_msg_size_;
}


/// Returns the maximum size of messages (in bytes) sent by reduce and broadcast

/// \return the maximum size of messages (in bytes) sent by reduce and broadcast
int max_reducebcast_msg_size() const {
return max_reducebcast_msg_size_;
}

/// Synchronizes all processes in communicator ... does NOT fence pending AM or tasks
void barrier() {
long i = world_.rank();
Expand Down Expand Up @@ -727,50 +771,46 @@ namespace madness {
delete [] buf;
}

private:
/// Inplace global reduction (like MPI all_reduce) while still processing AM & tasks

/// Optimizations can be added for long messages and to reduce the memory footprint
template <typename T, class opT>
void reduce_impl(T* buf, int nelem, opT op) {
void reduce(T* buf, std::size_t nelem, opT op) {
ProcessID parent, child0, child1;
world_.mpi.binary_tree_info(0, parent, child0, child1);
const std::size_t nelem_per_maxmsg = max_reducebcast_msg_size() / sizeof(T);

auto buf0 = std::unique_ptr<T[]>(new T[nelem_per_maxmsg]);
auto buf1 = std::unique_ptr<T[]>(new T[nelem_per_maxmsg]);

auto reduce_impl = [&,this](T* buf, int nelem) {
MADNESS_ASSERT(nelem <= nelem_per_maxmsg);
SafeMPI::Request req0, req1;
ProcessID parent, child0, child1;
world_.mpi.binary_tree_info(0, parent, child0, child1);
Tag gsum_tag = world_.mpi.unique_tag();

T* buf0 = new T[nelem];
T* buf1 = new T[nelem];

if (child0 != -1) req0 = world_.mpi.Irecv(buf0, nelem*sizeof(T), MPI_BYTE, child0, gsum_tag);
if (child1 != -1) req1 = world_.mpi.Irecv(buf1, nelem*sizeof(T), MPI_BYTE, child1, gsum_tag);
if (child0 != -1) req0 = world_.mpi.Irecv(buf0.get(), nelem*sizeof(T), MPI_BYTE, child0, gsum_tag);
if (child1 != -1) req1 = world_.mpi.Irecv(buf1.get(), nelem*sizeof(T), MPI_BYTE, child1, gsum_tag);

if (child0 != -1) {
World::await(req0);
for (long i=0; i<(long)nelem; ++i) buf[i] = op(buf[i],buf0[i]);
World::await(req0);
for (long i=0; i<(long)nelem; ++i) buf[i] = op(buf[i],buf0[i]);
}
if (child1 != -1) {
World::await(req1);
for (long i=0; i<(long)nelem; ++i) buf[i] = op(buf[i],buf1[i]);
World::await(req1);
for (long i=0; i<(long)nelem; ++i) buf[i] = op(buf[i],buf1[i]);
}

delete [] buf0;
delete [] buf1;

if (parent != -1) {
req0 = world_.mpi.Isend(buf, nelem*sizeof(T), MPI_BYTE, parent, gsum_tag);
World::await(req0);
req0 = world_.mpi.Isend(buf, nelem*sizeof(T), MPI_BYTE, parent, gsum_tag);
World::await(req0);
}

broadcast(buf, nelem, 0);
}

public:
/// Inplace global reduction (like MPI all_reduce) while still processing AM & tasks
};

/// Optimizations can be added for long messages and to reduce the memory footprint
template <typename T, class opT>
void reduce(T* buf, std::size_t nelem, opT op) {
const std::size_t nelem_per_intmax_nbytes = std::numeric_limits<int>::max() / sizeof(T);
while (nelem) {
const int n = std::min(nelem_per_intmax_nbytes, nelem);
reduce_impl(buf, n, op);
const int n = std::min(nelem_per_maxmsg, nelem);
reduce_impl(buf, n);
nelem -= n;
buf += n;
}
Expand Down Expand Up @@ -862,44 +902,94 @@ namespace madness {
/// \return on rank 0 returns the concatenated vector, elsewhere returns an empty vector
template <typename T>
std::vector<T> concat0(const std::vector<T>& v, size_t bufsz=1024*1024) {
MADNESS_ASSERT(bufsz <= std::numeric_limits<int>::max());

SafeMPI::Request req0, req1;
ProcessID parent, child0, child1;
world_.mpi.binary_tree_info(0, parent, child0, child1);
Tag gsum_tag = world_.mpi.unique_tag();

MADNESS_ASSERT(bufsz <= std::numeric_limits<int>::max());
auto buf0 = std::unique_ptr<std::byte[]>(new std::byte[bufsz]);
auto buf1 = std::unique_ptr<std::byte[]>(new std::byte[bufsz]);

const int batch_size = static_cast<int>(std::min(static_cast<size_t>(max_reducebcast_msg_size()),bufsz));
std::deque<Tag> tags; // stores tags used to send each batch

unsigned char* buf0 = new unsigned char[bufsz];
unsigned char* buf1 = new unsigned char[bufsz];
auto batched_receives = [&,this](size_t buf_offset) {
MADNESS_ASSERT(batch_size <= bufsz);
Tag gsum_tag = world_.mpi.unique_tag();
tags.push_back(gsum_tag);

if (child0 != -1)
req0 = world_.mpi.Irecv(buf0.get() + buf_offset,
bufsz - batch_size, MPI_BYTE, child0,
gsum_tag);
if (child1 != -1)
req1 = world_.mpi.Irecv(buf1.get() + buf_offset,
bufsz - batch_size, MPI_BYTE, child1,
gsum_tag);

if (child0 != -1) {
World::await(req0);
}
if (child1 != -1) {
World::await(req1);
}
};

// receive data in batches
if (child0 != -1 || child1 != -1) {
size_t buf_offset = 0;
while (buf_offset < bufsz) {
batched_receives(buf_offset);
buf_offset += batch_size;
buf_offset = std::min(buf_offset, bufsz);
}
}

if (child0 != -1) req0 = world_.mpi.Irecv(buf0, bufsz, MPI_BYTE, child0, gsum_tag);
if (child1 != -1) req1 = world_.mpi.Irecv(buf1, bufsz, MPI_BYTE, child1, gsum_tag);

std::vector<T> left, right;
if (child0 != -1) {
World::await(req0);
archive::BufferInputArchive ar(buf0, bufsz);
ar & left;
archive::BufferInputArchive ar(buf0.get(), bufsz);
ar & left;
}
if (child1 != -1) {
World::await(req1);
archive::BufferInputArchive ar(buf1, bufsz);
ar & right;
for (unsigned int i=0; i<right.size(); ++i) left.push_back(right[i]);
archive::BufferInputArchive ar(buf1.get(), bufsz);
ar & right;
for (unsigned int i = 0; i < right.size(); ++i)
left.push_back(right[i]);
}

for (unsigned int i=0; i<v.size(); ++i) left.push_back(v[i]);

// send data in batches
if (parent != -1) {
archive::BufferOutputArchive ar(buf0, bufsz);
ar & left;
req0 = world_.mpi.Isend(buf0, ar.size(), MPI_BYTE, parent, gsum_tag);
archive::BufferOutputArchive ar(buf0.get(), bufsz);
ar & left;
const auto total_nbytes_to_send = ar.size();

auto batched_send = [&,this](size_t buf_offset) {
MADNESS_ASSERT(batch_size <= bufsz);
Tag gsum_tag;
if (tags.empty()) {
gsum_tag = world_.mpi.unique_tag();
} else {
gsum_tag = tags.front();
tags.pop_front();
}

const auto nbytes_to_send = static_cast<int>(std::min(static_cast<size_t>(batch_size), total_nbytes_to_send - buf_offset));
req0 = world_.mpi.Isend(buf0.get() + buf_offset, nbytes_to_send, MPI_BYTE, parent,
gsum_tag);
World::await(req0);
};

size_t buf_offset = 0;
while (buf_offset < bufsz) {
batched_send(buf_offset);
buf_offset += batch_size;
buf_offset = std::min(buf_offset, bufsz);
}
}

delete [] buf0;
delete [] buf1;

if (parent == -1) return left;
else return std::vector<T>();
}
Expand Down

0 comments on commit 8788aea

Please sign in to comment.