From 57380f37423ce4efcdc95c24cb577520f87fac24 Mon Sep 17 00:00:00 2001 From: Weiqun Zhang Date: Fri, 8 Nov 2024 16:37:59 -0800 Subject: [PATCH] FFT: Add new domain decomposition strategy (#4221) Instead of pencil, it has the option of doing slab decomposition. This allows the x and y directions to be processed together without MPI communication. --- .github/workflows/apps.yml | 1 + .../dependencies/dependencies_hip.sh | 1 + Src/FFT/AMReX_FFT_Helper.H | 74 +++++++--- Src/FFT/AMReX_FFT_OpenBCSolver.H | 134 ++++++++++++----- Src/FFT/AMReX_FFT_Poisson.H | 41 ++++-- Src/FFT/AMReX_FFT_R2C.H | 137 +++++++++++++----- Src/FFT/AMReX_FFT_R2X.H | 14 +- Tests/FFT/R2C/main.cpp | 6 +- 8 files changed, 297 insertions(+), 111 deletions(-) diff --git a/.github/workflows/apps.yml b/.github/workflows/apps.yml index c8968b24f7..042b0d7c9d 100644 --- a/.github/workflows/apps.yml +++ b/.github/workflows/apps.yml @@ -114,6 +114,7 @@ jobs: runs-on: ubuntu-latest needs: check_changes if: needs.check_changes.outputs.has_non_docs_changes == 'true' + steps: - uses: actions/checkout@v4 - name: Checkout pyamrex uses: actions/checkout@v4 diff --git a/.github/workflows/dependencies/dependencies_hip.sh b/.github/workflows/dependencies/dependencies_hip.sh index df4f274ef3..6b69c5433a 100755 --- a/.github/workflows/dependencies/dependencies_hip.sh +++ b/.github/workflows/dependencies/dependencies_hip.sh @@ -40,6 +40,7 @@ echo 'export PATH=/opt/rocm/llvm/bin:/opt/rocm/bin:/opt/rocm/profiler/bin:/opt/r # we should not need to export HIP_PATH=/opt/rocm/hip with those installs +sudo apt-get clean sudo apt-get update # Ref.: https://rocmdocs.amd.com/en/latest/Installation_Guide/Installation-Guide.html#installing-development-packages-for-cross-compilation diff --git a/Src/FFT/AMReX_FFT_Helper.H b/Src/FFT/AMReX_FFT_Helper.H index be3c76ea4c..315e0641ac 100644 --- a/Src/FFT/AMReX_FFT_Helper.H +++ b/Src/FFT/AMReX_FFT_Helper.H @@ -34,6 +34,7 @@ #include #include +#include #include #include #include @@ -43,6 +44,8 @@ namespace amrex::FFT enum struct Direction { forward, backward, both, none }; +enum struct DomainStrategy { slab, pencil }; + AMREX_ENUM( Boundary, periodic, even, odd ); enum struct Kind { none, r2c_f, r2c_b, c2c_f, c2c_b, r2r_ee_f, r2r_ee_b, @@ -55,7 +58,11 @@ struct Info //! batch size. bool batch_mode = false; + //! Max number of processes to use + int nprocs = std::numeric_limits::max(); + Info& setBatchMode (bool x) { batch_mode = x; return *this; } + Info& setNumProcs (int n) { nprocs = n; return *this; } }; #ifdef AMREX_USE_HIP @@ -172,18 +179,34 @@ struct Plan } template - void init_r2c (Box const& box, T* pr, VendorComplex* pc) + void init_r2c (Box const& box, T* pr, VendorComplex* pc, bool is_2d_transform = false) { static_assert(D == Direction::forward || D == Direction::backward); + int rank = is_2d_transform ? 2 : 1; + kind = (D == Direction::forward) ? Kind::r2c_f : Kind::r2c_b; defined = true; pf = (void*)pr; pb = (void*)pc; - n = box.length(0); - int nc = (n/2) + 1; - howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2)); + int len[2] = {}; + if (rank == 1) { + len[0] = box.length(0); + len[1] = box.length(0); // Not used except for HIP. Yes it's `(0)`. + } else { + len[0] = box.length(1); // Most FFT libraries assume row-major ordering + len[1] = box.length(0); // except for rocfft + } + int nr = (rank == 1) ? len[0] : len[0]*len[1]; + n = nr; + int nc = (rank == 1) ? (len[0]/2+1) : (len[1]/2+1)*len[0]; +#if (AMREX_SPACEDIM == 1) + howmany = 1; +#else + howmany = (rank == 1) ? AMREX_D_TERM(1, *box.length(1), *box.length(2)) + : AMREX_D_TERM(1, *1 , *box.length(2)); +#endif amrex::ignore_unused(nc); @@ -193,33 +216,39 @@ struct Plan if constexpr (D == Direction::forward) { cufftType fwd_type = std::is_same_v ? CUFFT_R2C : CUFFT_D2Z; AMREX_CUFFT_SAFE_CALL - (cufftMakePlanMany(plan, 1, &n, nullptr, 1, n, nullptr, 1, nc, fwd_type, howmany, &work_size)); + (cufftMakePlanMany(plan, rank, len, nullptr, 1, nr, nullptr, 1, nc, fwd_type, howmany, &work_size)); AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream())); } else { cufftType bwd_type = std::is_same_v ? CUFFT_C2R : CUFFT_Z2D; AMREX_CUFFT_SAFE_CALL - (cufftMakePlanMany(plan, 1, &n, nullptr, 1, nc, nullptr, 1, n, bwd_type, howmany, &work_size)); + (cufftMakePlanMany(plan, rank, len, nullptr, 1, nc, nullptr, 1, nr, bwd_type, howmany, &work_size)); AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream())); } #elif defined(AMREX_USE_HIP) auto prec = std::is_same_v ? rocfft_precision_single : rocfft_precision_double; - const std::size_t length = n; + // switch to column-major ordering + std::size_t length[2] = {std::size_t(len[1]), std::size_t(len[0])}; if constexpr (D == Direction::forward) { AMREX_ROCFFT_SAFE_CALL (rocfft_plan_create(&plan, rocfft_placement_notinplace, - rocfft_transform_type_real_forward, prec, 1, - &length, howmany, nullptr)); + rocfft_transform_type_real_forward, prec, rank, + length, howmany, nullptr)); } else { AMREX_ROCFFT_SAFE_CALL (rocfft_plan_create(&plan, rocfft_placement_notinplace, - rocfft_transform_type_real_inverse, prec, 1, - &length, howmany, nullptr)); + rocfft_transform_type_real_inverse, prec, rank, + length, howmany, nullptr)); } #elif defined(AMREX_USE_SYCL) - auto* pp = new mkl_desc_r(n); + mkl_desc_r* pp; + if (rank == 1) { + pp = new mkl_desc_r(len[0]); + } else { + pp = new mkl_desc_r({std::int64_t(len[0]), std::int64_t(len[1])}); + } #ifndef AMREX_USE_MKL_DFTI_2024 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::NOT_INPLACE); @@ -227,9 +256,12 @@ struct Plan pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE); #endif pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany); - pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n); + pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nr); pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc); - std::vector strides = {0,1}; + std::vector strides; + strides.push_back(0); + if (rank == 2) { strides.push_back(len[1]); } + strides.push_back(1); #ifndef AMREX_USE_MKL_DFTI_2024 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides); pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides); @@ -247,21 +279,21 @@ struct Plan if constexpr (std::is_same_v) { if constexpr (D == Direction::forward) { plan = fftwf_plan_many_dft_r2c - (1, &n, howmany, pr, nullptr, 1, n, pc, nullptr, 1, nc, + (rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc, FFTW_ESTIMATE | FFTW_DESTROY_INPUT); } else { plan = fftwf_plan_many_dft_c2r - (1, &n, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, n, + (rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr, FFTW_ESTIMATE | FFTW_DESTROY_INPUT); } } else { if constexpr (D == Direction::forward) { plan = fftw_plan_many_dft_r2c - (1, &n, howmany, pr, nullptr, 1, n, pc, nullptr, 1, nc, + (rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc, FFTW_ESTIMATE | FFTW_DESTROY_INPUT); } else { plan = fftw_plan_many_dft_c2r - (1, &n, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, n, + (rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr, FFTW_ESTIMATE | FFTW_DESTROY_INPUT); } } @@ -1087,13 +1119,17 @@ namespace detail template std::unique_ptr make_mfs_share (FA1& fa1, FA2& fa2) { + bool not_same_fa = true; + if constexpr (std::is_same_v) { + not_same_fa = (&fa1 != &fa2); + } using FAB1 = typename FA1::FABType::value_type; using FAB2 = typename FA2::FABType::value_type; using T1 = typename FAB1::value_type; using T2 = typename FAB2::value_type; auto myproc = ParallelContext::MyProcSub(); bool alloc_1 = (myproc < fa1.size()); - bool alloc_2 = (myproc < fa2.size()); + bool alloc_2 = (myproc < fa2.size()) && not_same_fa; void* p = nullptr; if (alloc_1 && alloc_2) { Box const& box1 = fa1.fabbox(myproc); diff --git a/Src/FFT/AMReX_FFT_OpenBCSolver.H b/Src/FFT/AMReX_FFT_OpenBCSolver.H index 42dccad89d..1f75d18719 100644 --- a/Src/FFT/AMReX_FFT_OpenBCSolver.H +++ b/Src/FFT/AMReX_FFT_OpenBCSolver.H @@ -15,7 +15,7 @@ public: using MF = typename R2C::MF; using cMF = typename R2C::cMF; - explicit OpenBCSolver (Box const& domain); + explicit OpenBCSolver (Box const& domain, Info const& info = Info{}); template void setGreensFunction (F const& greens_function); @@ -25,26 +25,61 @@ public: [[nodiscard]] Box const& Domain () const { return m_domain; } private: + static Box make_grown_domain (Box const& domain, Info const& info); + Box m_domain; + Info m_info; R2C m_r2c; cMF m_G_fft; + std::unique_ptr> m_r2c_green; }; template -OpenBCSolver::OpenBCSolver (Box const& domain) +Box OpenBCSolver::make_grown_domain (Box const& domain, Info const& info) +{ + IntVect len = domain.length(); +#if (AMREX_SPACEDIM == 3) + if (info.batch_mode) { len[2] = 0; } +#else + amrex::ignore_unused(info); +#endif + return Box(domain.smallEnd(), domain.bigEnd()+len, domain.ixType()); +} + +template +OpenBCSolver::OpenBCSolver (Box const& domain, Info const& info) : m_domain(domain), - m_r2c(Box(domain.smallEnd(), domain.bigEnd()+domain.length(), domain.ixType())) + m_info(info), + m_r2c(OpenBCSolver::make_grown_domain(domain,info), info) { - auto [sd, ord] = m_r2c.getSpectralData(); - amrex::ignore_unused(ord); - m_G_fft.define(sd->boxArray(), sd->DistributionMap(), 1, 0); +#if (AMREX_SPACEDIM == 3) + if (m_info.batch_mode) { + auto gdom = make_grown_domain(domain,m_info); + gdom.enclosedCells(2); + gdom.setSmall(2, 0); + int nprocs = std::min({ParallelContext::NProcsSub(), + m_info.nprocs, + m_domain.length(2)}); + gdom.setBig(2, nprocs-1); + m_r2c_green = std::make_unique>(gdom,info); + auto [sd, ord] = m_r2c_green->getSpectralData(); + m_G_fft = cMF(*sd, amrex::make_alias, 0, 1); + } else +#endif + { + amrex::ignore_unused(m_r2c_green); + auto [sd, ord] = m_r2c.getSpectralData(); + amrex::ignore_unused(ord); + m_G_fft.define(sd->boxArray(), sd->DistributionMap(), 1, 0); + } } template template void OpenBCSolver::setGreensFunction (F const& greens_function) { - auto* infab = detail::get_fab(m_r2c.m_rx); + auto* infab = m_info.batch_mode ? detail::get_fab(m_r2c_green->m_rx) + : detail::get_fab(m_r2c.m_rx); auto const& lo = m_domain.smallEnd(); auto const& lo3 = lo.dim3(); auto const& len = m_domain.length3d(); @@ -52,7 +87,8 @@ void OpenBCSolver::setGreensFunction (F const& greens_function) auto const& a = infab->array(); auto box = infab->box(); GpuArray nimages{1,1,1}; - for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) { + int ndims = m_info.batch_mode ? AMREX_SPACEDIM : AMREX_SPACEDIM-1; + for (int idim = 0; idim < ndims; ++idim) { if (box.smallEnd(idim) == lo[idim] && box.length(idim) == 2*len[idim]) { box.growHi(idim, -len[idim]+1); // +1 to include the middle plane nimages[idim] = 2; @@ -62,46 +98,59 @@ void OpenBCSolver::setGreensFunction (F const& greens_function) box.shift(-lo); amrex::ParallelFor(box, [=] AMREX_GPU_DEVICE (int i, int j, int k) { + T G; if (i == len[0] || j == len[1] || k == len[2]) { - a(i+lo3.x,j+lo3.y,k+lo3.z) = T(0); + G = 0; } else { auto ii = i; auto jj = (j > len[1]) ? 2*len[1]-j : j; auto kk = (k > len[2]) ? 2*len[2]-k : k; - auto G = greens_function(ii+lo3.x,jj+lo3.y,kk+lo3.z); - for (int koff = 0; koff < nimages[2]; ++koff) { - int k2 = (koff == 0) ? k : 2*len[2]-k; - if (k2 == 2*len[2]) { continue; } - for (int joff = 0; joff < nimages[1]; ++joff) { - int j2 = (joff == 0) ? j : 2*len[1]-j; - if (j2 == 2*len[1]) { continue; } - for (int ioff = 0; ioff < nimages[0]; ++ioff) { - int i2 = (ioff == 0) ? i : 2*len[0]-i; - if (i2 == 2*len[0]) { continue; } - a(i2+lo3.x,j2+lo3.y,k2+lo3.z) = G; + G = greens_function(ii+lo3.x,jj+lo3.y,kk+lo3.z); + } + for (int koff = 0; koff < nimages[2]; ++koff) { + int k2 = (koff == 0) ? k : 2*len[2]-k; + if ((k2 == 2*len[2]) || (koff == 1 && k == len[2])) { + continue; + } + for (int joff = 0; joff < nimages[1]; ++joff) { + int j2 = (joff == 0) ? j : 2*len[1]-j; + if ((j2 == 2*len[1]) || (joff == 1 && j == len[1])) { + continue; + } + for (int ioff = 0; ioff < nimages[0]; ++ioff) { + int i2 = (ioff == 0) ? i : 2*len[0]-i; + if ((i2 == 2*len[0]) || (ioff == 1 && i == len[0])) { + continue; } + a(i2+lo3.x,j2+lo3.y,k2+lo3.z) = G; } } } }); } - m_r2c.forward(m_r2c.m_rx); + if (m_info.batch_mode) { + m_r2c_green->forward(m_r2c_green->m_rx); + } else { + m_r2c.forward(m_r2c.m_rx); + } - auto [sd, ord] = m_r2c.getSpectralData(); - amrex::ignore_unused(ord); - auto const* srcfab = detail::get_fab(*sd); - if (srcfab) { - auto* dstfab = detail::get_fab(m_G_fft); - if (dstfab) { + if (!m_info.batch_mode) { + auto [sd, ord] = m_r2c.getSpectralData(); + amrex::ignore_unused(ord); + auto const* srcfab = detail::get_fab(*sd); + if (srcfab) { + auto* dstfab = detail::get_fab(m_G_fft); + if (dstfab) { #if defined(AMREX_USE_GPU) - Gpu::dtod_memcpy_async + Gpu::dtod_memcpy_async #else - std::memcpy + std::memcpy #endif - (dstfab->dataPtr(), srcfab->dataPtr(), dstfab->nBytes()); - } else { - amrex::Abort("FFT::OpenBCSolver: how did this happen"); + (dstfab->dataPtr(), srcfab->dataPtr(), dstfab->nBytes()); + } else { + amrex::Abort("FFT::OpenBCSolver: how did this happen"); + } } } } @@ -115,7 +164,7 @@ void OpenBCSolver::solve (MF& phi, MF const& rho) m_r2c.forward(inmf); - auto scaling_factor = T(1) / T(m_r2c.m_real_domain.numPts()); + auto scaling_factor = m_r2c.scalingFactor(); auto const* gfab = detail::get_fab(m_G_fft); if (gfab) { @@ -125,9 +174,24 @@ void OpenBCSolver::solve (MF& phi, MF const& rho) if (rhofab) { auto* pdst = rhofab->dataPtr(); auto const* psrc = gfab->dataPtr(); - amrex::ParallelFor(rhofab->box().numPts(), [=] AMREX_GPU_DEVICE (Long i) + Box const& rhobox = rhofab->box(); +#if (AMREX_SPACEDIM == 3) + Long leng = gfab->box().numPts(); + if (m_info.batch_mode) { + AMREX_ASSERT(gfab->box().length(2) == 1 && + leng == (rhobox.length(0) * rhobox.length(1))); + } else { + AMREX_ASSERT(leng == rhobox.numPts()); + } +#endif + amrex::ParallelFor(rhobox.numPts(), [=] AMREX_GPU_DEVICE (Long i) { - pdst[i] *= psrc[i] * scaling_factor; +#if (AMREX_SPACEDIM == 3) + Long isrc = i % leng; +#else + Long isrc = i; +#endif + pdst[i] *= psrc[isrc] * scaling_factor; }); } else { amrex::Abort("FFT::OpenBCSolver::solve: how did this happen?"); diff --git a/Src/FFT/AMReX_FFT_Poisson.H b/Src/FFT/AMReX_FFT_Poisson.H index c50f26f66e..8ab467cc54 100644 --- a/Src/FFT/AMReX_FFT_Poisson.H +++ b/Src/FFT/AMReX_FFT_Poisson.H @@ -19,18 +19,33 @@ public: template ,int> = 0> Poisson (Geometry const& geom, Array,AMREX_SPACEDIM> const& bc) - : m_geom(geom), m_bc(bc), m_r2x(geom.Domain(),bc) - {} + : m_geom(geom), m_bc(bc) + { + bool all_periodic = true; + for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) { + all_periodic = all_periodic + && (bc[idim].first == Boundary::periodic) + && (bc[idim].second == Boundary::periodic); + } + if (all_periodic) { + m_r2c = std::make_unique>(m_geom.Domain()); + } else { + m_r2x = std::make_unique> (m_geom.Domain(), m_bc); + } + } template ,int> = 0> explicit Poisson (Geometry const& geom) : m_geom(geom), m_bc{AMREX_D_DECL(std::make_pair(Boundary::periodic,Boundary::periodic), std::make_pair(Boundary::periodic,Boundary::periodic), - std::make_pair(Boundary::periodic,Boundary::periodic))}, - m_r2x(geom.Domain(),m_bc) + std::make_pair(Boundary::periodic,Boundary::periodic))} { - AMREX_ALWAYS_ASSERT(m_geom.isAllPeriodic()); + if (m_geom.isAllPeriodic()) { + m_r2c = std::make_unique>(m_geom.Domain()); + } else { + amrex::Abort("FFT::Poisson: wrong BC"); + } } void solve (MF& soln, MF const& rhs); @@ -38,7 +53,8 @@ public: private: Geometry m_geom; Array,AMREX_SPACEDIM> m_bc; - R2X m_r2x; + std::unique_ptr> m_r2x; + std::unique_ptr> m_r2c; }; #if (AMREX_SPACEDIM == 3) @@ -114,7 +130,7 @@ void Poisson::solve (MF& soln, MF const& rhs) {AMREX_D_DECL(T(2)/T(m_geom.CellSize(0)*m_geom.CellSize(0)), T(2)/T(m_geom.CellSize(1)*m_geom.CellSize(1)), T(2)/T(m_geom.CellSize(2)*m_geom.CellSize(2)))}; - auto scale = m_r2x.scalingFactor(); + auto scale = (m_r2x) ? m_r2x->scalingFactor() : m_r2c->scalingFactor(); GpuArray offset{AMREX_D_DECL(T(0),T(0),T(0))}; // Not sure about odd-even and even-odd yet @@ -133,8 +149,7 @@ void Poisson::solve (MF& soln, MF const& rhs) } } - m_r2x.forwardThenBackward(rhs, soln, - [=] AMREX_GPU_DEVICE (int i, int j, int k, auto& spectral_data) + auto f = [=] AMREX_GPU_DEVICE (int i, int j, int k, auto& spectral_data) { amrex::ignore_unused(j,k); AMREX_D_TERM(T a = fac[0]*(i+offset[0]);, @@ -147,7 +162,13 @@ void Poisson::solve (MF& soln, MF const& rhs) spectral_data /= k2; } spectral_data *= scale; - }); + }; + + if (m_r2x) { + m_r2x->forwardThenBackward(rhs, soln, f); + } else { + m_r2c->forwardThenBackward(rhs, soln, f); + } } #if (AMREX_SPACEDIM == 3) diff --git a/Src/FFT/AMReX_FFT_R2C.H b/Src/FFT/AMReX_FFT_R2C.H index 5969c6a789..aaa5fac4c3 100644 --- a/Src/FFT/AMReX_FFT_R2C.H +++ b/Src/FFT/AMReX_FFT_R2C.H @@ -4,6 +4,7 @@ #include #include +#include #include #include @@ -26,7 +27,9 @@ template class OpenBCSolver; * For more details, we refer the users to * https://amrex-codes.github.io/amrex/docs_html/FFT_Chapter.html. */ -template +template + // Don't change the default. Otherwise OpenBCSolver might break. class R2C { public: @@ -130,6 +133,11 @@ public: DIR == Direction::both, int> = 0> void backward (cMF const& inmf, MF& outmf); + //! Scaling factor. If the data goes through forward and then backward, + //! the result multiplied by the scaling factor is equal to the original + //! data. + [[nodiscard]] T scalingFactor () const; + /** * \brief Get the internal spectral data * @@ -176,18 +184,22 @@ private: std::unique_ptr m_cmd_y2x; // (y,x,z) -> (x,y,z) std::unique_ptr m_cmd_y2z; // (y,x,z) -> (z,x,y) std::unique_ptr m_cmd_z2y; // (z,x,y) -> (y,x,z) + std::unique_ptr m_cmd_x2z; // (x,y,z) -> (z,x,y) + std::unique_ptr m_cmd_z2x; // (z,x,y) -> (x,y,z) Swap01 m_dtos_x2y{}; Swap01 m_dtos_y2x{}; Swap02 m_dtos_y2z{}; Swap02 m_dtos_z2y{}; + RotateFwd m_dtos_x2z{}; + RotateBwd m_dtos_z2x{}; MF m_rx; cMF m_cx; cMF m_cy; cMF m_cz; - std::unique_ptr m_data_rx_cy; - std::unique_ptr m_data_cx_cz; + std::unique_ptr m_data_1; + std::unique_ptr m_data_2; Box m_real_domain; Box m_spectral_domain_x; @@ -195,10 +207,12 @@ private: Box m_spectral_domain_z; Info m_info; + + bool m_slab_decomp = false; }; -template -R2C::R2C (Box const& domain, Info const& info) +template +R2C::R2C (Box const& domain, Info const& info) : m_real_domain(domain), m_spectral_domain_x(IntVect(0), IntVect(AMREX_D_DECL(domain.length(0)/2, domain.length(1)-1, @@ -230,14 +244,20 @@ R2C::R2C (Box const& domain, Info const& info) #endif int myproc = ParallelContext::MyProcSub(); - int nprocs = ParallelContext::NProcsSub(); + int nprocs = std::min(ParallelContext::NProcsSub(), m_info.nprocs); + +#if (AMREX_SPACEDIM == 3) + if (S == DomainStrategy::slab && (m_real_domain.length(1) > 1)) { + m_slab_decomp = true; + } +#endif // // make data containers // auto bax = amrex::decompose(m_real_domain, nprocs, - {AMREX_D_DECL(false,true,true)}, true); + {AMREX_D_DECL(false,!m_slab_decomp,true)}, true); DistributionMapping dmx = detail::make_iota_distromap(bax.size()); m_rx.define(bax, dmx, 1, 0, MFInfo().SetAlloc(false)); @@ -253,7 +273,7 @@ R2C::R2C (Box const& domain, Info const& info) #if (AMREX_SPACEDIM >= 2) DistributionMapping cdmy; - if (m_real_domain.length(1) > 1) { + if ((m_real_domain.length(1) > 1) && !m_slab_decomp) { auto cbay = amrex::decompose(m_spectral_domain_y, nprocs, {AMREX_D_DECL(false,true,true)}, true); if (cbay.size() == dmx.size()) { @@ -283,8 +303,13 @@ R2C::R2C (Box const& domain, Info const& info) } #endif - m_data_rx_cy = detail::make_mfs_share(m_rx, m_cy); - m_data_cx_cz = detail::make_mfs_share(m_cx, m_cz); + if (m_slab_decomp) { + m_data_1 = detail::make_mfs_share(m_rx, m_cz); + m_data_2 = detail::make_mfs_share(m_cx, m_cx); + } else { + m_data_1 = detail::make_mfs_share(m_rx, m_cy); + m_data_2 = detail::make_mfs_share(m_cx, m_cz); + } // // make copiers @@ -300,12 +325,20 @@ R2C::R2C (Box const& domain, Info const& info) } #endif #if (AMREX_SPACEDIM == 3) - if (! m_cz.empty()) { - // comm meta-data between y and z phases - m_cmd_y2z = std::make_unique - (m_cz, m_spectral_domain_z, m_cy, IntVect(0), m_dtos_y2z); - m_cmd_z2y = std::make_unique - (m_cy, m_spectral_domain_y, m_cz, IntVect(0), m_dtos_z2y); + if (! m_cz.empty() ) { + if (m_slab_decomp) { + // comm meta-data between xy and z phases + m_cmd_x2z = std::make_unique + (m_cz, m_spectral_domain_z, m_cx, IntVect(0), m_dtos_x2z); + m_cmd_z2x = std::make_unique + (m_cx, m_spectral_domain_x, m_cz, IntVect(0), m_dtos_z2x); + } else { + // comm meta-data between y and z phases + m_cmd_y2z = std::make_unique + (m_cz, m_spectral_domain_z, m_cy, IntVect(0), m_dtos_y2z); + m_cmd_z2y = std::make_unique + (m_cy, m_spectral_domain_y, m_cz, IntVect(0), m_dtos_z2y); + } } #endif @@ -319,14 +352,14 @@ R2C::R2C (Box const& domain, Info const& info) auto* pr = m_rx[myproc].dataPtr(); auto* pc = (typename Plan::VendorComplex *)m_cx[myproc].dataPtr(); #ifdef AMREX_USE_SYCL - m_fft_fwd_x.template init_r2c(box, pr, pc); + m_fft_fwd_x.template init_r2c(box, pr, pc, m_slab_decomp); m_fft_bwd_x = m_fft_fwd_x; #else if constexpr (D == Direction::both || D == Direction::forward) { - m_fft_fwd_x.template init_r2c(box, pr, pc); + m_fft_fwd_x.template init_r2c(box, pr, pc, m_slab_decomp); } if constexpr (D == Direction::both || D == Direction::backward) { - m_fft_bwd_x.template init_r2c(box, pr, pc); + m_fft_bwd_x.template init_r2c(box, pr, pc, m_slab_decomp); } #endif } @@ -343,8 +376,8 @@ R2C::R2C (Box const& domain, Info const& info) #endif } -template -R2C::~R2C () +template +R2C::~R2C () { if (m_fft_bwd_x.plan != m_fft_fwd_x.plan) { m_fft_bwd_x.destroy(); @@ -360,10 +393,10 @@ R2C::~R2C () m_fft_fwd_z.destroy(); } -template +template template > -void R2C::forward (MF const& inmf) +void R2C::forward (MF const& inmf) { BL_PROFILE("FFT::R2C::forward(in)"); @@ -380,18 +413,23 @@ void R2C::forward (MF const& inmf) if ( m_cmd_y2z) { ParallelCopy(m_cz, m_cy, *m_cmd_y2z, 0, 0, 1, m_dtos_y2z); } +#if (AMREX_SPACEDIM == 3) + else if ( m_cmd_x2z) { + ParallelCopy(m_cz, m_cx, *m_cmd_x2z, 0, 0, 1, m_dtos_x2z); + } +#endif m_fft_fwd_z.template compute_c2c(); } -template +template template > -void R2C::backward (MF& outmf) +void R2C::backward (MF& outmf) { backward_doit(outmf); } -template -void R2C::backward_doit (MF& outmf, IntVect const& ngout) +template +void R2C::backward_doit (MF& outmf, IntVect const& ngout) { BL_PROFILE("FFT::R2C::backward(out)"); @@ -399,6 +437,11 @@ void R2C::backward_doit (MF& outmf, IntVect const& ngout) if ( m_cmd_z2y) { ParallelCopy(m_cy, m_cz, *m_cmd_z2y, 0, 0, 1, m_dtos_z2y); } +#if (AMREX_SPACEDIM == 3) + else if ( m_cmd_z2x) { + ParallelCopy(m_cx, m_cz, *m_cmd_z2x, 0, 0, 1, m_dtos_z2x); + } +#endif m_fft_bwd_y.template compute_c2c(); if ( m_cmd_y2x) { @@ -409,9 +452,9 @@ void R2C::backward_doit (MF& outmf, IntVect const& ngout) outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), ngout); } -template +template std::pair, Plan> -R2C::make_c2c_plans (cMF& inout) +R2C::make_c2c_plans (cMF& inout) { Plan fwd; Plan bwd; @@ -437,9 +480,9 @@ R2C::make_c2c_plans (cMF& inout) return {fwd, bwd}; } -template +template template -void R2C::post_forward_doit (F const& post_forward) +void R2C::post_forward_doit (F const& post_forward) { if (m_info.batch_mode) { amrex::Abort("xxxxx todo: post_forward"); @@ -478,11 +521,25 @@ void R2C::post_forward_doit (F const& post_forward) } } -template +template +T R2C::scalingFactor () const +{ +#if (AMREX_SPACEDIM == 3) + if (m_info.batch_mode) { + return T(1)/T(Long(m_real_domain.length(0)) * + Long(m_real_domain.length(1))); + } else +#endif + { + return T(1)/T(m_real_domain.numPts()); + } +} + +template template > -std::pair::cMF *, IntVect> -R2C::getSpectralData () +std::pair::cMF *, IntVect> +R2C::getSpectralData () { if (!m_cz.empty()) { return std::make_pair(&m_cz, IntVect{AMREX_D_DECL(2,0,1)}); @@ -493,10 +550,10 @@ R2C::getSpectralData () } } -template +template template > -void R2C::forward (MF const& inmf, cMF& outmf) +void R2C::forward (MF const& inmf, cMF& outmf) { BL_PROFILE("FFT::R2C::forward(inout)"); @@ -515,10 +572,10 @@ void R2C::forward (MF const& inmf, cMF& outmf) } } -template +template template > -void R2C::backward (cMF const& inmf, MF& outmf) +void R2C::backward (cMF const& inmf, MF& outmf) { BL_PROFILE("FFT::R2C::backward(inout)"); @@ -537,9 +594,9 @@ void R2C::backward (cMF const& inmf, MF& outmf) backward_doit(outmf); } -template +template std::pair -R2C::getSpectralDataLayout () const +R2C::getSpectralDataLayout () const { #if (AMREX_SPACEDIM == 3) if (!m_cz.empty()) { diff --git a/Src/FFT/AMReX_FFT_R2X.H b/Src/FFT/AMReX_FFT_R2X.H index 7ff8805b03..5d916ada3c 100644 --- a/Src/FFT/AMReX_FFT_R2X.H +++ b/Src/FFT/AMReX_FFT_R2X.H @@ -4,6 +4,7 @@ #include #include +#include #include #include @@ -25,7 +26,8 @@ public: using cMF = FabArray > >; R2X (Box const& domain, - Array,AMREX_SPACEDIM> const& bc); + Array,AMREX_SPACEDIM> const& bc, + Info const& info = Info{}); ~R2X (); @@ -85,13 +87,17 @@ private: Box m_dom_cx; Box m_dom_cy; Box m_dom_cz; + + Info m_info; }; template R2X::R2X (Box const& domain, - Array,AMREX_SPACEDIM> const& bc) + Array,AMREX_SPACEDIM> const& bc, + Info const& info) : m_dom_0(domain), - m_bc(bc) + m_bc(bc), + m_info(info) { BL_PROFILE("FFT::R2X"); @@ -111,7 +117,7 @@ R2X::R2X (Box const& domain, } int myproc = ParallelContext::MyProcSub(); - int nprocs = ParallelContext::NProcsSub(); + int nprocs = std::min(ParallelContext::NProcsSub(), m_info.nprocs); // // make data containers diff --git a/Tests/FFT/R2C/main.cpp b/Tests/FFT/R2C/main.cpp index caa457a7f5..594a9ec760 100644 --- a/Tests/FFT/R2C/main.cpp +++ b/Tests/FFT/R2C/main.cpp @@ -74,7 +74,7 @@ int main (int argc, char* argv[]) // forward { - FFT::R2C r2c(geom.Domain()); + FFT::R2C r2c(geom.Domain()); auto const& [cba, cdm] = r2c.getSpectralDataLayout(); cmf.define(cba, cdm, 1, 0); r2c.forward(mf,cmf); @@ -82,7 +82,7 @@ int main (int argc, char* argv[]) // backward { - FFT::R2C r2c(geom.Domain()); + FFT::R2C r2c(geom.Domain()); r2c.backward(cmf,mf2); } @@ -105,7 +105,7 @@ int main (int argc, char* argv[]) mf2.setVal(std::numeric_limits::max()); { // forward and backward - FFT::R2C r2c(geom.Domain()); + FFT::R2C r2c(geom.Domain()); r2c.forwardThenBackward(mf, mf2, [=] AMREX_GPU_DEVICE (int, int, int, auto& sp) {