Skip to content

Commit

Permalink
separate function for cuda compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
RevathiJambunathan committed Mar 18, 2024
1 parent c6675cd commit c042680
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 62 deletions.
5 changes: 5 additions & 0 deletions Source/WarpX.H
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,11 @@ public:
// for cuda
void BuildBufferMasksInBox ( amrex::Box tbx, amrex::IArrayBox &buffer_mask,
const amrex::IArrayBox &guard_mask, int ng );

void InitInterpolationWeightsInBuffer( const amrex::Box tbx, amrex::FArrayBox &weights_gbuffer,
const amrex::IArrayBox & buffer_mask,
const amrex::IArrayBox &guard_mask, int ngbuffer,
bool do_interpolate, amrex::Real tanh_midpoint);
#ifdef AMREX_USE_EB
amrex::EBFArrayBoxFactory const& fieldEBFactory (int lev) const noexcept {
return static_cast<amrex::EBFArrayBoxFactory const&>(*m_field_factory[lev]);
Expand Down
136 changes: 74 additions & 62 deletions Source/WarpX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3107,68 +3107,8 @@ WarpX::BuildBufferMasks ()
for (MFIter mfi(*weight_gbuffer, true); mfi.isValid(); ++mfi)
{
const Box& tbx = mfi.tilebox(IntVect::TheNodeVector(),weight_gbuffer->nGrowVect());
auto const& gmsk = tmp[mfi].const_array();
auto const& bmsk = (*bmasks)[mfi].array();
auto const& wtmsk = (*weight_gbuffer)[mfi].array();
amrex::ParallelFor(tbx, [=] AMREX_GPU_DEVICE(int i, int j, int k) {
wtmsk(i,j,k) = 0._rt;
if (bmsk(i,j,k) == 0 && do_interpolate) {
if(gmsk(i,j,k)==0) {
wtmsk(i,j,k) = 0.;
return;
}
//for (int ii = i-1; ii >= i-ngbuffer; --ii) {
// if (gmsk(ii,j,k)==0) {
// amrex::Real arg = (static_cast<amrex::Real>(i-ii)-ngbuffer*tanh_midpoint)
// / ((1.-tanh_midpoint)*(ngbuffer/3.));
// wtmsk(i,j,k) = std::tanh(arg)*0.5 + 0.5;
// amrex::Print() << " i edge wt is " << wtmsk(i,j,k) << "\n";
// return;
// }
//}
//for (int ii = i+1; ii <= i+ngbuffer; ++ii) {
// if (gmsk(ii,j,k)==0) {
// amrex::Real arg = (static_cast<amrex::Real>(ii-i)-ngbuffer*tanh_midpoint)
// / ((1.-tanh_midpoint)*(ngbuffer/3.));
// wtmsk(i,j,k) = std::tanh(arg)*0.5+0.5;
// amrex::Print() << " wt is " << wtmsk(i,j,k) << "\n";
// return;
// }
//}
for (int jj = j-1; jj >= j-ngbuffer; --jj) {
if (gmsk(i,jj,k)==0) {
amrex::Real arg = (static_cast<amrex::Real>(j-jj)-ngbuffer*tanh_midpoint)
/ ((1.-tanh_midpoint)*(ngbuffer/3.));
wtmsk(i,j,k) = std::tanh(arg)*0.5 + 0.5;
return;
}
}
for (int jj = j+1; jj <= j+ngbuffer; ++jj) {
if (gmsk(i,jj,k)==0) {
amrex::Real arg = (static_cast<amrex::Real>(jj - j)-ngbuffer*tanh_midpoint)
/ ((1.-tanh_midpoint)*(ngbuffer/3.));
wtmsk(i,j,k) = std::tanh(arg)*0.5+0.5;
return;
}
}
//for (int kk = k-1; kk >= k-ngbuffer; --kk) {
// if (gmsk(i,j,kk)==0) {
// amrex::Real arg = (static_cast<amrex::Real>(k-kk)-ngbuffer*tanh_midpoint)
// / ((1.-tanh_midpoint)*(ngbuffer/3.));
// wtmsk(i,j,k) = std::tanh(arg)*0.5+0.5;
// return;
// }
//}
//for (int kk = k+1; kk <= k+ngbuffer; ++kk) {
// if (gmsk(i,j,kk)==0) {
// amrex::Real arg = (static_cast<amrex::Real>(kk-k)-ngbuffer*tanh_midpoint)
// / ((1.-tanh_midpoint)*(ngbuffer/3.));
// wtmsk(i,j,k) = std::tanh(arg)*0.5+0.5;
// return;
// }
//}
}
});
InitInterpolationWeightsInBuffer( tbx, (*weight_gbuffer)[mfi], (*bmasks)[mfi], tmp[mfi],
ngbuffer, do_interpolate, tanh_midpoint);
}
}
}
Expand Down Expand Up @@ -3206,6 +3146,78 @@ WarpX::BuildBufferMasksInBox ( const amrex::Box tbx, amrex::IArrayBox &buffer_ma
});
}


void
WarpX::InitInterpolationWeightsInBuffer( const amrex::Box tbx, amrex::FArrayBox &weights_gbuffer,
const amrex::IArrayBox & buffer_mask,
const amrex::IArrayBox &guard_mask, int ngbuffer,
bool do_interpolate, amrex::Real tanh_midpoint)
{
auto const& gmsk = guard_mask.const_array();
auto const& bmsk = buffer_mask.array();
auto const& wtmsk = weights_gbuffer.array();
amrex::ParallelFor(tbx, [=] AMREX_GPU_DEVICE(int i, int j, int k) {
wtmsk(i,j,k) = 0._rt;
if (bmsk(i,j,k) == 0 && do_interpolate) {
if(gmsk(i,j,k)==0) {
wtmsk(i,j,k) = 0.;
return;
}
//for (int ii = i-1; ii >= i-ngbuffer; --ii) {
// if (gmsk(ii,j,k)==0) {
// amrex::Real arg = (static_cast<amrex::Real>(i-ii)-ngbuffer*tanh_midpoint)
// / ((1.-tanh_midpoint)*(ngbuffer/3.));
// wtmsk(i,j,k) = std::tanh(arg)*0.5 + 0.5;
// amrex::Print() << " i edge wt is " << wtmsk(i,j,k) << "\n";
// return;
// }
//}
//for (int ii = i+1; ii <= i+ngbuffer; ++ii) {
// if (gmsk(ii,j,k)==0) {
// amrex::Real arg = (static_cast<amrex::Real>(ii-i)-ngbuffer*tanh_midpoint)
// / ((1.-tanh_midpoint)*(ngbuffer/3.));
// wtmsk(i,j,k) = std::tanh(arg)*0.5+0.5;
// amrex::Print() << " wt is " << wtmsk(i,j,k) << "\n";
// return;
// }
//}
for (int jj = j-1; jj >= j-ngbuffer; --jj) {
if (gmsk(i,jj,k)==0) {
amrex::Real arg = (static_cast<amrex::Real>(j-jj)-ngbuffer*tanh_midpoint)
/ ((1.-tanh_midpoint)*(ngbuffer/3.));
wtmsk(i,j,k) = std::tanh(arg)*0.5 + 0.5;
return;
}
}
for (int jj = j+1; jj <= j+ngbuffer; ++jj) {
if (gmsk(i,jj,k)==0) {
amrex::Real arg = (static_cast<amrex::Real>(jj - j)-ngbuffer*tanh_midpoint)
/ ((1.-tanh_midpoint)*(ngbuffer/3.));
wtmsk(i,j,k) = std::tanh(arg)*0.5+0.5;
return;
}
}
//for (int kk = k-1; kk >= k-ngbuffer; --kk) {
// if (gmsk(i,j,kk)==0) {
// amrex::Real arg = (static_cast<amrex::Real>(k-kk)-ngbuffer*tanh_midpoint)
// / ((1.-tanh_midpoint)*(ngbuffer/3.));
// wtmsk(i,j,k) = std::tanh(arg)*0.5+0.5;
// return;
// }
//}
//for (int kk = k+1; kk <= k+ngbuffer; ++kk) {
// if (gmsk(i,j,kk)==0) {
// amrex::Real arg = (static_cast<amrex::Real>(kk-k)-ngbuffer*tanh_midpoint)
// / ((1.-tanh_midpoint)*(ngbuffer/3.));
// wtmsk(i,j,k) = std::tanh(arg)*0.5+0.5;
// return;
// }
//}
}
});
}


amrex::Vector<amrex::Real> WarpX::getFornbergStencilCoefficients(const int n_order, const short a_grid_type)
{
AMREX_ALWAYS_ASSERT_WITH_MESSAGE(n_order % 2 == 0, "n_order must be even");
Expand Down

0 comments on commit c042680

Please sign in to comment.