Skip to content

Commit

Permalink
Minor MPI use tidying
Browse files Browse the repository at this point in the history
  • Loading branch information
rupertnash committed Jun 5, 2024
1 parent de0c2fd commit b4c1ba7
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 71 deletions.
37 changes: 19 additions & 18 deletions Code/geometry/Domain.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ namespace hemelb::geometry
// GmyReadResult. Since that is only guaranteed to know
// the rank of sites which are assigned to this rank, it
// may well return UNKNOWN_PROCESS
auto read_result_rank_for_site = [this, &max_site_index, &readResult](
auto read_result_rank_for_site = [&max_site_index, &readResult](
util::Vector3D<site_t> const& global_site_coords,
Vec16& block_coords, site_t& site_gmy_idx
) {
Expand Down Expand Up @@ -469,35 +469,36 @@ namespace hemelb::geometry
// propagate to different partitions is avoided (only their values
// will be communicated). It's here!
// Allocate the request variable.
net::Net tempNet(comms);
std::vector<net::MpiRequest> reqs(neighbouringProcs.size());
int i_req = 0;
for (auto& neighbouringProc : neighbouringProcs)
{
// We know that the elements are contiguous from asserts
auto flatten_vec_of_pairs = [] (auto&& vop) {
auto& [loc, d] = vop.front();
return std::span{&loc[0], vop.size() * 4};
};

// One way send receive. The lower numbered netTop->ProcessorCount send and the higher numbered ones receive.
// It seems that, for each pair of processors, the lower numbered one ends up with its own
// edge sites and directions stored and the higher numbered one ends up with those on the
// other processor.
if (neighbouringProc.Rank > localRank)
{
tempNet.RequestSendV<site_t>(
flatten_vec_of_pairs(sharedFLocationForEachProc.at(neighbouringProc.Rank)),
neighbouringProc.Rank);
}
else
{
// We know that the elements are contiguous from asserts
// in Domain.h about size and alignment of point_direction.
// Using a template as want the const/mutable variants.
auto const& to_send = sharedFLocationForEachProc.at(neighbouringProc.Rank);

reqs[i_req] = comms.Issend(
std::span<site_t const>(&to_send[0].first[0], to_send.size() * 4),
neighbouringProc.Rank
);
} else {
auto& dest = sharedFLocationForEachProc[neighbouringProc.Rank];
dest.resize(neighbouringProc.SharedDistributionCount);
tempNet.RequestReceiveV(flatten_vec_of_pairs(dest),
neighbouringProc.Rank);
reqs[i_req] = comms.Irecv(
std::span<site_t>(&dest[0].first[0], 4*dest.size()),
neighbouringProc.Rank
);
}
i_req += 1;
}

tempNet.Dispatch();
net::MpiRequest::Waitall(reqs);
}

void Domain::InitialiseReceiveLookup(proc2neighdata const& sharedFLocationForEachProc)
Expand Down
51 changes: 10 additions & 41 deletions Code/lb/iolets/BoundaryComms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,48 +20,25 @@ namespace hemelb::lb

// Only BC proc sends
if (bcComm.IsCurrentProcTheBCProc())
{
sendRequest = new MPI_Request[nProcs];
sendStatus = new MPI_Status[nProcs];
}
else
{
sendRequest = nullptr;
sendStatus = nullptr;
}
}

BoundaryComms::~BoundaryComms()
{

if (bcComm.IsCurrentProcTheBCProc())
{
delete[] sendRequest;
delete[] sendStatus;
}
sendRequest.resize(nProcs);
}

void BoundaryComms::Wait()
{
if (hasBoundary)
{
HEMELB_MPI_CALL(MPI_Wait, (&receiveRequest, &receiveStatus));
}
receiveRequest.Wait();
}

void BoundaryComms::WaitAllComms()
{
// Now wait for all to complete
if (bcComm.IsCurrentProcTheBCProc())
{
HEMELB_MPI_CALL(MPI_Waitall, (nProcs, sendRequest, sendStatus));
if (bcComm.IsCurrentProcTheBCProc()) {
net::MpiRequest::Waitall(sendRequest);

if (hasBoundary)
HEMELB_MPI_CALL(MPI_Wait, (&receiveRequest, &receiveStatus));
}
else
{
HEMELB_MPI_CALL(MPI_Wait, (&receiveRequest, &receiveStatus));
receiveRequest.Wait();
} else {
receiveRequest.Wait();
}

}
Expand All @@ -70,29 +47,21 @@ namespace hemelb::lb
void BoundaryComms::Send(distribn_t* density)
{
for (int proc = 0; proc < nProcs; proc++)
{
HEMELB_MPI_CALL(MPI_Isend,
( density, 1, net::MpiDataType(*density), procsList[proc], 100, bcComm, &sendRequest[proc] ));
}
sendRequest[proc] = bcComm.Issend(*density, procsList[proc], BC_TAG);
}

void BoundaryComms::Receive(distribn_t* density)
{
if (hasBoundary)
{
HEMELB_MPI_CALL(MPI_Irecv,
( density, 1, net::MpiDataType(*density), bcComm.GetBCProcRank(), 100, bcComm, &receiveRequest ));
}
receiveRequest = bcComm.Irecv(*density, bcComm.GetBCProcRank(), BC_TAG);
}

void BoundaryComms::FinishSend()
{
// Don't move on to next step with BC proc until all messages have been sent
// Precautionary measure to make sure proc doesn't overwrite, before message is sent
if (bcComm.IsCurrentProcTheBCProc())
{
HEMELB_MPI_CALL(MPI_Waitall, (nProcs, sendRequest, sendStatus));
}
net::MpiRequest::Waitall(sendRequest);
}

}
14 changes: 6 additions & 8 deletions Code/lb/iolets/BoundaryComms.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ namespace hemelb::lb

class BoundaryComms
{
public:
public:
BoundaryComms(SimulationState* iSimState, std::vector<int> &iProcsList,
const BoundaryCommunicator& boundaryComm, bool iHasBoundary);
~BoundaryComms();

void Wait();

Expand All @@ -36,7 +35,9 @@ namespace hemelb::lb
void WaitAllComms();
void FinishSend();

private:
private:
// MPI tag for communication
static constexpr int BC_TAG = 100;
// This is necessary to support BC proc having fluid sites
bool hasBoundary;

Expand All @@ -46,11 +47,8 @@ namespace hemelb::lb
std::vector<int> procsList;
const BoundaryCommunicator& bcComm;

MPI_Request *sendRequest;
MPI_Status *sendStatus;

MPI_Request receiveRequest;
MPI_Status receiveStatus;
std::vector<net::MpiRequest> sendRequest;
net::MpiRequest receiveRequest;

SimulationState* mState;
};
Expand Down
11 changes: 11 additions & 0 deletions Code/net/MpiCommunicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@

namespace hemelb::net
{
void MpiRequest::Waitall(std::span<MpiRequest> reqs) {
// These asserts check that there's no extra data in an
// MpiRequest object, so we can effectively just pointer
// alias a span of them.
static_assert(sizeof(MpiRequest) == sizeof(MPI_Request));
static_assert(alignof(MpiRequest)== alignof(MPI_Request));
int N = std::ssize(reqs);
MPI_Request* data = reqs.data() ? &reqs.data()->req : nullptr;
MpiCall{MPI_Waitall}(N, data, MPI_STATUSES_IGNORE);
}

namespace {
void Deleter(MPI_Comm* comm)
{
Expand Down
8 changes: 8 additions & 0 deletions Code/net/MpiCommunicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,16 @@ namespace hemelb::net
inline void Wait() {
HEMELB_MPI_CALL(MPI_Wait, (&req, MPI_STATUS_IGNORE));
}
static void Waitall(std::span<MpiRequest> reqs);

[[nodiscard]] inline bool Test() {
int done = 0;
HEMELB_MPI_CALL(MPI_Test, (&req, &done, MPI_STATUS_IGNORE));
return done;
}
};


// Holds an MPI communicator and exposes communication functions
// via members. It will own the underlying MPI_Comm (i.e. it will
// call MPI_Comm_free) if created from another MpiCommunicator
Expand Down Expand Up @@ -281,6 +284,11 @@ namespace hemelb::net
void Receive(std::vector<T>& val, int src, int tag = 0,
MPI_Status* stat = MPI_STATUS_IGNORE) const;

template <typename T>
[[nodiscard]] MpiRequest Irecv(std::span<T> dest, int src, int tag = 0) const;
template <typename T>
[[nodiscard]] MpiRequest Irecv(T& dest, int src, int tag = 0) const;

//! \brief Create a distributed graph communicator assuming unweighted and bidirectional communication.
[[nodiscard]] MpiCommunicator DistGraphAdjacent(std::vector<int> my_neighbours, bool reorder = true) const;

Expand Down
18 changes: 14 additions & 4 deletions Code/net/MpiCommunicator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,15 +235,13 @@ namespace hemelb::net
template <typename T>
MpiRequest MpiCommunicator::Issend(std::span<T const> vals, int dest, int tag) const {
MpiRequest ans;
HEMELB_MPI_CALL(MPI_Issend,
(vals.data(), vals.size(), MpiDataType<T>(), dest, tag, *this, &ans.req));
MpiCall{MPI_Issend}(vals.data(), vals.size(), MpiDataType<T>(), dest, tag, *commPtr, &ans.req);
return ans;
}
template <typename T>
MpiRequest MpiCommunicator::Issend(T const& val, int dest, int tag) const {
MpiRequest ans;
HEMELB_MPI_CALL(MPI_Issend,
(&val, 1, MpiDataType<T>(), dest, tag, *this, &ans.req));
MpiCall{MPI_Issend}(&val, 1, MpiDataType<T>(), dest, tag, *commPtr, &ans.req);
return ans;
}

Expand All @@ -258,6 +256,18 @@ namespace hemelb::net
HEMELB_MPI_CALL(MPI_Recv, (vals.data(), vals.size(), MpiDataType<T>(), src, tag, *this, stat));
}

template <typename T>
MpiRequest MpiCommunicator::Irecv(std::span<T> dest, int src, int tag) const {
MpiRequest ans;
MpiCall{MPI_Irecv}(dest.data(), dest.size(), MpiDataType<T>(), src, tag, *commPtr, &ans.req);
return ans;
}
template <typename T>
MpiRequest MpiCommunicator::Irecv(T &dest, int src, int tag) const {
MpiRequest ans;
MpiCall{MPI_Irecv}(&dest, 1, MpiDataType<T>(), src, tag, *commPtr, &ans.req);
return ans;
}
}

#endif // HEMELB_NET_MPICOMMUNICATOR_HPP

0 comments on commit b4c1ba7

Please sign in to comment.