Skip to content

Commit

Permalink
[PG] add support for subgroup creation
Browse files Browse the repository at this point in the history
  • Loading branch information
ajaypanyala committed Nov 14, 2024
1 parent 1a7c616 commit 5ca53d3
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 13 deletions.
4 changes: 2 additions & 2 deletions src/tamm/atomic_counter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class AtomicCounterGA: public AtomicCounter {
#endif
pg_{pg} {
#if defined(USE_UPCXX)
ad_i64 = new upcxx::atomic_domain<int64_t>({upcxx::atomic_op::fetch_add}, *pg.team());
ad_i64 = new upcxx::atomic_domain<int64_t>({upcxx::atomic_op::fetch_add}, *pg.comm());
#endif
}

Expand All @@ -110,7 +110,7 @@ class AtomicCounterGA: public AtomicCounter {
{
// upcxx::persona_scope master_scope(master_mtx,
// upcxx::master_persona());
dobj = new upcxx::dist_object<upcxx::global_ptr<int64_t>>(local_gptr, *pg_.team());
dobj = new upcxx::dist_object<upcxx::global_ptr<int64_t>>(local_gptr, *pg_.comm());
}

pg_.barrier();
Expand Down
4 changes: 2 additions & 2 deletions src/tamm/memory_manager_ga.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class MemoryManagerGA: public MemoryManager {
#if defined(USE_UPCXX)
void alloc_coll_upcxx(ElementType eltype, Size local_nelements, MemoryRegionGA* pmr, int nranks,
int64_t element_size, int64_t nels) {
upcxx::team* team = pg_.team();
upcxx::team* team = pg_.comm();
pmr->gptrs_ = new upcxx::global_ptr<uint8_t>[nranks];
pmr->eltype_ = eltype;
pmr->eltype_size_ = get_element_size(eltype);
Expand Down Expand Up @@ -277,7 +277,7 @@ class MemoryManagerGA: public MemoryManager {
explicit MemoryManagerGA(ProcGroup pg): MemoryManager{pg, MemoryManagerKind::ga} {
EXPECTS(pg.is_valid());
#if defined(USE_UPCXX)
team_ = pg.team();
team_ = pg.comm();
#else
pg_ = pg;
ga_pg_ = pg.ga_pg();
Expand Down
100 changes: 98 additions & 2 deletions src/tamm/proc_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,90 @@ class ProcGroup {
}
#endif

/**
* @brief Create a process group with only the calling process in it
*
* @return ProcGroup New ProcGroup object that contains only the calling process
*/
static ProcGroup create_self() {
#if defined(USE_UPCXX)
upcxx::team* scomm = new upcxx::team(upcxx::world().create(std::vector<int>{upcxx::rank_me()}));
ProcGroup pg = create_coll(*scomm);
#else
ProcGroup pg = create_coll(MPI_COMM_SELF);
#endif
return pg;
}

/**
* @brief Collectively create a ProcGroup from the given parent group and a list of ranks
*
* @param parent_group Parent ProcGroup
* @param nranks size of the sub-group
* @return ProcGroup New ProcGroup object that creates the corresponding process sub-group
*/
static ProcGroup create_subgroup(const ProcGroup& parent_group, std::vector<int>& ranks) {
const int nranks = ranks.size();
#if defined(USE_UPCXX)
// TODO: should use ranks in list and not first nranks
const bool in_new_team = (parent_group.rank() < ranks.size());
upcxx::team* gcomm = parent_group.comm();
upcxx::team* scomm = new upcxx::team(
gcomm->split(in_new_team ? 0 : upcxx::team::color_none, parent_group.rank().value()));
ProcGroup pg = create_coll(*scomm);
#else
MPI_Comm scomm;
MPI_Group wgroup;
MPI_Group sgroup;
auto gcomm = parent_group.comm();
MPI_Comm_group(gcomm, &wgroup);
MPI_Group_incl(wgroup, nranks, ranks.data(), &sgroup);
MPI_Comm_create(gcomm, sgroup, &scomm);
MPI_Group_free(&wgroup);
MPI_Group_free(&sgroup);
ProcGroup pg;
if(scomm != MPI_COMM_NULL) {
pg = create_coll(scomm);
MPI_Comm_free(&scomm); // since we duplicate in create_coll
}
#endif
return pg;
}

// Create subgroup from first nranks of parent group
static ProcGroup create_subgroup(const ProcGroup& parent_group, int nranks) {
std::vector<int> ranks(nranks);
for(int i = 0; i < nranks; i++) ranks[i] = i;
return create_subgroup(parent_group, ranks);
}

/**
* @brief Collectively create multiple process groups from the given parent group
*
* @param parent_group Parent ProcGroup
* @param nranks size of each sub-group
* @return ProcGroup New ProcGroup object that creates the corresponding process sub-group
*/
static ProcGroup create_subgroups(const ProcGroup& parent_group, int nranks) {
ProcGroup pg;
#if defined(USE_UPCXX)
upcxx::team* parent_team = parent_group.comm();
int color = upcxx::rank_me() % nranks;
int key = upcxx::rank_me() / nranks;
upcxx::team scomm = parent_team->split(color, key);
pg = create_coll(scomm);
#else
int color = 0;
MPI_Comm scomm;
int pg_rank = parent_group.rank().value();
if(nranks > 1) color = pg_rank / nranks;
MPI_Comm_split(parent_group.comm(), color, pg_rank, &scomm);
pg = create_coll(scomm);
MPI_Comm_free(&scomm); // since we duplicate in create_coll
#endif
return pg;
}

/**
* @brief Check if the given process group is valid
*
Expand Down Expand Up @@ -122,7 +206,7 @@ class ProcGroup {
*
* @pre is_valid()
*/
upcxx::team* team() const {
upcxx::team* comm() const {
EXPECTS(is_valid());
return pginfo_->team_;
}
Expand All @@ -138,6 +222,18 @@ class ProcGroup {
return pginfo_->mpi_comm_;
}

/**
* Access the underlying MPI communicator
* @return the Fortran representation of the wrapped MPI communicator
*
* @pre is_valid()
*/
MPI_Fint comm_c2f() const {
EXPECTS(is_valid());
// convert the C comm handle to its Fortran equivalent
return MPI_Comm_c2f(pginfo_->mpi_comm_);
}

/**
* @brief Obtained the underlying GA process group
*
Expand Down Expand Up @@ -391,7 +487,7 @@ class ProcGroup {
upcxx::promise<> p;

for(int r = 0; r < nranks; ++r)
upcxx::broadcast(rbuf + r * rcount, scount, r, *team(), upcxx::operation_cx::as_promise(p));
upcxx::broadcast(rbuf + r * rcount, scount, r, *comm(), upcxx::operation_cx::as_promise(p));

p.finalize().wait();
}
Expand Down
2 changes: 1 addition & 1 deletion src/tamm/tamm_io.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ void write_to_disk(Tensor<TensorType> tensor, const std::string& filename, bool
auto [nagg, ppn, subranks] = get_subgroup_info(gec, tensor, nagg_hint);
#if defined(USE_UPCXX)
upcxx::team* io_comm = new upcxx::team(
gec.pg().team()->split(gec.pg().rank() < subranks ? 0 : upcxx::team::color_none, 0));
gec.pg().comm()->split(gec.pg().rank() < subranks ? 0 : upcxx::team::color_none, 0));
#else
MPI_Comm io_comm;
subcomm_from_subranks(gec, subranks, io_comm);
Expand Down
8 changes: 4 additions & 4 deletions src/tamm/tamm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ TensorType linf_norm(LabeledTensor<TensorType> ltensor) {
auto [nagg, ppn, subranks] = get_subgroup_info(gec, tensor);
#if defined(USE_UPCXX)
upcxx::team* sub_comm =
new upcxx::team(gec.pg().team()->split(rank < subranks ? 0 : upcxx::team::color_none, 0));
new upcxx::team(gec.pg().comm()->split(rank < subranks ? 0 : upcxx::team::color_none, 0));
#else
MPI_Comm sub_comm;
subcomm_from_subranks(gec, subranks, sub_comm);
Expand Down Expand Up @@ -842,7 +842,7 @@ void apply_ewise_ip(LabeledTensor<TensorType> ltensor, std::function<TensorType(
auto [nagg, ppn, subranks] = get_subgroup_info(gec, tensor);
#if defined(USE_UPCXX)
upcxx::team* sub_comm = new upcxx::team(
gec.pg().team()->split(gec.pg().rank() < subranks ? 0 : upcxx::team::color_none, 0));
gec.pg().comm()->split(gec.pg().rank() < subranks ? 0 : upcxx::team::color_none, 0));
#else
MPI_Comm sub_comm;
subcomm_from_subranks(gec, subranks, sub_comm);
Expand Down Expand Up @@ -1079,7 +1079,7 @@ TensorType sum(LabeledTensor<TensorType> ltensor) {
auto [nagg, ppn, subranks] = get_subgroup_info(gec, tensor);
#if defined(USE_UPCXX)
upcxx::team* sub_comm = new upcxx::team(
gec.pg().team()->split(gec.pg().rank() < subranks ? 0 : upcxx::team::color_none, 0));
gec.pg().comm()->split(gec.pg().rank() < subranks ? 0 : upcxx::team::color_none, 0));
#else
MPI_Comm sub_comm;
subcomm_from_subranks(gec, subranks, sub_comm);
Expand Down Expand Up @@ -1177,7 +1177,7 @@ TensorType norm(ExecutionContext& gec, LabeledTensor<TensorType> ltensor) {
auto [nagg, ppn, subranks] = get_subgroup_info(gec, tensor);
#if defined(USE_UPCXX)
upcxx::team* sub_comm =
new upcxx::team(gec.pg().team()->split(rank < subranks ? 0 : upcxx::team::color_none, 0));
new upcxx::team(gec.pg().comm()->split(rank < subranks ? 0 : upcxx::team::color_none, 0));
#else
MPI_Comm sub_comm;
subcomm_from_subranks(gec, subranks, sub_comm);
Expand Down
4 changes: 2 additions & 2 deletions src/tamm/tensor_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,7 @@ class DenseTensorImpl: public TensorImpl<T> {
gptrs_.resize(nranks);
upcxx::promise<> p(nranks);
for(int r = 0; r < nranks; r++)
upcxx::broadcast(local_gptr_, r, *ec->pg().team())
upcxx::broadcast(local_gptr_, r, *ec->pg().comm())
.then([this, &p, r](upcxx::global_ptr<uint8_t> result) {
gptrs_[r] = result;
p.fulfill_anonymous(1);
Expand Down Expand Up @@ -1158,7 +1158,7 @@ class DenseTensorImpl: public TensorImpl<T> {
TensorTile t = find_tile(lo[0], lo[1], lo[2], lo[3]);

upcxx::rpc(
*ec_->pg().team(), t.rank,
*ec_->pg().comm(), t.rank,
[](const upcxx::global_ptr<T>& dst_buf, const upcxx::view<T>& src_buf) {
T* dst = dst_buf.local();
size_t n = src_buf.size();
Expand Down

0 comments on commit 5ca53d3

Please sign in to comment.