Skip to content

Commit

Permalink
[PG] allow parent group specification internally
Browse files Browse the repository at this point in the history
  • Loading branch information
ajaypanyala committed Nov 15, 2024
1 parent 5ca53d3 commit d005809
Showing 1 changed file with 55 additions and 12 deletions.
67 changes: 55 additions & 12 deletions src/tamm/proc_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ class ProcGroup {
}

/**
* @brief Collectively create a ProcGroup from the given communicator
*
* @brief Collectively create a ProcGroup from the given communicator.
* Assumes parent is GA world group.
* @param mpi_comm Communication to be used as a basis of the ProcGroup
* @return ProcGroup New ProcGroup object that duplicates @param mpi_comm and
* creates the corresponding GA process group
Expand All @@ -88,6 +88,18 @@ class ProcGroup {
pg.pginfo_->is_valid_ = (mpi_comm != MPI_COMM_NULL);
return pg;
}

static ProcGroup create_coll(const ProcGroup& parent_group, MPI_Comm mpi_comm) {
MPI_Comm comm_out;
MPI_Comm_dup(mpi_comm, &comm_out);
ProcGroup pg;
pg.pginfo_->mpi_comm_ = comm_out;
pg.pginfo_->created_mpi_comm_ = true;
pg.pginfo_->ga_pg_ = create_ga_process_group_coll(parent_group, mpi_comm);
pg.pginfo_->created_ga_pg_ = true;
pg.pginfo_->is_valid_ = (mpi_comm != MPI_COMM_NULL);
return pg;
}
#endif

/**
Expand All @@ -97,10 +109,10 @@ class ProcGroup {
*/
static ProcGroup create_self() {
#if defined(USE_UPCXX)
upcxx::team* scomm = new upcxx::team(upcxx::world().create(std::vector<int>{upcxx::rank_me()}));
ProcGroup pg = create_coll(*scomm);
upcxx::team* team_self = new upcxx::team(upcxx::local_team().split(upcxx::rank_me(), 0));
ProcGroup pg{team_self};
#else
ProcGroup pg = create_coll(MPI_COMM_SELF);
ProcGroup pg{MPI_COMM_SELF, ProcGroup::self_ga_pgroup()};
#endif
return pg;
}
Expand Down Expand Up @@ -129,13 +141,13 @@ class ProcGroup {
MPI_Comm_group(gcomm, &wgroup);
MPI_Group_incl(wgroup, nranks, ranks.data(), &sgroup);
MPI_Comm_create(gcomm, sgroup, &scomm);
MPI_Group_free(&wgroup);
MPI_Group_free(&sgroup);
ProcGroup pg;
if(scomm != MPI_COMM_NULL) {
pg = create_coll(scomm);
pg = create_coll(parent_group, scomm);
MPI_Comm_free(&scomm); // since we duplicate in create_coll
}
MPI_Group_free(&wgroup);
MPI_Group_free(&sgroup);
#endif
return pg;
}
Expand Down Expand Up @@ -168,7 +180,7 @@ class ProcGroup {
int pg_rank = parent_group.rank().value();
if(nranks > 1) color = pg_rank / nranks;
MPI_Comm_split(parent_group.comm(), color, pg_rank, &scomm);
pg = create_coll(scomm);
pg = create_coll(parent_group, scomm);
MPI_Comm_free(&scomm); // since we duplicate in create_coll
#endif
return pg;
Expand Down Expand Up @@ -271,12 +283,13 @@ class ProcGroup {
}

/**
* Collectivelu clone the given process group
* Collectively clone the given process group. Not used currently.
* @return A copy of this process group
*/
#if defined(USE_UPCXX)
ProcGroup clone_coll() const { return create_coll(*pginfo_->team_); }
#else
// TODO: handle cloning of subgroups
ProcGroup clone_coll() const { return create_coll(pginfo_->mpi_comm_); }
#endif

Expand Down Expand Up @@ -735,9 +748,10 @@ class ProcGroup {

private:
/**
* Create a GA process group corresponding to the given proc group
* Create a GA process group corresponding to the given proc group.
* Assumes parent group is GA world group.
* @param pg TAMM process group
* @return GA processes group on this TAMM process group
* @return GA process group on this TAMM process group
*/
static int create_ga_process_group_coll(MPI_Comm comm) {
int nranks;
Expand All @@ -746,6 +760,7 @@ class ProcGroup {
int ranks[nranks], ranks_world[nranks];
MPI_Comm_group(comm, &group);

// also works when GA is initialized with an existing MPI communicator
MPI_Comm_group(GA_MPI_Comm(), &group_world);

for(int i = 0; i < nranks; i++) { ranks[i] = i; }
Expand All @@ -760,6 +775,34 @@ class ProcGroup {
return ga_pg;
}

#if !defined(USE_UPCXX)
/**
* Create a GA process group corresponding to the given proc group and parent group.
* @param pg TAMM process group
* @return GA process group on this TAMM process group
*/
static int create_ga_process_group_coll(const ProcGroup& parent_group, MPI_Comm comm) {
int nranks;
MPI_Comm_size(comm, &nranks);
MPI_Group group, group_world;
int ranks[nranks], ranks_world[nranks];
MPI_Comm_group(comm, &group);

MPI_Comm_group(parent_group.comm(), &group_world);

for(int i = 0; i < nranks; i++) { ranks[i] = i; }
MPI_Group_translate_ranks(group, nranks, ranks, group_world, ranks_world);

int ga_pg_default = GA_Pgroup_get_default();
GA_Pgroup_set_default(parent_group.ga_pg());
int ga_pg = GA_Pgroup_create(ranks_world, nranks);
GA_Pgroup_set_default(ga_pg_default);
MPI_Group_free(&group);
MPI_Group_free(&group_world);
return ga_pg;
}
#endif

/**
* @brief Swap contents of two given objects
*
Expand Down

0 comments on commit d005809

Please sign in to comment.