Skip to content

Commit

Permalink
Merge pull request #25 from danclaudino/hpc_virt_update
Browse files Browse the repository at this point in the history
Added shots to HPC virtualization
  • Loading branch information
danclaudino authored Jul 25, 2024
2 parents 94fcd56 + ec7b050 commit 6bcd980
Show file tree
Hide file tree
Showing 7 changed files with 320 additions and 97 deletions.
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ add_compile_flags_if_supported(-Wno-maybe-uninitialized)
# Check MPI status
# if MPI_CXX_COMPILER is not empty and XACC_ENABLE_MPI is set
# turn MPI_ENABLED on
if(NOT MPI_CXX_COMPILER STREQUAL "" AND XACC_ENABLE_MPI)
# Update: we don't really need to give the path to the compiler
# because if MPI is found, MPI_CXX_COMPILER is populated
if(XACC_ENABLE_MPI)
find_package(MPI)

if(MPI_FOUND)
Expand Down
19 changes: 15 additions & 4 deletions quantum/plugins/decorators/hpc-virtualization/MPIProxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,25 @@ Copyright (C) 2018-2021 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2021 Oak Ridge National Laboratory (UT-Battelle) **/

#include "MPIProxy.hpp"

#include "mpi.h"

#include <cstdlib>

#include <iostream>
#include <algorithm>

template <>
MPI_Datatype MPIDataTypeResolver<int>::getMPIDatatype() {
return MPI_INT;
}

template <>
MPI_Datatype MPIDataTypeResolver<double>::getMPIDatatype() {
return MPI_DOUBLE;
}

template <>
MPI_Datatype MPIDataTypeResolver<char>::getMPIDatatype() {
return MPI_CHAR;
}

namespace xacc {

//Temporary buffers:
Expand Down
50 changes: 50 additions & 0 deletions quantum/plugins/decorators/hpc-virtualization/MPIProxy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ Copyright (C) 2018-2021 Oak Ridge National Laboratory (UT-Battelle) **/
#include <vector>
#include <memory>
#include <cassert>
#include "mpi.h"

template <typename T>
class MPIDataTypeResolver {
public:
MPI_Datatype getMPIDatatype();
};

namespace xacc {

Expand Down Expand Up @@ -142,6 +149,49 @@ class ProcessGroup {
different MPI processes, thus putting them into disjoint subgroups. **/
std::shared_ptr<ProcessGroup> split(int my_subgroup) const;


// some useful wrappers

// I could move this to a single function, but don't
// want to abuse template specialization here
// this broadcasts a single element (int/char/double)
template<typename T>
void broadcast(T element) {

MPIDataTypeResolver<T> resolver;
MPI_Datatype mpiType = resolver.getMPIDatatype();
MPI_Bcast(&element, 1, mpiType, 0,
this->getMPICommProxy().getRef<MPI_Comm>());
}

// this broadcasts a vector
template<typename T>
void broadcast(std::vector<T> &vec) {

MPIDataTypeResolver<T> resolver;
MPI_Datatype mpiType = resolver.getMPIDatatype();
MPI_Bcast(vec.data(), vec.size(), mpiType, 0,
this->getMPICommProxy().getRef<MPI_Comm>());
};


// this Allgatherv's the content of local vectors
// into a global vector
template<typename T>
void allGatherv(std::vector<T> &local,
std::vector<T> &global,
std::vector<int> &nLocalData,
std::vector<int> &shift) {

MPIDataTypeResolver<T> resolver;
MPI_Datatype mpiType = resolver.getMPIDatatype();
MPI_Allgatherv(local.data(), local.size(), mpiType,
global.data(), nLocalData.data(),
shift.data(), mpiType,
this->getMPICommProxy().getRef<MPI_Comm>());

}

protected:

std::vector<unsigned int> process_ranks_; //global ranks of the MPI processes forming the process group
Expand Down
Loading

0 comments on commit 6bcd980

Please sign in to comment.