Skip to content

Commit

Permalink
FFT: Add new domain decomposition strategy (AMReX-Codes#4221)
Browse files Browse the repository at this point in the history
Instead of pencil, it has the option of doing slab decomposition. This
allows the x and y directions to be processed together without MPI
communication.
  • Loading branch information
WeiqunZhang authored Nov 9, 2024
1 parent 8e7bb00 commit 57380f3
Show file tree
Hide file tree
Showing 8 changed files with 297 additions and 111 deletions.
1 change: 1 addition & 0 deletions .github/workflows/apps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ jobs:
runs-on: ubuntu-latest
needs: check_changes
if: needs.check_changes.outputs.has_non_docs_changes == 'true'
steps:
- uses: actions/checkout@v4
- name: Checkout pyamrex
uses: actions/checkout@v4
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/dependencies/dependencies_hip.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ echo 'export PATH=/opt/rocm/llvm/bin:/opt/rocm/bin:/opt/rocm/profiler/bin:/opt/r

# we should not need to export HIP_PATH=/opt/rocm/hip with those installs

sudo apt-get clean
sudo apt-get update

# Ref.: https://rocmdocs.amd.com/en/latest/Installation_Guide/Installation-Guide.html#installing-development-packages-for-cross-compilation
Expand Down
74 changes: 55 additions & 19 deletions Src/FFT/AMReX_FFT_Helper.H
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

#include <algorithm>
#include <complex>
#include <limits>
#include <memory>
#include <utility>
#include <variant>
Expand All @@ -43,6 +44,8 @@ namespace amrex::FFT

enum struct Direction { forward, backward, both, none };

enum struct DomainStrategy { slab, pencil };

AMREX_ENUM( Boundary, periodic, even, odd );

enum struct Kind { none, r2c_f, r2c_b, c2c_f, c2c_b, r2r_ee_f, r2r_ee_b,
Expand All @@ -55,7 +58,11 @@ struct Info
//! batch size.
bool batch_mode = false;

//! Max number of processes to use
int nprocs = std::numeric_limits<int>::max();

Info& setBatchMode (bool x) { batch_mode = x; return *this; }
Info& setNumProcs (int n) { nprocs = n; return *this; }
};

#ifdef AMREX_USE_HIP
Expand Down Expand Up @@ -172,18 +179,34 @@ struct Plan
}

template <Direction D>
void init_r2c (Box const& box, T* pr, VendorComplex* pc)
void init_r2c (Box const& box, T* pr, VendorComplex* pc, bool is_2d_transform = false)
{
static_assert(D == Direction::forward || D == Direction::backward);

int rank = is_2d_transform ? 2 : 1;

kind = (D == Direction::forward) ? Kind::r2c_f : Kind::r2c_b;
defined = true;
pf = (void*)pr;
pb = (void*)pc;

n = box.length(0);
int nc = (n/2) + 1;
howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
int len[2] = {};
if (rank == 1) {
len[0] = box.length(0);
len[1] = box.length(0); // Not used except for HIP. Yes it's `(0)`.
} else {
len[0] = box.length(1); // Most FFT libraries assume row-major ordering
len[1] = box.length(0); // except for rocfft
}
int nr = (rank == 1) ? len[0] : len[0]*len[1];
n = nr;
int nc = (rank == 1) ? (len[0]/2+1) : (len[1]/2+1)*len[0];
#if (AMREX_SPACEDIM == 1)
howmany = 1;
#else
howmany = (rank == 1) ? AMREX_D_TERM(1, *box.length(1), *box.length(2))
: AMREX_D_TERM(1, *1 , *box.length(2));
#endif

amrex::ignore_unused(nc);

Expand All @@ -193,43 +216,52 @@ struct Plan
if constexpr (D == Direction::forward) {
cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
AMREX_CUFFT_SAFE_CALL
(cufftMakePlanMany(plan, 1, &n, nullptr, 1, n, nullptr, 1, nc, fwd_type, howmany, &work_size));
(cufftMakePlanMany(plan, rank, len, nullptr, 1, nr, nullptr, 1, nc, fwd_type, howmany, &work_size));
AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
} else {
cufftType bwd_type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
AMREX_CUFFT_SAFE_CALL
(cufftMakePlanMany(plan, 1, &n, nullptr, 1, nc, nullptr, 1, n, bwd_type, howmany, &work_size));
(cufftMakePlanMany(plan, rank, len, nullptr, 1, nc, nullptr, 1, nr, bwd_type, howmany, &work_size));
AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
}
#elif defined(AMREX_USE_HIP)

auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
const std::size_t length = n;
// switch to column-major ordering
std::size_t length[2] = {std::size_t(len[1]), std::size_t(len[0])};
if constexpr (D == Direction::forward) {
AMREX_ROCFFT_SAFE_CALL
(rocfft_plan_create(&plan, rocfft_placement_notinplace,
rocfft_transform_type_real_forward, prec, 1,
&length, howmany, nullptr));
rocfft_transform_type_real_forward, prec, rank,
length, howmany, nullptr));
} else {
AMREX_ROCFFT_SAFE_CALL
(rocfft_plan_create(&plan, rocfft_placement_notinplace,
rocfft_transform_type_real_inverse, prec, 1,
&length, howmany, nullptr));
rocfft_transform_type_real_inverse, prec, rank,
length, howmany, nullptr));
}

#elif defined(AMREX_USE_SYCL)

auto* pp = new mkl_desc_r(n);
mkl_desc_r* pp;
if (rank == 1) {
pp = new mkl_desc_r(len[0]);
} else {
pp = new mkl_desc_r({std::int64_t(len[0]), std::int64_t(len[1])});
}
#ifndef AMREX_USE_MKL_DFTI_2024
pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::NOT_INPLACE);
#else
pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
#endif
pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nr);
pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
std::vector<std::int64_t> strides = {0,1};
std::vector<std::int64_t> strides;
strides.push_back(0);
if (rank == 2) { strides.push_back(len[1]); }
strides.push_back(1);
#ifndef AMREX_USE_MKL_DFTI_2024
pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
Expand All @@ -247,21 +279,21 @@ struct Plan
if constexpr (std::is_same_v<float,T>) {
if constexpr (D == Direction::forward) {
plan = fftwf_plan_many_dft_r2c
(1, &n, howmany, pr, nullptr, 1, n, pc, nullptr, 1, nc,
(rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc,
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
} else {
plan = fftwf_plan_many_dft_c2r
(1, &n, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, n,
(rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr,
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
}
} else {
if constexpr (D == Direction::forward) {
plan = fftw_plan_many_dft_r2c
(1, &n, howmany, pr, nullptr, 1, n, pc, nullptr, 1, nc,
(rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc,
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
} else {
plan = fftw_plan_many_dft_c2r
(1, &n, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, n,
(rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr,
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
}
}
Expand Down Expand Up @@ -1087,13 +1119,17 @@ namespace detail
template <typename FA1, typename FA2>
std::unique_ptr<char,DataDeleter> make_mfs_share (FA1& fa1, FA2& fa2)
{
bool not_same_fa = true;
if constexpr (std::is_same_v<FA1,FA2>) {
not_same_fa = (&fa1 != &fa2);
}
using FAB1 = typename FA1::FABType::value_type;
using FAB2 = typename FA2::FABType::value_type;
using T1 = typename FAB1::value_type;
using T2 = typename FAB2::value_type;
auto myproc = ParallelContext::MyProcSub();
bool alloc_1 = (myproc < fa1.size());
bool alloc_2 = (myproc < fa2.size());
bool alloc_2 = (myproc < fa2.size()) && not_same_fa;
void* p = nullptr;
if (alloc_1 && alloc_2) {
Box const& box1 = fa1.fabbox(myproc);
Expand Down
134 changes: 99 additions & 35 deletions Src/FFT/AMReX_FFT_OpenBCSolver.H
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public:
using MF = typename R2C<T>::MF;
using cMF = typename R2C<T>::cMF;

explicit OpenBCSolver (Box const& domain);
explicit OpenBCSolver (Box const& domain, Info const& info = Info{});

template <class F>
void setGreensFunction (F const& greens_function);
Expand All @@ -25,34 +25,70 @@ public:
[[nodiscard]] Box const& Domain () const { return m_domain; }

private:
static Box make_grown_domain (Box const& domain, Info const& info);

Box m_domain;
Info m_info;
R2C<T> m_r2c;
cMF m_G_fft;
std::unique_ptr<R2C<T>> m_r2c_green;
};

template <typename T>
OpenBCSolver<T>::OpenBCSolver (Box const& domain)
Box OpenBCSolver<T>::make_grown_domain (Box const& domain, Info const& info)
{
IntVect len = domain.length();
#if (AMREX_SPACEDIM == 3)
if (info.batch_mode) { len[2] = 0; }
#else
amrex::ignore_unused(info);
#endif
return Box(domain.smallEnd(), domain.bigEnd()+len, domain.ixType());
}

template <typename T>
OpenBCSolver<T>::OpenBCSolver (Box const& domain, Info const& info)
: m_domain(domain),
m_r2c(Box(domain.smallEnd(), domain.bigEnd()+domain.length(), domain.ixType()))
m_info(info),
m_r2c(OpenBCSolver<T>::make_grown_domain(domain,info), info)
{
auto [sd, ord] = m_r2c.getSpectralData();
amrex::ignore_unused(ord);
m_G_fft.define(sd->boxArray(), sd->DistributionMap(), 1, 0);
#if (AMREX_SPACEDIM == 3)
if (m_info.batch_mode) {
auto gdom = make_grown_domain(domain,m_info);
gdom.enclosedCells(2);
gdom.setSmall(2, 0);
int nprocs = std::min({ParallelContext::NProcsSub(),
m_info.nprocs,
m_domain.length(2)});
gdom.setBig(2, nprocs-1);
m_r2c_green = std::make_unique<R2C<T>>(gdom,info);
auto [sd, ord] = m_r2c_green->getSpectralData();
m_G_fft = cMF(*sd, amrex::make_alias, 0, 1);
} else
#endif
{
amrex::ignore_unused(m_r2c_green);
auto [sd, ord] = m_r2c.getSpectralData();
amrex::ignore_unused(ord);
m_G_fft.define(sd->boxArray(), sd->DistributionMap(), 1, 0);
}
}

template <typename T>
template <class F>
void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
{
auto* infab = detail::get_fab(m_r2c.m_rx);
auto* infab = m_info.batch_mode ? detail::get_fab(m_r2c_green->m_rx)
: detail::get_fab(m_r2c.m_rx);
auto const& lo = m_domain.smallEnd();
auto const& lo3 = lo.dim3();
auto const& len = m_domain.length3d();
if (infab) {
auto const& a = infab->array();
auto box = infab->box();
GpuArray<int,3> nimages{1,1,1};
for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) {
int ndims = m_info.batch_mode ? AMREX_SPACEDIM : AMREX_SPACEDIM-1;
for (int idim = 0; idim < ndims; ++idim) {
if (box.smallEnd(idim) == lo[idim] && box.length(idim) == 2*len[idim]) {
box.growHi(idim, -len[idim]+1); // +1 to include the middle plane
nimages[idim] = 2;
Expand All @@ -62,46 +98,59 @@ void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
box.shift(-lo);
amrex::ParallelFor(box, [=] AMREX_GPU_DEVICE (int i, int j, int k)
{
T G;
if (i == len[0] || j == len[1] || k == len[2]) {
a(i+lo3.x,j+lo3.y,k+lo3.z) = T(0);
G = 0;
} else {
auto ii = i;
auto jj = (j > len[1]) ? 2*len[1]-j : j;
auto kk = (k > len[2]) ? 2*len[2]-k : k;
auto G = greens_function(ii+lo3.x,jj+lo3.y,kk+lo3.z);
for (int koff = 0; koff < nimages[2]; ++koff) {
int k2 = (koff == 0) ? k : 2*len[2]-k;
if (k2 == 2*len[2]) { continue; }
for (int joff = 0; joff < nimages[1]; ++joff) {
int j2 = (joff == 0) ? j : 2*len[1]-j;
if (j2 == 2*len[1]) { continue; }
for (int ioff = 0; ioff < nimages[0]; ++ioff) {
int i2 = (ioff == 0) ? i : 2*len[0]-i;
if (i2 == 2*len[0]) { continue; }
a(i2+lo3.x,j2+lo3.y,k2+lo3.z) = G;
G = greens_function(ii+lo3.x,jj+lo3.y,kk+lo3.z);
}
for (int koff = 0; koff < nimages[2]; ++koff) {
int k2 = (koff == 0) ? k : 2*len[2]-k;
if ((k2 == 2*len[2]) || (koff == 1 && k == len[2])) {
continue;
}
for (int joff = 0; joff < nimages[1]; ++joff) {
int j2 = (joff == 0) ? j : 2*len[1]-j;
if ((j2 == 2*len[1]) || (joff == 1 && j == len[1])) {
continue;
}
for (int ioff = 0; ioff < nimages[0]; ++ioff) {
int i2 = (ioff == 0) ? i : 2*len[0]-i;
if ((i2 == 2*len[0]) || (ioff == 1 && i == len[0])) {
continue;
}
a(i2+lo3.x,j2+lo3.y,k2+lo3.z) = G;
}
}
}
});
}

m_r2c.forward(m_r2c.m_rx);
if (m_info.batch_mode) {
m_r2c_green->forward(m_r2c_green->m_rx);
} else {
m_r2c.forward(m_r2c.m_rx);
}

auto [sd, ord] = m_r2c.getSpectralData();
amrex::ignore_unused(ord);
auto const* srcfab = detail::get_fab(*sd);
if (srcfab) {
auto* dstfab = detail::get_fab(m_G_fft);
if (dstfab) {
if (!m_info.batch_mode) {
auto [sd, ord] = m_r2c.getSpectralData();
amrex::ignore_unused(ord);
auto const* srcfab = detail::get_fab(*sd);
if (srcfab) {
auto* dstfab = detail::get_fab(m_G_fft);
if (dstfab) {
#if defined(AMREX_USE_GPU)
Gpu::dtod_memcpy_async
Gpu::dtod_memcpy_async
#else
std::memcpy
std::memcpy
#endif
(dstfab->dataPtr(), srcfab->dataPtr(), dstfab->nBytes());
} else {
amrex::Abort("FFT::OpenBCSolver: how did this happen");
(dstfab->dataPtr(), srcfab->dataPtr(), dstfab->nBytes());
} else {
amrex::Abort("FFT::OpenBCSolver: how did this happen");
}
}
}
}
Expand All @@ -115,7 +164,7 @@ void OpenBCSolver<T>::solve (MF& phi, MF const& rho)

m_r2c.forward(inmf);

auto scaling_factor = T(1) / T(m_r2c.m_real_domain.numPts());
auto scaling_factor = m_r2c.scalingFactor();

auto const* gfab = detail::get_fab(m_G_fft);
if (gfab) {
Expand All @@ -125,9 +174,24 @@ void OpenBCSolver<T>::solve (MF& phi, MF const& rho)
if (rhofab) {
auto* pdst = rhofab->dataPtr();
auto const* psrc = gfab->dataPtr();
amrex::ParallelFor(rhofab->box().numPts(), [=] AMREX_GPU_DEVICE (Long i)
Box const& rhobox = rhofab->box();
#if (AMREX_SPACEDIM == 3)
Long leng = gfab->box().numPts();
if (m_info.batch_mode) {
AMREX_ASSERT(gfab->box().length(2) == 1 &&
leng == (rhobox.length(0) * rhobox.length(1)));
} else {
AMREX_ASSERT(leng == rhobox.numPts());
}
#endif
amrex::ParallelFor(rhobox.numPts(), [=] AMREX_GPU_DEVICE (Long i)
{
pdst[i] *= psrc[i] * scaling_factor;
#if (AMREX_SPACEDIM == 3)
Long isrc = i % leng;
#else
Long isrc = i;
#endif
pdst[i] *= psrc[isrc] * scaling_factor;
});
} else {
amrex::Abort("FFT::OpenBCSolver::solve: how did this happen?");
Expand Down
Loading

0 comments on commit 57380f3

Please sign in to comment.