Skip to content

Commit

Permalink
WorldGOPInterface::concat0 also uses max_reducebcast_msg_size + clean…
Browse files Browse the repository at this point in the history
…up of reduce/broadcast
  • Loading branch information
evaleev committed Jan 23, 2024
1 parent 2628cb9 commit 87715d9
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 67 deletions.
14 changes: 13 additions & 1 deletion src/madness/world/safempi.h
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,13 @@ namespace SafeMPI {
return result;
}

/// \return the period of repeat of unique tags produces by unique_tag()
static int unique_tag_period() {
const auto min_tag_value = 1024;
const auto max_tag_value = 4094;
return max_tag_value - min_tag_value + 1;
}

/// Returns a unique tag reserved for long-term use (0<tag<1000)

/// Get a tag from this routine for long-term/repeated use.
Expand Down Expand Up @@ -810,7 +817,7 @@ namespace SafeMPI {
MADNESS_MPI_TEST(MPI_Barrier(pimpl->comm));
}

/// Returns a unique tag for temporary use (1023<tag<=4095)
/// Returns a unique tag for temporary use (1023<tag<4095)

/// These tags are intended for one time use to avoid tag
/// collisions with other messages around the same time period.
Expand All @@ -824,6 +831,11 @@ namespace SafeMPI {
return pimpl->unique_tag();
}

/// \return the period of repeat of unique tags produces by unique_tag()
static int unique_tag_period() {
return Impl::unique_tag_period();
}

/// Returns a unique tag reserved for long-term use (0<tag<1000)

/// Get a tag from this routine for long-term/repeated use.
Expand Down
42 changes: 23 additions & 19 deletions src/madness/world/worldgop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,36 +170,40 @@ namespace madness {
fence();
}

/// Broadcasts bytes from process root while still processing AM & tasks
static void broadcast_impl(void* buf, int nbyte, ProcessID root, bool dowork, Tag bcast_tag, World &world) {
void WorldGopInterface::broadcast(void* buf, size_t nbyte, ProcessID root, bool dowork, Tag bcast_tag) {
if (bcast_tag < 0)
bcast_tag = world_.mpi.unique_tag();
ProcessID parent, child0, child1;
world_.mpi.binary_tree_info(root, parent, child0, child1);
const size_t max_msg_size =
static_cast<size_t>(max_reducebcast_msg_size());

auto broadcast_impl = [&, this](void *buf, int nbyte) {
SafeMPI::Request req0, req1;
ProcessID parent, child0, child1;
world.mpi.binary_tree_info(root, parent, child0, child1);

//print("BCAST TAG", bcast_tag);
// print("BCAST TAG", bcast_tag);

if (parent != -1) {
req0 = world.mpi.Irecv(buf, nbyte, MPI_BYTE, parent, bcast_tag);
World::await(req0, dowork);
req0 = world_.mpi.Irecv(buf, nbyte, MPI_BYTE, parent, bcast_tag);
World::await(req0, dowork);
}

if (child0 != -1) req0 = world.mpi.Isend(buf, nbyte, MPI_BYTE, child0, bcast_tag);
if (child1 != -1) req1 = world.mpi.Isend(buf, nbyte, MPI_BYTE, child1, bcast_tag);
if (child0 != -1)
req0 = world_.mpi.Isend(buf, nbyte, MPI_BYTE, child0, bcast_tag);
if (child1 != -1)
req1 = world_.mpi.Isend(buf, nbyte, MPI_BYTE, child1, bcast_tag);

if (child0 != -1) World::await(req0, dowork);
if (child1 != -1) World::await(req1, dowork);
}
if (child0 != -1)
World::await(req0, dowork);
if (child1 != -1)
World::await(req1, dowork);
};

/// Optimizations can be added for long messages
void WorldGopInterface::broadcast(void* buf, size_t nbyte, ProcessID root, bool dowork, Tag bcast_tag) {
if(bcast_tag < 0)
bcast_tag = world_.mpi.unique_tag();
const size_t max_msg_size = static_cast<size_t>(max_reducebcast_msg_size());
while (nbyte) {
const int n = static_cast<int>(std::min(max_msg_size, nbyte));
broadcast_impl(buf, n, root, dowork, bcast_tag, world_);
broadcast_impl(buf, n);
nbyte -= n;
buf = static_cast<char*>(buf) + n;
buf = static_cast<char *>(buf) + n;
}
}

Expand Down
140 changes: 93 additions & 47 deletions src/madness/world/worldgop.h
Original file line number Diff line number Diff line change
Expand Up @@ -771,50 +771,46 @@ namespace madness {
delete [] buf;
}

private:
/// Inplace global reduction (like MPI all_reduce) while still processing AM & tasks

/// Optimizations can be added for long messages and to reduce the memory footprint
template <typename T, class opT>
void reduce_impl(T* buf, int nelem, opT op) {
void reduce(T* buf, std::size_t nelem, opT op) {
ProcessID parent, child0, child1;
world_.mpi.binary_tree_info(0, parent, child0, child1);
const std::size_t nelem_per_maxmsg = max_reducebcast_msg_size() / sizeof(T);

auto buf0 = std::unique_ptr<T[]>(new T[nelem_per_maxmsg]);
auto buf1 = std::unique_ptr<T[]>(new T[nelem_per_maxmsg]);

auto reduce_impl = [&,this](T* buf, int nelem) {
MADNESS_ASSERT(nelem <= nelem_per_maxmsg);
SafeMPI::Request req0, req1;
ProcessID parent, child0, child1;
world_.mpi.binary_tree_info(0, parent, child0, child1);
Tag gsum_tag = world_.mpi.unique_tag();

T* buf0 = new T[nelem];
T* buf1 = new T[nelem];

if (child0 != -1) req0 = world_.mpi.Irecv(buf0, nelem*sizeof(T), MPI_BYTE, child0, gsum_tag);
if (child1 != -1) req1 = world_.mpi.Irecv(buf1, nelem*sizeof(T), MPI_BYTE, child1, gsum_tag);
if (child0 != -1) req0 = world_.mpi.Irecv(buf0.get(), nelem*sizeof(T), MPI_BYTE, child0, gsum_tag);
if (child1 != -1) req1 = world_.mpi.Irecv(buf1.get(), nelem*sizeof(T), MPI_BYTE, child1, gsum_tag);

if (child0 != -1) {
World::await(req0);
for (long i=0; i<(long)nelem; ++i) buf[i] = op(buf[i],buf0[i]);
World::await(req0);
for (long i=0; i<(long)nelem; ++i) buf[i] = op(buf[i],buf0[i]);
}
if (child1 != -1) {
World::await(req1);
for (long i=0; i<(long)nelem; ++i) buf[i] = op(buf[i],buf1[i]);
World::await(req1);
for (long i=0; i<(long)nelem; ++i) buf[i] = op(buf[i],buf1[i]);
}

delete [] buf0;
delete [] buf1;

if (parent != -1) {
req0 = world_.mpi.Isend(buf, nelem*sizeof(T), MPI_BYTE, parent, gsum_tag);
World::await(req0);
req0 = world_.mpi.Isend(buf, nelem*sizeof(T), MPI_BYTE, parent, gsum_tag);
World::await(req0);
}

broadcast(buf, nelem, 0);
}
};

public:
/// Inplace global reduction (like MPI all_reduce) while still processing AM & tasks

/// Optimizations can be added for long messages and to reduce the memory footprint
template <typename T, class opT>
void reduce(T* buf, std::size_t nelem, opT op) {
const std::size_t nelem_per_maxmsg = max_reducebcast_msg_size() / sizeof(T);
while (nelem) {
const int n = std::min(nelem_per_maxmsg, nelem);
reduce_impl(buf, n, op);
reduce_impl(buf, n);
nelem -= n;
buf += n;
}
Expand Down Expand Up @@ -906,44 +902,94 @@ namespace madness {
/// \return on rank 0 returns the concatenated vector, elsewhere returns an empty vector
template <typename T>
std::vector<T> concat0(const std::vector<T>& v, size_t bufsz=1024*1024) {
MADNESS_ASSERT(bufsz <= std::numeric_limits<int>::max());

SafeMPI::Request req0, req1;
ProcessID parent, child0, child1;
world_.mpi.binary_tree_info(0, parent, child0, child1);
Tag gsum_tag = world_.mpi.unique_tag();

MADNESS_ASSERT(bufsz <= std::numeric_limits<int>::max());
auto buf0 = std::unique_ptr<std::byte[]>(new std::byte[bufsz]);
auto buf1 = std::unique_ptr<std::byte[]>(new std::byte[bufsz]);

const int batch_size = static_cast<int>(std::min(static_cast<size_t>(max_reducebcast_msg_size()),bufsz));
std::deque<Tag> tags; // stores tags used to send each batch

unsigned char* buf0 = new unsigned char[bufsz];
unsigned char* buf1 = new unsigned char[bufsz];
auto batched_receives = [&,this](size_t buf_offset) {
MADNESS_ASSERT(batch_size <= bufsz);
Tag gsum_tag = world_.mpi.unique_tag();
tags.push_back(gsum_tag);

if (child0 != -1)
req0 = world_.mpi.Irecv(buf0.get() + buf_offset,
bufsz - batch_size, MPI_BYTE, child0,
gsum_tag);
if (child1 != -1)
req1 = world_.mpi.Irecv(buf1.get() + buf_offset,
bufsz - batch_size, MPI_BYTE, child1,
gsum_tag);

if (child0 != -1) {
World::await(req0);
}
if (child1 != -1) {
World::await(req1);
}
};

// receive data in batches
if (child0 != -1 || child1 != -1) {
size_t buf_offset = 0;
while (buf_offset < bufsz) {
batched_receives(buf_offset);
buf_offset += batch_size;
buf_offset = std::min(buf_offset, bufsz);
}
}

if (child0 != -1) req0 = world_.mpi.Irecv(buf0, bufsz, MPI_BYTE, child0, gsum_tag);
if (child1 != -1) req1 = world_.mpi.Irecv(buf1, bufsz, MPI_BYTE, child1, gsum_tag);

std::vector<T> left, right;
if (child0 != -1) {
World::await(req0);
archive::BufferInputArchive ar(buf0, bufsz);
ar & left;
archive::BufferInputArchive ar(buf0.get(), bufsz);
ar & left;
}
if (child1 != -1) {
World::await(req1);
archive::BufferInputArchive ar(buf1, bufsz);
ar & right;
for (unsigned int i=0; i<right.size(); ++i) left.push_back(right[i]);
archive::BufferInputArchive ar(buf1.get(), bufsz);
ar & right;
for (unsigned int i = 0; i < right.size(); ++i)
left.push_back(right[i]);
}

for (unsigned int i=0; i<v.size(); ++i) left.push_back(v[i]);

// send data in batches
if (parent != -1) {
archive::BufferOutputArchive ar(buf0, bufsz);
ar & left;
req0 = world_.mpi.Isend(buf0, ar.size(), MPI_BYTE, parent, gsum_tag);
archive::BufferOutputArchive ar(buf0.get(), bufsz);
ar & left;
const auto total_nbytes_to_send = ar.size();

auto batched_send = [&,this](size_t buf_offset) {
MADNESS_ASSERT(batch_size <= bufsz);
Tag gsum_tag;
if (tags.empty()) {
gsum_tag = world_.mpi.unique_tag();
} else {
gsum_tag = tags.front();
tags.pop_front();
}

const auto nbytes_to_send = static_cast<int>(std::min(static_cast<size_t>(batch_size), total_nbytes_to_send - buf_offset));
req0 = world_.mpi.Isend(buf0.get() + buf_offset, nbytes_to_send, MPI_BYTE, parent,
gsum_tag);
World::await(req0);
};

size_t buf_offset = 0;
while (buf_offset < bufsz) {
batched_send(buf_offset);
buf_offset += batch_size;
buf_offset = std::min(buf_offset, bufsz);
}
}

delete [] buf0;
delete [] buf1;

if (parent == -1) return left;
else return std::vector<T>();
}
Expand Down

0 comments on commit 87715d9

Please sign in to comment.