Skip to content

Commit

Permalink
Refactor hpsi_func in hsolver (deepmodeling#5202)
Browse files Browse the repository at this point in the history
* Refactor hpsi_func of dav_subspace

* Modified the hpsi_func in pyabacus to maintain definition consistency

* Fix hpsi_func in pyabacus-dav_subspace

* [pre-commit.ci lite] apply automatic fixes

* Refactor hpsi_func of dav

* Change hpsi_func of hsolver_lrtd

* Modify hpsi_func definition in dav tests

* Modify the hpsi_func in pyabacus to maintain definition consistency

* [pre-commit.ci lite] apply automatic fixes

* Modify hsolver_pw_sup mock func signature

* Update docs for new hpsi_func

* Update docs

* Change indent to make it prettier

* Update docs

* Update spsi docs

* Update parameter name of spsi_func interface

* Rename leading dimension vars from camel to snake case

* Rename leading dimension in pyabacus to snake case

---------

Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
  • Loading branch information
Cstandardlib and pre-commit-ci-lite[bot] authored Oct 5, 2024
1 parent 800987a commit 72b1d7c
Show file tree
Hide file tree
Showing 12 changed files with 135 additions and 137 deletions.
12 changes: 5 additions & 7 deletions python/pyabacus/src/py_diago_dav_subspace.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,23 +113,21 @@ class PyDiagoDavSubspace
auto hpsi_func = [mm_op] (
std::complex<double> *psi_in,
std::complex<double> *hpsi_out,
const int nband_in,
const int nbasis_in,
const int band_index1,
const int band_index2
const int ld_psi,
const int nvec
) {
// Note: numpy's py::array_t is row-major, but
// our raw pointer-array is column-major
py::array_t<std::complex<double>, py::array::f_style> psi({nbasis_in, band_index2 - band_index1 + 1});
py::array_t<std::complex<double>, py::array::f_style> psi({ld_psi, nvec});
py::buffer_info psi_buf = psi.request();
std::complex<double>* psi_ptr = static_cast<std::complex<double>*>(psi_buf.ptr);
std::copy(psi_in + band_index1 * nbasis_in, psi_in + (band_index2 + 1) * nbasis_in, psi_ptr);
std::copy(psi_in, psi_in + nvec * ld_psi, psi_ptr);

py::array_t<std::complex<double>, py::array::f_style> hpsi = mm_op(psi);

py::buffer_info hpsi_buf = hpsi.request();
std::complex<double>* hpsi_ptr = static_cast<std::complex<double>*>(hpsi_buf.ptr);
std::copy(hpsi_ptr, hpsi_ptr + (band_index2 - band_index1 + 1) * nbasis_in, hpsi_out);
std::copy(hpsi_ptr, hpsi_ptr + nvec * ld_psi, hpsi_out);
};

obj = std::make_unique<hsolver::Diago_DavSubspace<std::complex<double>, base_device::DEVICE_CPU>>(
Expand Down
12 changes: 5 additions & 7 deletions python/pyabacus/src/py_diago_david.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,23 +111,21 @@ class PyDiagoDavid
auto hpsi_func = [mm_op] (
std::complex<double> *psi_in,
std::complex<double> *hpsi_out,
const int nband_in,
const int nbasis_in,
const int band_index1,
const int band_index2
const int ld_psi,
const int nvec
) {
// Note: numpy's py::array_t is row-major, but
// our raw pointer-array is column-major
py::array_t<std::complex<double>, py::array::f_style> psi({nbasis_in, band_index2 - band_index1 + 1});
py::array_t<std::complex<double>, py::array::f_style> psi({ld_psi, nvec});
py::buffer_info psi_buf = psi.request();
std::complex<double>* psi_ptr = static_cast<std::complex<double>*>(psi_buf.ptr);
std::copy(psi_in + band_index1 * nbasis_in, psi_in + (band_index2 + 1) * nbasis_in, psi_ptr);
std::copy(psi_in, psi_in + nvec * ld_psi, psi_ptr);

py::array_t<std::complex<double>, py::array::f_style> hpsi = mm_op(psi);

py::buffer_info hpsi_buf = hpsi.request();
std::complex<double>* hpsi_ptr = static_cast<std::complex<double>*>(hpsi_buf.ptr);
std::copy(hpsi_ptr, hpsi_ptr + (band_index2 - band_index1 + 1) * nbasis_in, hpsi_out);
std::copy(hpsi_ptr, hpsi_ptr + nvec * ld_psi, hpsi_out);
};

auto spsi_func = [this] (
Expand Down
9 changes: 6 additions & 3 deletions source/module_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,

// compute h*psi_in_iter
// NOTE: bands after the first n_band should yield zero
hpsi_func(this->psi_in_iter, this->hphi, this->nbase_x, this->dim, 0, this->nbase_x - 1);
// hphi[:, 0:nbase_x] = H * psi_in_iter[:, 0:nbase_x]
hpsi_func(this->psi_in_iter, this->hphi, this->dim, this->nbase_x);

// at this stage, notconv = n_band and nbase = 0
// note that nbase of cal_elem is an inout parameter: nbase := nbase + notconv
Expand Down Expand Up @@ -421,7 +422,8 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
}

// update hpsi[:, nbase:nbase+notconv]
hpsi_func(psi_iter, &hphi[nbase * this->dim], this->nbase_x, this->dim, nbase, nbase + notconv - 1);
// hpsi[:, nbase:nbase+notconv] = H * psi_iter[:, nbase:nbase+notconv]
hpsi_func(psi_iter + nbase * dim, hphi + nbase * this->dim, this->dim, notconv);

ModuleBase::timer::tick("Diago_DavSubspace", "cal_grad");
return;
Expand Down Expand Up @@ -886,7 +888,8 @@ void Diago_DavSubspace<T, Device>::diagH_subspace(T* psi_pointer, // [in] & [out

{
// do hPsi for all bands
hpsi_func(psi_pointer, hphi, n_band, dmax, 0, nstart - 1);
// hphi[:, 0:nstart] = H * psi_pointer[:, 0:nstart]
hpsi_func(psi_pointer, hphi, dmax, nstart);

gemm_op<T, Device>()(ctx,
'C',
Expand Down
3 changes: 2 additions & 1 deletion source/module_hsolver/diago_dav_subspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class Diago_DavSubspace : public DiagH<T, Device>

virtual ~Diago_DavSubspace() override;

using HPsiFunc = std::function<void(T*, T*, const int, const int, const int, const int)>;
// See diago_david.h for information on the HPsiFunc function type
using HPsiFunc = std::function<void(T*, T*, const int, const int)>;

int diag(const HPsiFunc& hpsi_func,
T* psi_in,
Expand Down
39 changes: 20 additions & 19 deletions source/module_hsolver/diago_david.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
const SPsiFunc& spsi_func,
const int dim,
const int nband,
const int ldPsi,
const int ld_psi,
T *psi_in,
Real* eigenvalue_in,
const Real david_diag_thr,
Expand Down Expand Up @@ -191,20 +191,20 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
if(this->use_paw)
{
#ifdef USE_PAW
GlobalC::paw_cell.paw_nl_psi(1, reinterpret_cast<const std::complex<double>*> (psi_in + m*ldPsi),
GlobalC::paw_cell.paw_nl_psi(1, reinterpret_cast<const std::complex<double>*> (psi_in + m*ld_psi),
reinterpret_cast<std::complex<double>*>(&this->spsi[m * dim]));
#endif
}
else
{
// phm_in->sPsi(psi_in + m*ldPsi, &this->spsi[m * dim], dim, dim, 1);
spsi_func(psi_in + m*ldPsi,&this->spsi[m*dim],dim,dim,1);
// phm_in->sPsi(psi_in + m*ld_psi, &this->spsi[m * dim], dim, dim, 1);
spsi_func(psi_in + m*ld_psi,&this->spsi[m*dim],dim,dim,1);
}
}
// begin SchmidtOrth
for (int m = 0; m < nband; m++)
{
syncmem_complex_op()(this->ctx, this->ctx, basis + dim*m, psi_in + m*ldPsi, dim);
syncmem_complex_op()(this->ctx, this->ctx, basis + dim*m, psi_in + m*ld_psi, dim);

this->SchmidtOrth(dim,
nband,
Expand All @@ -230,7 +230,9 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
// end of SchmidtOrth and calculate H|psi>
// hpsi_info dav_hpsi_in(&basis, psi::Range(true, 0, 0, nband - 1), this->hpsi);
// phm_in->ops->hPsi(dav_hpsi_in);
hpsi_func(basis, hpsi, nbase_x, dim, 0, nband - 1);
// hpsi[:, 0:nband] = H basis[:, 0:nband]
// slice index in this piece of code is in C manner. i.e. 0:id stands for [0,id)
hpsi_func(basis, hpsi, dim, nband);

this->cal_elem(dim, nbase, nbase_x, this->notconv, this->hpsi, this->spsi, this->hcc, this->scc);

Expand Down Expand Up @@ -287,7 +289,7 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,

// update eigenvectors of Hamiltonian

setmem_complex_op()(this->ctx, psi_in, 0, nband * ldPsi);
setmem_complex_op()(this->ctx, psi_in, 0, nband * ld_psi);
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
gemm_op<T, Device>()(this->ctx,
'N',
Expand All @@ -302,7 +304,7 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
nbase_x,
this->zero,
psi_in, // C dim * nband
ldPsi
ld_psi
);

if (!this->notconv || (dav_iter == david_maxiter))
Expand All @@ -322,7 +324,7 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
nbase_x,
eigenvalue_in,
psi_in,
ldPsi,
ld_psi,
this->hpsi,
this->spsi,
this->hcc,
Expand Down Expand Up @@ -601,7 +603,8 @@ void DiagoDavid<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
// psi::Range(true, 0, nbase, nbase + notconv - 1),
// &hpsi[nbase * dim]); // &hp(nbase, 0)
// phm_in->ops->hPsi(dav_hpsi_in);
hpsi_func(basis, &hpsi[nbase * dim], nbase_x, dim, nbase, nbase + notconv - 1);
// hpsi[:, nbase:nbase+notcnv] = H basis[:, nbase:nbase+notcnv]
hpsi_func(basis + nbase * dim, hpsi + nbase * dim, dim, notconv);

delmem_complex_op()(this->ctx, lagrange);
delmem_complex_op()(this->ctx, vc_ev_vector);
Expand Down Expand Up @@ -785,7 +788,7 @@ void DiagoDavid<T, Device>::diag_zhegvx(const int& nbase,
* @param nbase_x The maximum dimension of the reduced basis set.
* @param eigenvalue_in Pointer to the array of eigenvalues.
* @param psi_in Pointer to the array of wavefunctions.
* @param ldPsi The leading dimension of the wavefunction array.
* @param ld_psi The leading dimension of the wavefunction array.
* @param hpsi Pointer to the output array for the updated basis set.
* @param spsi Pointer to the output array for the updated basis set (nband-th column).
* @param hcc Pointer to the output array for the updated reduced Hamiltonian.
Expand All @@ -800,7 +803,7 @@ void DiagoDavid<T, Device>::refresh(const int& dim,
const int nbase_x, // maximum dimension of the reduced basis set
const Real* eigenvalue_in,
const T *psi_in,
const int ldPsi,
const int ld_psi,
T* hpsi,
T* spsi,
T* hcc,
Expand Down Expand Up @@ -866,7 +869,7 @@ void DiagoDavid<T, Device>::refresh(const int& dim,

for (int m = 0; m < nband; m++)
{
syncmem_complex_op()(this->ctx, this->ctx, basis + dim*m,psi_in + m*ldPsi, dim);
syncmem_complex_op()(this->ctx, this->ctx, basis + dim*m,psi_in + m*ld_psi, dim);
/*for (int ig = 0; ig < npw; ig++)
basis(m, ig) = psi(m, ig);*/
}
Expand Down Expand Up @@ -1149,15 +1152,13 @@ void DiagoDavid<T, Device>::planSchmidtOrth(const int nband, std::vector<int>& p
/**
* @brief Performs iterative diagonalization using the David algorithm.
*
* @warning Please see docs of `HPsiFunc` for more information.
* @warning Please adhere strictly to the requirements of the function pointer
* @warning for the hpsi mat-vec interface; it may seem counterintuitive.
* @warning Please see docs of `HPsiFunc` for more information about the hpsi mat-vec interface.
*
* @tparam T The type of the elements in the matrix.
* @tparam Device The device type (CPU or GPU).
* @param hpsi_func The function object that computes the matrix-blockvector product H * psi.
* @param spsi_func The function object that computes the matrix-blockvector product overlap S * psi.
* @param ldPsi The leading dimension of the psi_in array.
* @param ld_psi The leading dimension of the psi_in array.
* @param psi_in The input wavefunction.
* @param eigenvalue_in The array to store the eigenvalues.
* @param david_diag_thr The convergence threshold for the diagonalization.
Expand All @@ -1172,7 +1173,7 @@ void DiagoDavid<T, Device>::planSchmidtOrth(const int nband, std::vector<int>& p
template <typename T, typename Device>
int DiagoDavid<T, Device>::diag(const HPsiFunc& hpsi_func,
const SPsiFunc& spsi_func,
const int ldPsi,
const int ld_psi,
T *psi_in,
Real* eigenvalue_in,
const Real david_diag_thr,
Expand All @@ -1187,7 +1188,7 @@ int DiagoDavid<T, Device>::diag(const HPsiFunc& hpsi_func,
int sum_dav_iter = 0;
do
{
sum_dav_iter += this->diag_once(hpsi_func, spsi_func, dim, nband, ldPsi, psi_in, eigenvalue_in, david_diag_thr, david_maxiter);
sum_dav_iter += this->diag_once(hpsi_func, spsi_func, dim, nband, ld_psi, psi_in, eigenvalue_in, david_diag_thr, david_maxiter);
++ntry;
} while (!check_block_conv(ntry, this->notconv, ntry_max, notconv_max));

Expand Down
70 changes: 34 additions & 36 deletions source/module_hsolver/diago_david.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,50 +38,48 @@ class DiagoDavid : public DiagH<T, Device>
* this function computes the product of the Hamiltonian matrix H and a blockvector X.
*
* Called as follows:
* hpsi(X, HX, nvec, dim, id_start, id_end)
* Result is stored in HX.
* HX = H * X[id_start:id_end]
* hpsi(X, HX, ld, nvec) where X and HX are (ld, nvec)-shaped blockvectors.
* Result HX = H * X is stored in HX.
*
* @param[out] X Head address of input blockvector of type `T*`.
* @param[in] HX Where to write output blockvector of type `T*`.
* @param[in] nvec Number of eigebpairs, i.e. number of vectors in a block.
* @param[in] dim Dimension of matrix.
* @param[in] id_start Start index of blockvector.
* @param[in] id_end End index of blockvector.
* @param[in] HX Head address of output blockvector of type `T*`.
* @param[in] ld Leading dimension of blockvector.
* @param[in] nvec Number of vectors in a block.
*
* @warning HX is the exact address to store output H*X[id_start:id_end];
* @warning while X is the head address of input blockvector, \b without offset.
* @warning Calling function should pass X and HX[offset] as arguments,
* @warning where offset is usually id_start * leading dimension.
* @warning X and HX are the exact address to read input X and store output H*X,
* @warning both of size ld * nvec.
*/
using HPsiFunc = std::function<void(T*, T*, const int, const int, const int, const int)>;
using HPsiFunc = std::function<void(T*, T*, const int, const int)>;

/**
* @brief A function type representing the SX function.
*
* nrow is leading dimension of spsi, npw is leading dimension of psi, nbands is number of vecs
*
* This function type is used to define a matrix-blockvector operator S.
* For generalized eigenvalue problem HX = λSX,
* this function computes the product of the overlap matrix S and a blockvector X.
*
* @param[in] X Pointer to the input array.
* @param[out] SX Pointer to the output array.
* @param[in] nrow Dimension of SX: nbands * nrow.
* @param[in] npw Number of plane waves.
* @param[in] nbands Number of bands.
* @param[in] X Pointer to the input blockvector.
* @param[out] SX Pointer to the output blockvector.
* @param[in] ld_spsi Leading dimension of spsi. Dimension of SX: nbands * nrow.
* @param[in] ld_psi Leading dimension of psi. Number of plane waves.
* @param[in] nbands Number of vectors.
*
* @note called as spsi(in, out, dim, dim, 1)
* @note called like spsi(in, out, dim, dim, 1)
*/
using SPsiFunc = std::function<void(T*, T*, const int, const int, const int)>;

int diag(const HPsiFunc& hpsi_func, // function void hpsi(T*, T*, const int, const int, const int, const int)
const SPsiFunc& spsi_func, // function void spsi(T*, T*, const int, const int, const int)
const int ldPsi, // Leading dimension of the psi input
T *psi_in, // Pointer to eigenvectors
Real* eigenvalue_in, // Pointer to store the resulting eigenvalues
const Real david_diag_thr, // Convergence threshold for the Davidson iteration
const int david_maxiter, // Maximum allowed iterations for the Davidson method
const int ntry_max = 5, // Maximum number of diagonalization attempts (default is 5)
const int notconv_max = 0); // Maximum number of allowed non-converged eigenvectors
int diag(
const HPsiFunc& hpsi_func, // function void hpsi(T*, T*, const int, const int)
const SPsiFunc& spsi_func, // function void spsi(T*, T*, const int, const int, const int)
const int ld_psi, // Leading dimension of the psi input
T *psi_in, // Pointer to eigenvectors
Real* eigenvalue_in, // Pointer to store the resulting eigenvalues
const Real david_diag_thr, // Convergence threshold for the Davidson iteration
const int david_maxiter, // Maximum allowed iterations for the Davidson method
const int ntry_max = 5, // Maximum number of diagonalization attempts (5 by default)
const int notconv_max = 0); // Maximum number of allowed non-converged eigenvectors

private:
bool use_paw = false;
Expand Down Expand Up @@ -130,7 +128,7 @@ class DiagoDavid : public DiagH<T, Device>
const SPsiFunc& spsi_func,
const int dim,
const int nband,
const int ldPsi,
const int ld_psi,
T *psi_in,
Real* eigenvalue_in,
const Real david_diag_thr,
Expand Down Expand Up @@ -163,20 +161,20 @@ class DiagoDavid : public DiagH<T, Device>
const int nbase_x,
const Real* eigenvalue,
const T *psi_in,
const int ldPsi,
const int ld_psi,
T* hpsi,
T* spsi,
T* hcc,
T* scc,
T* vcc);

void SchmidtOrth(const int& dim,
const int nband,
const int m,
const T* spsi,
T* lagrange_m,
const int mm_size,
const int mv_size);
const int nband,
const int m,
const T* spsi,
T* lagrange_m,
const int mm_size,
const int mv_size);

void planSchmidtOrth(const int nband, std::vector<int>& pre_matrix_mm_m, std::vector<int>& pre_matrix_mv_m);

Expand Down
Loading

0 comments on commit 72b1d7c

Please sign in to comment.