Skip to content

Commit

Permalink
HIP compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanstocks00 committed Jul 18, 2024
1 parent ef88e16 commit 616615f
Showing 1 changed file with 315 additions and 0 deletions.
315 changes: 315 additions & 0 deletions src/hip/builtin.hip
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,209 @@ __global__ GGA_EXC_VXC_INC_GENERATOR( device_eval_exc_vxc_inc_helper_polar_kerne

}

template <typename KernelType>
__global__ MGGA_EXC_GENERATOR( device_eval_exc_helper_unpolar_kernel ) {

using traits = kernel_traits<KernelType>;
int tid = threadIdx.x + blockIdx.x * blockDim.x;

if( tid < N ) {

const double lapl_use = traits::needs_laplacian ? lapl[tid] : 0.0;
traits::eval_exc_unpolar( rho[tid], sigma[tid], lapl_use, tau[tid], eps[tid] );

}

}


template <typename KernelType>
__global__ MGGA_EXC_GENERATOR( device_eval_exc_helper_polar_kernel ) {

using traits = kernel_traits<KernelType>;
int tid = threadIdx.x + blockIdx.x * blockDim.x;

if( tid < N ) {

auto* rho_i = rho + 2*tid;
auto* sigma_i = sigma + 3*tid;
auto* lapl_i = traits::needs_laplacian ? (lapl + 2*tid) : nullptr;
auto* tau_i = tau + 2*tid;

const double lapl_a_use = traits::needs_laplacian ? lapl_i[0] : 0.0;
const double lapl_b_use = traits::needs_laplacian ? lapl_i[1] : 0.0;

traits::eval_exc_polar( rho_i[0], rho_i[1], sigma_i[0],
sigma_i[1], sigma_i[2], lapl_a_use, lapl_b_use, tau_i[0],
tau_i[1], eps[tid] );

}

}

template <typename KernelType>
__global__ MGGA_EXC_VXC_GENERATOR( device_eval_exc_vxc_helper_unpolar_kernel ) {

using traits = kernel_traits<KernelType>;
int tid = threadIdx.x + blockIdx.x * blockDim.x;

if( tid < N ) {

const double lapl_use = traits::needs_laplacian ? lapl[tid] : 0.0;

double dummy;
auto& vlapl_return = traits::needs_laplacian ? vlapl[tid] : dummy;
traits::eval_exc_vxc_unpolar( rho[tid], sigma[tid], lapl_use, tau[tid],
eps[tid], vrho[tid], vsigma[tid], vlapl_return, vtau[tid] );

}

}

template <typename KernelType>
__global__ MGGA_EXC_VXC_GENERATOR( device_eval_exc_vxc_helper_polar_kernel ) {

using traits = kernel_traits<KernelType>;
int tid = threadIdx.x + blockIdx.x * blockDim.x;

double dummy_vlapl[2];

if( tid < N ) {

auto* rho_i = rho + 2*tid;
auto* sigma_i = sigma + 3*tid;
auto* lapl_i = traits::needs_laplacian ? (lapl + 2*tid) : lapl;
auto* tau_i = tau + 2*tid;

auto* vrho_i = vrho + 2*tid;
auto* vsigma_i = vsigma + 3*tid;
auto* vlapl_i = traits::needs_laplacian ? vlapl + 2*tid : dummy_vlapl;
auto* vtau_i = vtau + 2*tid;
const double lapl_a_use = traits::needs_laplacian ? lapl_i[0] : 0.0;
const double lapl_b_use = traits::needs_laplacian ? lapl_i[1] : 0.0;

traits::eval_exc_vxc_polar( rho_i[0], rho_i[1], sigma_i[0],
sigma_i[1], sigma_i[2], lapl_a_use, lapl_b_use, tau_i[0],
tau_i[1], eps[tid], vrho_i[0], vrho_i[1], vsigma_i[0], vsigma_i[1],
vsigma_i[2], vlapl_i[0], vlapl_i[1], vtau_i[0], vtau_i[1] );

}

}


template <typename KernelType>
__global__ MGGA_EXC_INC_GENERATOR( device_eval_exc_inc_helper_unpolar_kernel ) {

using traits = kernel_traits<KernelType>;
int tid = threadIdx.x + blockIdx.x * blockDim.x;

double e;
if( tid < N ) {

const double lapl_use = traits::needs_laplacian ? lapl[tid] : 0.0;
traits::eval_exc_unpolar( rho[tid], sigma[tid], lapl_use, tau[tid], e );
eps[tid] += scal_fact * e;


}

}

template <typename KernelType>
__global__ MGGA_EXC_INC_GENERATOR( device_eval_exc_inc_helper_polar_kernel ) {

using traits = kernel_traits<KernelType>;
int tid = threadIdx.x + blockIdx.x * blockDim.x;

if( tid < N ) {

auto* rho_i = rho + 2*tid;
auto* sigma_i = sigma + 3*tid;
auto* lapl_i = traits::needs_laplacian ? (lapl + 2*tid) : lapl;
auto* tau_i = tau + 2*tid;

const double lapl_a_use = traits::needs_laplacian ? lapl_i[0] : 0.0;
const double lapl_b_use = traits::needs_laplacian ? lapl_i[1] : 0.0;

double e;
traits::eval_exc_polar( rho_i[0], rho_i[1], sigma_i[0],
sigma_i[1], sigma_i[2], lapl_a_use, lapl_b_use, tau_i[0],
tau_i[1], e );
eps[tid] += scal_fact * e;


}

}

template <typename KernelType>
__global__ MGGA_EXC_VXC_INC_GENERATOR( device_eval_exc_vxc_inc_helper_unpolar_kernel ) {

using traits = kernel_traits<KernelType>;
int tid = threadIdx.x + blockIdx.x * blockDim.x;

double e, vr, vs, vl, vt;
if( tid < N ) {

const double lapl_use = traits::needs_laplacian ? lapl[tid] : 0.0;

traits::eval_exc_vxc_unpolar( rho[tid], sigma[tid], lapl_use, tau[tid],
e, vr, vs, vl, vt );
eps[tid] += scal_fact * e;
vrho[tid] += scal_fact * vr;
vsigma[tid] += scal_fact * vs;
vtau[tid] += scal_fact * vt;
if(traits::needs_laplacian) vlapl[tid] += scal_fact * vl;

}

}

template <typename KernelType>
__global__ MGGA_EXC_VXC_INC_GENERATOR( device_eval_exc_vxc_inc_helper_polar_kernel ) {

using traits = kernel_traits<KernelType>;
int tid = threadIdx.x + blockIdx.x * blockDim.x;

double dummy_vlapl[2];
if( tid < N ) {

auto* rho_i = rho + 2*tid;
auto* sigma_i = sigma + 3*tid;
auto* lapl_i = traits::needs_laplacian ? (lapl + 2*tid) : lapl;
auto* tau_i = tau + 2*tid;

auto* vrho_i = vrho + 2*tid;
auto* vsigma_i = vsigma + 3*tid;
auto* vlapl_i = traits::needs_laplacian ? vlapl + 2*tid : dummy_vlapl;
auto* vtau_i = vtau + 2*tid;

const double lapl_a_use = traits::needs_laplacian ? lapl_i[0] : 0.0;
const double lapl_b_use = traits::needs_laplacian ? lapl_i[1] : 0.0;


double e, vra, vrb, vsaa,vsab,vsbb, vla, vlb, vta, vtb;
traits::eval_exc_vxc_polar( rho_i[0], rho_i[1], sigma_i[0],
sigma_i[1], sigma_i[2], lapl_a_use, lapl_b_use, tau_i[0],
tau_i[1], e, vra, vrb, vsaa, vsab, vsbb, vla, vlb, vta, vtb );

eps[tid] += scal_fact * e;
vrho_i[0] += scal_fact * vra;
vrho_i[1] += scal_fact * vrb;
vsigma_i[0] += scal_fact * vsaa;
vsigma_i[1] += scal_fact * vsab;
vsigma_i[2] += scal_fact * vsbb;
vtau_i[0] += scal_fact * vta;
vtau_i[1] += scal_fact * vtb;
if(traits::needs_laplacian) {
vlapl_i[0] += scal_fact * vla;
vlapl_i[1] += scal_fact * vlb;
}

}

}

template <typename KernelType>
LDA_EXC_GENERATOR_DEVICE( device_eval_exc_helper_unpolar ) {
Expand Down Expand Up @@ -582,6 +785,99 @@ GGA_EXC_VXC_INC_GENERATOR_DEVICE( device_eval_exc_vxc_inc_helper_polar ) {

}

template <typename KernelType>
MGGA_EXC_GENERATOR_DEVICE( device_eval_exc_helper_unpolar ) {

dim3 threads(32);
dim3 blocks( util::div_ceil( N, threads.x) );
device_eval_exc_helper_unpolar_kernel<KernelType><<<blocks,threads,0,stream>>>(
N, rho, sigma, lapl, tau, eps
);

}

template <typename KernelType>
MGGA_EXC_GENERATOR_DEVICE( device_eval_exc_helper_polar ) {

dim3 threads(32);
dim3 blocks( util::div_ceil( N, threads.x) );
device_eval_exc_helper_polar_kernel<KernelType><<<blocks,threads,0,stream>>>(
N, rho, sigma, lapl, tau, eps
);

}

template <typename KernelType>
MGGA_EXC_VXC_GENERATOR_DEVICE( device_eval_exc_vxc_helper_unpolar ) {

dim3 threads(32);
dim3 blocks( util::div_ceil( N, threads.x) );

device_eval_exc_vxc_helper_unpolar_kernel<KernelType><<<blocks,threads,0,stream>>>(
N, rho, sigma, lapl, tau, eps, vrho, vsigma, vlapl, vtau
);

}

template <typename KernelType>
MGGA_EXC_VXC_GENERATOR_DEVICE( device_eval_exc_vxc_helper_polar ) {

dim3 threads(32);
dim3 blocks( util::div_ceil( N, threads.x) );

device_eval_exc_vxc_helper_polar_kernel<KernelType><<<blocks,threads,0,stream>>>(
N, rho, sigma, lapl, tau, eps, vrho, vsigma, vlapl, vtau
);

}


template <typename KernelType>
MGGA_EXC_INC_GENERATOR_DEVICE( device_eval_exc_inc_helper_unpolar ) {

dim3 threads(32);
dim3 blocks( util::div_ceil( N, threads.x) );
device_eval_exc_inc_helper_unpolar_kernel<KernelType><<<blocks,threads,0,stream>>>(
scal_fact, N, rho, sigma, lapl, tau, eps
);

}

template <typename KernelType>
MGGA_EXC_INC_GENERATOR_DEVICE( device_eval_exc_inc_helper_polar ) {

dim3 threads(32);
dim3 blocks( util::div_ceil( N, threads.x) );
device_eval_exc_inc_helper_polar_kernel<KernelType><<<blocks,threads,0,stream>>>(
scal_fact, N, rho, sigma, lapl, tau, eps
);

}

template <typename KernelType>
MGGA_EXC_VXC_INC_GENERATOR_DEVICE( device_eval_exc_vxc_inc_helper_unpolar ) {

dim3 threads(32);
dim3 blocks( util::div_ceil( N, threads.x) );

device_eval_exc_vxc_inc_helper_unpolar_kernel<KernelType><<<blocks,threads,0,stream>>>(
scal_fact, N, rho, sigma, lapl, tau, eps, vrho, vsigma, vlapl, vtau
);

}

template <typename KernelType>
MGGA_EXC_VXC_INC_GENERATOR_DEVICE( device_eval_exc_vxc_inc_helper_polar ) {

dim3 threads(32);
dim3 blocks( util::div_ceil( N, threads.x) );

device_eval_exc_vxc_inc_helper_polar_kernel<KernelType><<<blocks,threads,0,stream>>>(
scal_fact, N, rho, sigma, lapl, tau, eps, vrho, vsigma, vlapl, vtau
);

}

#define LDA_GENERATE_DEVICE_HELPERS(KERN) \
template LDA_EXC_GENERATOR_DEVICE( device_eval_exc_helper_unpolar<KERN> ); \
template LDA_EXC_VXC_GENERATOR_DEVICE( device_eval_exc_vxc_helper_unpolar<KERN> ); \
Expand All @@ -602,6 +898,16 @@ GGA_EXC_VXC_INC_GENERATOR_DEVICE( device_eval_exc_vxc_inc_helper_polar ) {
template GGA_EXC_INC_GENERATOR_DEVICE( device_eval_exc_inc_helper_polar<KERN> ); \
template GGA_EXC_VXC_INC_GENERATOR_DEVICE( device_eval_exc_vxc_inc_helper_polar<KERN> );

#define MGGA_GENERATE_DEVICE_HELPERS(KERN) \
template MGGA_EXC_GENERATOR_DEVICE( device_eval_exc_helper_unpolar<KERN> ); \
template MGGA_EXC_VXC_GENERATOR_DEVICE( device_eval_exc_vxc_helper_unpolar<KERN> ); \
template MGGA_EXC_INC_GENERATOR_DEVICE( device_eval_exc_inc_helper_unpolar<KERN> ); \
template MGGA_EXC_VXC_INC_GENERATOR_DEVICE( device_eval_exc_vxc_inc_helper_unpolar<KERN> );\
template MGGA_EXC_GENERATOR_DEVICE( device_eval_exc_helper_polar<KERN> ); \
template MGGA_EXC_VXC_GENERATOR_DEVICE( device_eval_exc_vxc_helper_polar<KERN> ); \
template MGGA_EXC_INC_GENERATOR_DEVICE( device_eval_exc_inc_helper_polar<KERN> ); \
template MGGA_EXC_VXC_INC_GENERATOR_DEVICE( device_eval_exc_vxc_inc_helper_polar<KERN> );

LDA_GENERATE_DEVICE_HELPERS( BuiltinSlaterExchange );
LDA_GENERATE_DEVICE_HELPERS( BuiltinVWN3 );
LDA_GENERATE_DEVICE_HELPERS( BuiltinVWN_RPA );
Expand All @@ -624,6 +930,15 @@ MGGA_GENERATE_DEVICE_HELPERS( BuiltinSCAN_X );
MGGA_GENERATE_DEVICE_HELPERS( BuiltinSCAN_C );
MGGA_GENERATE_DEVICE_HELPERS( BuiltinR2SCAN_X );
MGGA_GENERATE_DEVICE_HELPERS( BuiltinR2SCAN_C );
MGGA_GENERATE_DEVICE_HELPERS( BuiltinFT98_X );

MGGA_GENERATE_DEVICE_HELPERS( BuiltinPC07_K );
MGGA_GENERATE_DEVICE_HELPERS( BuiltinPC07OPT_K );

MGGA_GENERATE_DEVICE_HELPERS( BuiltinSCANL_C );
MGGA_GENERATE_DEVICE_HELPERS( BuiltinSCANL_X );
MGGA_GENERATE_DEVICE_HELPERS( BuiltinR2SCANL_C );
MGGA_GENERATE_DEVICE_HELPERS( BuiltinR2SCANL_X );

LDA_GENERATE_DEVICE_HELPERS( BuiltinEPC17_1 )
LDA_GENERATE_DEVICE_HELPERS( BuiltinEPC17_2 )
Expand Down

0 comments on commit 616615f

Please sign in to comment.