Skip to content

Commit

Permalink
add template <Device> for sdft
Browse files Browse the repository at this point in the history
  • Loading branch information
Qianruipku committed Oct 29, 2024
1 parent d7d52c5 commit 3b096c6
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 192 deletions.
65 changes: 35 additions & 30 deletions source/module_elecstate/elecstate_pw_sdft.cpp
Original file line number Diff line number Diff line change
@@ -1,41 +1,46 @@
#include "./elecstate_pw_sdft.h"

#include "module_base/global_function.h"
#include "module_base/global_variable.h"
#include "module_parameter/parameter.h"
#include "module_base/timer.h"
#include "module_base/global_function.h"
#include "module_hamilt_general/module_xc/xc_functional.h"
#include "module_parameter/parameter.h"
namespace elecstate
{
void ElecStatePW_SDFT::psiToRho(const psi::Psi<std::complex<double>>& psi)

template <typename Device>
void ElecStatePW_SDFT<Device>::psiToRho(const psi::Psi<std::complex<double>>& psi)
{
ModuleBase::TITLE(this->classname, "psiToRho");
ModuleBase::timer::tick(this->classname, "psiToRho");
for (int is = 0; is < PARAM.inp.nspin; is++)
{
ModuleBase::TITLE(this->classname, "psiToRho");
ModuleBase::timer::tick(this->classname, "psiToRho");
for(int is=0; is < PARAM.inp.nspin; is++)
{
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx);
if (XC_Functional::get_func_type() == 3)
{
ModuleBase::GlobalFunc::ZEROS(this->charge->kin_r[is], this->charge->nrxx);
}
}

if(GlobalV::MY_STOGROUP == 0)
{
this->calEBand();
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx);
if (XC_Functional::get_func_type() == 3)
{
ModuleBase::GlobalFunc::ZEROS(this->charge->kin_r[is], this->charge->nrxx);
}
}

if (GlobalV::MY_STOGROUP == 0)
{
this->calEBand();

for(int is=0; is<PARAM.inp.nspin; is++)
{
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx);
}
for (int is = 0; is < PARAM.inp.nspin; is++)
{
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx);
}

for (int ik = 0; ik < psi.get_nk(); ++ik)
{
psi.fix_k(ik);
this->updateRhoK(psi);
}
this->parallelK();
for (int ik = 0; ik < psi.get_nk(); ++ik)
{
psi.fix_k(ik);
this->updateRhoK(psi);
}
ModuleBase::timer::tick(this->classname, "psiToRho");
return;
this->parallelK();
}
}
ModuleBase::timer::tick(this->classname, "psiToRho");
return;
}

template class ElecStatePW_SDFT<base_device::DEVICE_CPU>;
} // namespace elecstate
36 changes: 19 additions & 17 deletions source/module_elecstate/elecstate_pw_sdft.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,24 @@
#include "elecstate_pw.h"
namespace elecstate
{
class ElecStatePW_SDFT : public ElecStatePW<std::complex<double>>
template <typename Device>
class ElecStatePW_SDFT : public ElecStatePW<std::complex<double>, Device>
{
public:
ElecStatePW_SDFT(ModulePW::PW_Basis_K* wfc_basis_in,
Charge* chg_in,
K_Vectors* pkv_in,
UnitCell* ucell_in,
pseudopot_cell_vnl* ppcell_in,
ModulePW::PW_Basis* rhodpw_in,
ModulePW::PW_Basis* rhopw_in,
ModulePW::PW_Basis_Big* bigpw_in)
: ElecStatePW<std::complex<double>,
Device>(wfc_basis_in, chg_in, pkv_in, ucell_in, ppcell_in, rhodpw_in, rhopw_in, bigpw_in)
{
public:
ElecStatePW_SDFT(ModulePW::PW_Basis_K* wfc_basis_in,
Charge* chg_in,
K_Vectors* pkv_in,
UnitCell* ucell_in,
pseudopot_cell_vnl* ppcell_in,
ModulePW::PW_Basis* rhodpw_in,
ModulePW::PW_Basis* rhopw_in,
ModulePW::PW_Basis_Big* bigpw_in)
: ElecStatePW(wfc_basis_in, chg_in, pkv_in, ucell_in, ppcell_in, rhodpw_in, rhopw_in, bigpw_in)
{
this->classname = "ElecStatePW_SDFT";
}
virtual void psiToRho(const psi::Psi<std::complex<double>>& psi) override;
};
}
this->classname = "ElecStatePW_SDFT";
}
virtual void psiToRho(const psi::Psi<std::complex<double>>& psi) override;
};
} // namespace elecstate
#endif
8 changes: 4 additions & 4 deletions source/module_esolver/esolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell)
return new ESolver_KS_PW<std::complex<double>, base_device::DEVICE_CPU>();
}
}
else if (esolver_type == "sdft_pw")
{
return new ESolver_SDFT_PW<base_device::DEVICE_CPU>();
}
#ifdef __LCAO
else if (esolver_type == "ksdft_lip")
{
Expand Down Expand Up @@ -230,10 +234,6 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell)
return p_esolver_lr;
}
#endif
else if (esolver_type == "sdft_pw")
{
return new ESolver_SDFT_PW();
}
else if(esolver_type == "ofdft")
{
return new ESolver_OF();
Expand Down
Loading

0 comments on commit 3b096c6

Please sign in to comment.