Skip to content

Commit

Permalink
Refactor: refactor set_diagethr func and remove phsol from esolver [v…
Browse files Browse the repository at this point in the history
…ersion 2] (#5017)

* refactor set_diagethr func

* fix build bug

* remove useless code about set_diagethr

* Remove the inheritance relationship between hsolverPW and hsolver

* format hsolver-sdft code

* remove this->phsol in SpinConstrain

* remove this->phsol in esolver

* [pre-commit.ci lite] apply automatic fixes

* remove useless code abaout phsol

* [pre-commit.ci lite] apply automatic fixes

---------

Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
  • Loading branch information
haozhihan and pre-commit-ci-lite[bot] authored Aug 31, 2024
1 parent ee3fba3 commit 52f7816
Show file tree
Hide file tree
Showing 26 changed files with 281 additions and 381 deletions.
32 changes: 29 additions & 3 deletions source/module_esolver/esolver_ks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ ESolver_KS<T, Device>::~ESolver_KS()
delete this->psi;
delete this->pw_wfc;
delete this->p_hamilt;
delete this->phsol;
delete this->p_chgmix;
}

Expand Down Expand Up @@ -379,7 +378,6 @@ void ESolver_KS<T, Device>::hamilt2density(const int istep, const int iter, cons
// LCAO, PW, SDFT and TDDFT.
// After HSolver is constructed, LCAO, PW, SDFT should delete their own
// hamilt2density() and use:
// this->phsol->solve(this->phamilt, this->pes, this->wf, ETHR);
ModuleBase::timer::tick(this->classname, "hamilt2density");
}

Expand Down Expand Up @@ -440,7 +438,35 @@ void ESolver_KS<T, Device>::runner(const int istep, UnitCell& ucell)
#else
auto iterstart = std::chrono::system_clock::now();
#endif
diag_ethr = this->phsol->set_diagethr(diag_ethr, istep, iter, drho);

if (PARAM.inp.esolver_type == "ksdft")
{
diag_ethr = hsolver::set_diagethr_ks(PARAM.inp.basis_type,
PARAM.inp.esolver_type,
PARAM.inp.calculation,
PARAM.inp.init_chg,
GlobalV::precision_flag,
istep,
iter,
drho,
GlobalV::PW_DIAG_THR,
diag_ethr,
GlobalV::nelec);
}
else if (PARAM.inp.esolver_type == "sdft")
{
diag_ethr = hsolver::set_diagethr_sdft(PARAM.inp.basis_type,
PARAM.inp.esolver_type,
PARAM.inp.calculation,
PARAM.inp.init_chg,
istep,
iter,
drho,
GlobalV::PW_DIAG_THR,
diag_ethr,
GlobalV::NBANDS,
esolver_KS_ne);
}

// 6) initialization of SCF iterations
this->iter_init(istep, iter);
Expand Down
4 changes: 3 additions & 1 deletion source/module_esolver/esolver_ks.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ class ESolver_KS : public ESolver_FP

std::string basisname; //PW or LCAO

void print_wfcfft(const Input_para& inp, std::ofstream &ofs);
void print_wfcfft(const Input_para& inp, std::ofstream& ofs);

double esolver_KS_ne = 0.0;
};
} // end of namespace
#endif
12 changes: 0 additions & 12 deletions source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,6 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(const Input_para& inp, UnitCell
// 9) initialize ppcell
GlobalC::ppcell.init_vloc(GlobalC::ppcell.vloc, this->pw_rho);

// 10) initialize the HSolver
if (this->phsol == nullptr)
{
this->phsol = new hsolver::HSolver<TK>();
}

// 11) inititlize the charge density
this->pelec->charge->allocate(GlobalV::NSPIN);
this->pelec->omega = GlobalC::ucell.omega;
Expand Down Expand Up @@ -702,13 +696,11 @@ void ESolver_KS_LCAO<TK, TR>::hamilt2density(int istep, int iter, double ethr)
}

// 3) solve the Hamiltonian and output band gap
if (this->phsol != nullptr)
{
// reset energy
this->pelec->f_en.eband = 0.0;
this->pelec->f_en.demet = 0.0;

// this->phsol->solve(this->p_hamilt, this->psi[0], this->pelec, GlobalV::KS_SOLVER);
hsolver::HSolverLCAO<TK> hsolver_lcao_obj(&(this->pv), GlobalV::KS_SOLVER);
hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, GlobalV::KS_SOLVER, false);

Expand All @@ -725,10 +717,6 @@ void ESolver_KS_LCAO<TK, TR>::hamilt2density(int istep, int iter, double ethr)
}
}
}
else
{
ModuleBase::WARNING_QUIT("ESolver_KS_PW", "HSolver has not been initialed!");
}

// 4) print bands for each k-point and each band
for (int ik = 0; ik < this->kv.get_nks(); ++ik)
Expand Down
21 changes: 6 additions & 15 deletions source/module_esolver/esolver_ks_lcao_tddft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,7 @@ void ESolver_KS_LCAO_TDDFT::before_all_runners(const Input_para& inp, UnitCell&
LCAO_domain::divide_HS_in_frag(GlobalV::GAMMA_ONLY_LOCAL, this->pv, kv.get_nks());

// 6) initialize Density Matrix
dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec)
->init_DM(&kv, &this->pv, GlobalV::NSPIN);

// 7) initialize Hsolver
if (this->phsol == nullptr)
{
this->phsol = new hsolver::HSolver<std::complex<double>>();
}
dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec)->init_DM(&kv, &this->pv, GlobalV::NSPIN);

// 8) initialize the charge density
this->pelec->charge->allocate(GlobalV::NSPIN);
Expand Down Expand Up @@ -166,23 +159,21 @@ void ESolver_KS_LCAO_TDDFT::hamilt2density(const int istep, const int iter, cons
kv.get_nks());
this->pelec_td->psiToRho_td(this->psi[0]);
}
else if (this->phsol != nullptr)
else
{
// reset energy
this->pelec->f_en.eband = 0.0;
this->pelec->f_en.demet = 0.0;
if (this->psi != nullptr)
{
// this->phsol->solve(this->p_hamilt, this->psi[0], this->pelec_td, GlobalV::KS_SOLVER);

hsolver::HSolverLCAO<std::complex<double>> hsolver_lcao_obj(&this->pv, GlobalV::KS_SOLVER);
hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec_td, GlobalV::KS_SOLVER, false);
}
}
else
{
ModuleBase::WARNING_QUIT("ESolver_KS_LCAO", "HSolver has not been initialed!");
}
// else
// {
// ModuleBase::WARNING_QUIT("ESolver_KS_LCAO", "HSolver has not been initialed!");
// }

// print occupation of each band
if (iter == 1 && istep <= 2)
Expand Down
22 changes: 1 addition & 21 deletions source/module_esolver/esolver_ks_lcaopw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,26 +58,10 @@ namespace ModuleESolver
template <typename T>
ESolver_KS_LIP<T>::~ESolver_KS_LIP()
{
// delete HSolver and ElecState
this->deallocate_hsolver();
// delete Hamilt
this->deallocate_hamilt();
}

template <typename T>
void ESolver_KS_LIP<T>::allocate_hsolver()
{
this->phsol = new hsolver::HSolver<T>();
}
template <typename T>
void ESolver_KS_LIP<T>::deallocate_hsolver()
{
if (this->phsol != nullptr)
{
delete (this->phsol);
this->phsol = nullptr;
}
}
template <typename T>
void ESolver_KS_LIP<T>::allocate_hamilt()
{
Expand Down Expand Up @@ -133,7 +117,6 @@ namespace ModuleESolver
ModuleBase::TITLE("ESolver_KS_LIP", "hamilt2density");
ModuleBase::timer::tick("ESolver_KS_LIP", "hamilt2density");

if (this->phsol != nullptr)
{
// reset energy
this->pelec->f_en.eband = 0.0;
Expand Down Expand Up @@ -175,10 +158,7 @@ namespace ModuleESolver
}
}
}
else
{
ModuleBase::WARNING_QUIT("ESolver_KS_LIP", "HSolver has not been allocated.");
}

// add exx
#ifdef __EXX
if (GlobalC::exx_info.info_global.cal_exx) {
Expand Down
3 changes: 1 addition & 2 deletions source/module_esolver/esolver_ks_lcaopw.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ namespace ModuleESolver
protected:
virtual void iter_init(const int istep, const int iter) override;
virtual void iter_finish(int& iter) override;
virtual void allocate_hsolver() override;
virtual void deallocate_hsolver() override;

virtual void allocate_hamilt() override;
virtual void deallocate_hamilt() override;

Expand Down
13 changes: 0 additions & 13 deletions source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ ESolver_KS_PW<T, Device>::ESolver_KS_PW()
template <typename T, typename Device>
ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
{
// delete HSolver and ElecState
this->deallocate_hsolver();

// delete Hamilt
this->deallocate_hamilt();
Expand Down Expand Up @@ -110,12 +108,6 @@ void ESolver_KS_PW<T, Device>::before_all_runners(const Input_para& inp, UnitCel
// 1) call before_all_runners() of ESolver_KS
ESolver_KS<T, Device>::before_all_runners(inp, ucell);

// 2) initialize HSolver
if (this->phsol == nullptr)
{
this->allocate_hsolver();
}

// 3) initialize ElecState,
if (this->pelec == nullptr)
{
Expand Down Expand Up @@ -348,7 +340,6 @@ void ESolver_KS_PW<T, Device>::hamilt2density(const int istep, const int iter, c
{
ModuleBase::timer::tick("ESolver_KS_PW", "hamilt2density");

if (this->phsol != nullptr)
{
// reset energy
this->pelec->f_en.eband = 0.0;
Expand Down Expand Up @@ -405,10 +396,6 @@ void ESolver_KS_PW<T, Device>::hamilt2density(const int istep, const int iter, c
}
}
}
else
{
ModuleBase::WARNING_QUIT("ESolver_KS_PW", "HSolver has not been initialed!");
}

// calculate the delta_harris energy
// according to new charge density.
Expand Down
2 changes: 0 additions & 2 deletions source/module_esolver/esolver_ks_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ class ESolver_KS_PW : public ESolver_KS<T, Device>
// Init Global class
void Init_GlobalC(const Input_para& inp, UnitCell& ucell, pseudopot_cell_vnl& ppcell);

virtual void allocate_hsolver();
virtual void deallocate_hsolver();
virtual void allocate_hamilt();
virtual void deallocate_hamilt();

Expand Down
9 changes: 3 additions & 6 deletions source/module_esolver/esolver_sdft_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,6 @@ void ESolver_SDFT_PW::before_all_runners(const Input_para& inp, UnitCell& ucell)
ModuleBase::Memory::record("SDFT::chiortho", size * sizeof(std::complex<double>));
}

// 9) initialize the hsolver
// It should be removed after esolver_ks does not use phsol.
this->phsol = new hsolver::HSolverPW_SDFT(&this->kv, this->pw_wfc, &this->wf, this->stowf, this->stoche, this->init_psi);

return;
}

Expand Down Expand Up @@ -192,9 +188,10 @@ void ESolver_SDFT_PW::hamilt2density(int istep, int iter, double ethr)
false);
this->init_psi = true;

// temporary
// temporary
// set_diagethr need it
((hsolver::HSolverPW_SDFT*)phsol)->set_KS_ne(hsolver_pw_sdft_obj.stoiter.KS_ne);
// ((hsolver::HSolverPW_SDFT*)phsol)->set_KS_ne(hsolver_pw_sdft_obj.stoiter.KS_ne);
this->esolver_KS_ne = hsolver_pw_sdft_obj.stoiter.KS_ne;

if (GlobalV::MY_STOGROUP == 0)
{
Expand Down
1 change: 0 additions & 1 deletion source/module_esolver/lcao_before_scf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ void ESolver_KS_LCAO<TK, TR>::beforesolver(const int istep)
GlobalV::NSPIN,
this->kv,
GlobalV::KS_SOLVER,
this->phsol,
this->p_hamilt,
this->psi,
this->pelec);
Expand Down
13 changes: 2 additions & 11 deletions source/module_esolver/lcao_nscf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,8 @@ void ESolver_KS_LCAO<TK, TR>::nscf() {
// then when the istep is a variable of scf or nscf,
// istep becomes istep-1, this should be fixed in future
int istep = 0;
if (this->phsol != nullptr)
{
// this->phsol->solve(this->p_hamilt, this->psi[0], this->pelec, GlobalV::KS_SOLVER, true);

hsolver::HSolverLCAO<TK> hsolver_lcao_obj(&(this->pv), GlobalV::KS_SOLVER);
hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, GlobalV::KS_SOLVER, true);
}
else
{
ModuleBase::WARNING_QUIT("ESolver_KS_PW", "HSolver has not been initialed!");
}
hsolver::HSolverLCAO<TK> hsolver_lcao_obj(&(this->pv), GlobalV::KS_SOLVER);
hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, GlobalV::KS_SOLVER, true);

time_t time_finish = std::time(nullptr);
ModuleBase::GlobalFunc::OUT_TIME("cal_bands", time_start, time_finish);
Expand Down
Loading

0 comments on commit 52f7816

Please sign in to comment.