diff --git a/source/module_esolver/esolver_ks.cpp b/source/module_esolver/esolver_ks.cpp index 8226184e39..c87a51a599 100644 --- a/source/module_esolver/esolver_ks.cpp +++ b/source/module_esolver/esolver_ks.cpp @@ -86,7 +86,6 @@ ESolver_KS::~ESolver_KS() delete this->psi; delete this->pw_wfc; delete this->p_hamilt; - delete this->phsol; delete this->p_chgmix; } @@ -379,7 +378,6 @@ void ESolver_KS::hamilt2density(const int istep, const int iter, cons // LCAO, PW, SDFT and TDDFT. // After HSolver is constructed, LCAO, PW, SDFT should delete their own // hamilt2density() and use: - // this->phsol->solve(this->phamilt, this->pes, this->wf, ETHR); ModuleBase::timer::tick(this->classname, "hamilt2density"); } @@ -440,7 +438,35 @@ void ESolver_KS::runner(const int istep, UnitCell& ucell) #else auto iterstart = std::chrono::system_clock::now(); #endif - diag_ethr = this->phsol->set_diagethr(diag_ethr, istep, iter, drho); + + if (PARAM.inp.esolver_type == "ksdft") + { + diag_ethr = hsolver::set_diagethr_ks(PARAM.inp.basis_type, + PARAM.inp.esolver_type, + PARAM.inp.calculation, + PARAM.inp.init_chg, + GlobalV::precision_flag, + istep, + iter, + drho, + GlobalV::PW_DIAG_THR, + diag_ethr, + GlobalV::nelec); + } + else if (PARAM.inp.esolver_type == "sdft") + { + diag_ethr = hsolver::set_diagethr_sdft(PARAM.inp.basis_type, + PARAM.inp.esolver_type, + PARAM.inp.calculation, + PARAM.inp.init_chg, + istep, + iter, + drho, + GlobalV::PW_DIAG_THR, + diag_ethr, + GlobalV::NBANDS, + esolver_KS_ne); + } // 6) initialization of SCF iterations this->iter_init(istep, iter); diff --git a/source/module_esolver/esolver_ks.h b/source/module_esolver/esolver_ks.h index 36a85eac68..aea19efbf7 100644 --- a/source/module_esolver/esolver_ks.h +++ b/source/module_esolver/esolver_ks.h @@ -119,7 +119,9 @@ class ESolver_KS : public ESolver_FP std::string basisname; //PW or LCAO - void print_wfcfft(const Input_para& inp, std::ofstream &ofs); + void print_wfcfft(const Input_para& inp, std::ofstream& ofs); + + double esolver_KS_ne = 0.0; }; } // end of namespace #endif diff --git a/source/module_esolver/esolver_ks_lcao.cpp b/source/module_esolver/esolver_ks_lcao.cpp index 177b791b98..bdc03ed3e3 100644 --- a/source/module_esolver/esolver_ks_lcao.cpp +++ b/source/module_esolver/esolver_ks_lcao.cpp @@ -217,12 +217,6 @@ void ESolver_KS_LCAO::before_all_runners(const Input_para& inp, UnitCell // 9) initialize ppcell GlobalC::ppcell.init_vloc(GlobalC::ppcell.vloc, this->pw_rho); - // 10) initialize the HSolver - if (this->phsol == nullptr) - { - this->phsol = new hsolver::HSolver(); - } - // 11) inititlize the charge density this->pelec->charge->allocate(GlobalV::NSPIN); this->pelec->omega = GlobalC::ucell.omega; @@ -702,13 +696,11 @@ void ESolver_KS_LCAO::hamilt2density(int istep, int iter, double ethr) } // 3) solve the Hamiltonian and output band gap - if (this->phsol != nullptr) { // reset energy this->pelec->f_en.eband = 0.0; this->pelec->f_en.demet = 0.0; - // this->phsol->solve(this->p_hamilt, this->psi[0], this->pelec, GlobalV::KS_SOLVER); hsolver::HSolverLCAO hsolver_lcao_obj(&(this->pv), GlobalV::KS_SOLVER); hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, GlobalV::KS_SOLVER, false); @@ -725,10 +717,6 @@ void ESolver_KS_LCAO::hamilt2density(int istep, int iter, double ethr) } } } - else - { - ModuleBase::WARNING_QUIT("ESolver_KS_PW", "HSolver has not been initialed!"); - } // 4) print bands for each k-point and each band for (int ik = 0; ik < this->kv.get_nks(); ++ik) diff --git a/source/module_esolver/esolver_ks_lcao_tddft.cpp b/source/module_esolver/esolver_ks_lcao_tddft.cpp index 3feb76bf9f..9ac24357e9 100644 --- a/source/module_esolver/esolver_ks_lcao_tddft.cpp +++ b/source/module_esolver/esolver_ks_lcao_tddft.cpp @@ -98,14 +98,7 @@ void ESolver_KS_LCAO_TDDFT::before_all_runners(const Input_para& inp, UnitCell& LCAO_domain::divide_HS_in_frag(GlobalV::GAMMA_ONLY_LOCAL, this->pv, kv.get_nks()); // 6) initialize Density Matrix - dynamic_cast>*>(this->pelec) - ->init_DM(&kv, &this->pv, GlobalV::NSPIN); - - // 7) initialize Hsolver - if (this->phsol == nullptr) - { - this->phsol = new hsolver::HSolver>(); - } + dynamic_cast>*>(this->pelec)->init_DM(&kv, &this->pv, GlobalV::NSPIN); // 8) initialize the charge density this->pelec->charge->allocate(GlobalV::NSPIN); @@ -166,23 +159,21 @@ void ESolver_KS_LCAO_TDDFT::hamilt2density(const int istep, const int iter, cons kv.get_nks()); this->pelec_td->psiToRho_td(this->psi[0]); } - else if (this->phsol != nullptr) + else { // reset energy this->pelec->f_en.eband = 0.0; this->pelec->f_en.demet = 0.0; if (this->psi != nullptr) { - // this->phsol->solve(this->p_hamilt, this->psi[0], this->pelec_td, GlobalV::KS_SOLVER); - hsolver::HSolverLCAO> hsolver_lcao_obj(&this->pv, GlobalV::KS_SOLVER); hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec_td, GlobalV::KS_SOLVER, false); } } - else - { - ModuleBase::WARNING_QUIT("ESolver_KS_LCAO", "HSolver has not been initialed!"); - } + // else + // { + // ModuleBase::WARNING_QUIT("ESolver_KS_LCAO", "HSolver has not been initialed!"); + // } // print occupation of each band if (iter == 1 && istep <= 2) diff --git a/source/module_esolver/esolver_ks_lcaopw.cpp b/source/module_esolver/esolver_ks_lcaopw.cpp index d34ce8e607..2dd0d4aaae 100644 --- a/source/module_esolver/esolver_ks_lcaopw.cpp +++ b/source/module_esolver/esolver_ks_lcaopw.cpp @@ -58,26 +58,10 @@ namespace ModuleESolver template ESolver_KS_LIP::~ESolver_KS_LIP() { - // delete HSolver and ElecState - this->deallocate_hsolver(); // delete Hamilt this->deallocate_hamilt(); } - template - void ESolver_KS_LIP::allocate_hsolver() - { - this->phsol = new hsolver::HSolver(); - } - template - void ESolver_KS_LIP::deallocate_hsolver() - { - if (this->phsol != nullptr) - { - delete (this->phsol); - this->phsol = nullptr; - } - } template void ESolver_KS_LIP::allocate_hamilt() { @@ -133,7 +117,6 @@ namespace ModuleESolver ModuleBase::TITLE("ESolver_KS_LIP", "hamilt2density"); ModuleBase::timer::tick("ESolver_KS_LIP", "hamilt2density"); - if (this->phsol != nullptr) { // reset energy this->pelec->f_en.eband = 0.0; @@ -175,10 +158,7 @@ namespace ModuleESolver } } } - else - { - ModuleBase::WARNING_QUIT("ESolver_KS_LIP", "HSolver has not been allocated."); - } + // add exx #ifdef __EXX if (GlobalC::exx_info.info_global.cal_exx) { diff --git a/source/module_esolver/esolver_ks_lcaopw.h b/source/module_esolver/esolver_ks_lcaopw.h index 43c4c2167e..e5202aff46 100644 --- a/source/module_esolver/esolver_ks_lcaopw.h +++ b/source/module_esolver/esolver_ks_lcaopw.h @@ -29,8 +29,7 @@ namespace ModuleESolver protected: virtual void iter_init(const int istep, const int iter) override; virtual void iter_finish(int& iter) override; - virtual void allocate_hsolver() override; - virtual void deallocate_hsolver() override; + virtual void allocate_hamilt() override; virtual void deallocate_hamilt() override; diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index aa7d856b35..b9eff6a08e 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -72,8 +72,6 @@ ESolver_KS_PW::ESolver_KS_PW() template ESolver_KS_PW::~ESolver_KS_PW() { - // delete HSolver and ElecState - this->deallocate_hsolver(); // delete Hamilt this->deallocate_hamilt(); @@ -110,12 +108,6 @@ void ESolver_KS_PW::before_all_runners(const Input_para& inp, UnitCel // 1) call before_all_runners() of ESolver_KS ESolver_KS::before_all_runners(inp, ucell); - // 2) initialize HSolver - if (this->phsol == nullptr) - { - this->allocate_hsolver(); - } - // 3) initialize ElecState, if (this->pelec == nullptr) { @@ -348,7 +340,6 @@ void ESolver_KS_PW::hamilt2density(const int istep, const int iter, c { ModuleBase::timer::tick("ESolver_KS_PW", "hamilt2density"); - if (this->phsol != nullptr) { // reset energy this->pelec->f_en.eband = 0.0; @@ -405,10 +396,6 @@ void ESolver_KS_PW::hamilt2density(const int istep, const int iter, c } } } - else - { - ModuleBase::WARNING_QUIT("ESolver_KS_PW", "HSolver has not been initialed!"); - } // calculate the delta_harris energy // according to new charge density. diff --git a/source/module_esolver/esolver_ks_pw.h b/source/module_esolver/esolver_ks_pw.h index ec48a49347..f2e65a4ef3 100644 --- a/source/module_esolver/esolver_ks_pw.h +++ b/source/module_esolver/esolver_ks_pw.h @@ -56,8 +56,6 @@ class ESolver_KS_PW : public ESolver_KS // Init Global class void Init_GlobalC(const Input_para& inp, UnitCell& ucell, pseudopot_cell_vnl& ppcell); - virtual void allocate_hsolver(); - virtual void deallocate_hsolver(); virtual void allocate_hamilt(); virtual void deallocate_hamilt(); diff --git a/source/module_esolver/esolver_sdft_pw.cpp b/source/module_esolver/esolver_sdft_pw.cpp index c504636771..2220d3b0cb 100644 --- a/source/module_esolver/esolver_sdft_pw.cpp +++ b/source/module_esolver/esolver_sdft_pw.cpp @@ -126,10 +126,6 @@ void ESolver_SDFT_PW::before_all_runners(const Input_para& inp, UnitCell& ucell) ModuleBase::Memory::record("SDFT::chiortho", size * sizeof(std::complex)); } - // 9) initialize the hsolver - // It should be removed after esolver_ks does not use phsol. - this->phsol = new hsolver::HSolverPW_SDFT(&this->kv, this->pw_wfc, &this->wf, this->stowf, this->stoche, this->init_psi); - return; } @@ -192,9 +188,10 @@ void ESolver_SDFT_PW::hamilt2density(int istep, int iter, double ethr) false); this->init_psi = true; - // temporary + // temporary // set_diagethr need it - ((hsolver::HSolverPW_SDFT*)phsol)->set_KS_ne(hsolver_pw_sdft_obj.stoiter.KS_ne); + // ((hsolver::HSolverPW_SDFT*)phsol)->set_KS_ne(hsolver_pw_sdft_obj.stoiter.KS_ne); + this->esolver_KS_ne = hsolver_pw_sdft_obj.stoiter.KS_ne; if (GlobalV::MY_STOGROUP == 0) { diff --git a/source/module_esolver/lcao_before_scf.cpp b/source/module_esolver/lcao_before_scf.cpp index ab34e79c02..10c6a71a92 100644 --- a/source/module_esolver/lcao_before_scf.cpp +++ b/source/module_esolver/lcao_before_scf.cpp @@ -155,7 +155,6 @@ void ESolver_KS_LCAO::beforesolver(const int istep) GlobalV::NSPIN, this->kv, GlobalV::KS_SOLVER, - this->phsol, this->p_hamilt, this->psi, this->pelec); diff --git a/source/module_esolver/lcao_nscf.cpp b/source/module_esolver/lcao_nscf.cpp index d1da8bc3b3..249311a66b 100644 --- a/source/module_esolver/lcao_nscf.cpp +++ b/source/module_esolver/lcao_nscf.cpp @@ -52,17 +52,8 @@ void ESolver_KS_LCAO::nscf() { // then when the istep is a variable of scf or nscf, // istep becomes istep-1, this should be fixed in future int istep = 0; - if (this->phsol != nullptr) - { - // this->phsol->solve(this->p_hamilt, this->psi[0], this->pelec, GlobalV::KS_SOLVER, true); - - hsolver::HSolverLCAO hsolver_lcao_obj(&(this->pv), GlobalV::KS_SOLVER); - hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, GlobalV::KS_SOLVER, true); - } - else - { - ModuleBase::WARNING_QUIT("ESolver_KS_PW", "HSolver has not been initialed!"); - } + hsolver::HSolverLCAO hsolver_lcao_obj(&(this->pv), GlobalV::KS_SOLVER); + hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, GlobalV::KS_SOLVER, true); time_t time_finish = std::time(nullptr); ModuleBase::GlobalFunc::OUT_TIME("cal_bands", time_start, time_finish); diff --git a/source/module_esolver/pw_fun.cpp b/source/module_esolver/pw_fun.cpp index a4060208bf..c24d426f46 100644 --- a/source/module_esolver/pw_fun.cpp +++ b/source/module_esolver/pw_fun.cpp @@ -49,20 +49,6 @@ namespace ModuleESolver { -template -void ESolver_KS_PW::allocate_hsolver() -{ - this->phsol = new hsolver::HSolverPW(this->pw_wfc, &this->wf, false); -} -template -void ESolver_KS_PW::deallocate_hsolver() -{ - if (this->phsol != nullptr) - { - delete reinterpret_cast*>(this->phsol); - this->phsol = nullptr; - } -} template void ESolver_KS_PW::allocate_hamilt() { @@ -80,46 +66,41 @@ void ESolver_KS_PW::deallocate_hamilt() template -void ESolver_KS_PW::hamilt2estates(const double ethr) { - if (this->phsol != nullptr) { - hsolver::DiagoIterAssist::need_subspace = false; - hsolver::DiagoIterAssist::PW_DIAG_THR = ethr; +void ESolver_KS_PW::hamilt2estates(const double ethr) +{ + hsolver::DiagoIterAssist::need_subspace = false; + hsolver::DiagoIterAssist::PW_DIAG_THR = ethr; - std::vector is_occupied(this->kspw_psi->get_nk() * this->kspw_psi->get_nbands(), true); + std::vector is_occupied(this->kspw_psi->get_nk() * this->kspw_psi->get_nbands(), true); - elecstate::set_is_occupied(is_occupied, - this->pelec, - hsolver::DiagoIterAssist::SCF_ITER, - this->kspw_psi->get_nk(), - this->kspw_psi->get_nbands(), - PARAM.inp.diago_full_acc); + elecstate::set_is_occupied(is_occupied, + this->pelec, + hsolver::DiagoIterAssist::SCF_ITER, + this->kspw_psi->get_nk(), + this->kspw_psi->get_nbands(), + PARAM.inp.diago_full_acc); - hsolver::HSolverPW hsolver_pw_obj(this->pw_wfc, &this->wf, this->init_psi); + hsolver::HSolverPW hsolver_pw_obj(this->pw_wfc, &this->wf, this->init_psi); - hsolver_pw_obj.solve(this->p_hamilt, - this->kspw_psi[0], - this->pelec, - this->pelec->ekb.c, - is_occupied, - PARAM.inp.ks_solver, - PARAM.inp.calculation, - PARAM.inp.basis_type, - PARAM.inp.use_paw, - GlobalV::use_uspp, - GlobalV::RANK_IN_POOL, - GlobalV::NPROC_IN_POOL, - hsolver::DiagoIterAssist::SCF_ITER, - hsolver::DiagoIterAssist::need_subspace, - hsolver::DiagoIterAssist::PW_DIAG_NMAX, - hsolver::DiagoIterAssist::PW_DIAG_THR, - true); + hsolver_pw_obj.solve(this->p_hamilt, + this->kspw_psi[0], + this->pelec, + this->pelec->ekb.c, + is_occupied, + PARAM.inp.ks_solver, + PARAM.inp.calculation, + PARAM.inp.basis_type, + PARAM.inp.use_paw, + GlobalV::use_uspp, + GlobalV::RANK_IN_POOL, + GlobalV::NPROC_IN_POOL, + hsolver::DiagoIterAssist::SCF_ITER, + hsolver::DiagoIterAssist::need_subspace, + hsolver::DiagoIterAssist::PW_DIAG_NMAX, + hsolver::DiagoIterAssist::PW_DIAG_THR, + true); - this->init_psi = true; - - } else { - ModuleBase::WARNING_QUIT("ESolver_KS_PW", - "HSolver has not been initialed!"); - } + this->init_psi = true; } template class ESolver_KS_PW, base_device::DEVICE_CPU>; diff --git a/source/module_esolver/pw_init_after_vc.cpp b/source/module_esolver/pw_init_after_vc.cpp index d43aead9df..780badc596 100644 --- a/source/module_esolver/pw_init_after_vc.cpp +++ b/source/module_esolver/pw_init_after_vc.cpp @@ -90,10 +90,7 @@ void ESolver_KS_PW::init_after_vc(const Input_para& inp, UnitCell& uc this->pw_wfc->collect_local_pw(inp.erf_ecut, inp.erf_height, inp.erf_sigma); - - delete this->phsol; this->init_psi = false; - this->allocate_hsolver(); delete this->pelec; this->pelec diff --git a/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp b/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp index 9b060e59fb..4c5d5ab535 100644 --- a/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp +++ b/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp @@ -11,8 +11,6 @@ void SpinConstrain, base_device::DEVICE_CPU>::cal_mw_from_l ModuleBase::TITLE("SpinConstrain","cal_mw_from_lambda"); ModuleBase::timer::tick("SpinConstrain", "cal_mw_from_lambda"); - // this->phsol->solve(this->p_hamilt, this->psi[0], this->pelec, this->KS_SOLVER, true); - // diagonalization without update charge hsolver::HSolverLCAO> hsolver_lcao_obj(this->ParaV, this->KS_SOLVER); hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, this->KS_SOLVER, true); diff --git a/source/module_hamilt_lcao/module_deltaspin/init_sc.cpp b/source/module_hamilt_lcao/module_deltaspin/init_sc.cpp index b826a2d425..d0aad9caf8 100644 --- a/source/module_hamilt_lcao/module_deltaspin/init_sc.cpp +++ b/source/module_hamilt_lcao/module_deltaspin/init_sc.cpp @@ -15,7 +15,6 @@ void SpinConstrain::init_sc(double sc_thr_in, int nspin_in, K_Vectors& kv_in, std::string KS_SOLVER_in, - hsolver::HSolver* phsol_in, hamilt::Hamilt* p_hamilt_in, psi::Psi* psi_in, elecstate::ElecState* pelec_in) @@ -28,7 +27,7 @@ void SpinConstrain::init_sc(double sc_thr_in, this->bcast_ScData(sc_file, this->get_nat(), this->get_ntype()); this->set_npol(NPOL); this->set_ParaV(ParaV_in); - this->set_solver_parameters(kv_in, phsol_in, p_hamilt_in, psi_in, pelec_in, KS_SOLVER_in); + this->set_solver_parameters(kv_in, p_hamilt_in, psi_in, pelec_in, KS_SOLVER_in); } template class SpinConstrain, base_device::DEVICE_CPU>; diff --git a/source/module_hamilt_lcao/module_deltaspin/spin_constrain.cpp b/source/module_hamilt_lcao/module_deltaspin/spin_constrain.cpp index 9627f71885..a984d1dc83 100644 --- a/source/module_hamilt_lcao/module_deltaspin/spin_constrain.cpp +++ b/source/module_hamilt_lcao/module_deltaspin/spin_constrain.cpp @@ -257,12 +257,15 @@ void SpinConstrain::set_target_mag() mag.x = element_data.target_mag_val * std::sin(radian_angle1) * std::cos(radian_angle2); mag.y = element_data.target_mag_val * std::sin(radian_angle1) * std::sin(radian_angle2); mag.z = element_data.target_mag_val * std::cos(radian_angle1); - if (std::abs(mag.x) < 1e-14) + if (std::abs(mag.x) < 1e-14) { mag.x = 0.0; - if (std::abs(mag.y) < 1e-14) +} + if (std::abs(mag.y) < 1e-14) { mag.y = 0.0; - if (std::abs(mag.z) < 1e-14) +} + if (std::abs(mag.z) < 1e-14) { mag.z = 0.0; +} } this->target_mag_[iat] = mag; } @@ -506,14 +509,12 @@ bool SpinConstrain::get_decay_grad_switch() template void SpinConstrain::set_solver_parameters(K_Vectors& kv_in, - hsolver::HSolver* phsol_in, hamilt::Hamilt* p_hamilt_in, psi::Psi* psi_in, elecstate::ElecState* pelec_in, std::string KS_SOLVER_in) { this->kv_ = kv_in; - this->phsol = phsol_in; this->p_hamilt = p_hamilt_in; this->psi = psi_in; this->pelec = pelec_in; diff --git a/source/module_hamilt_lcao/module_deltaspin/spin_constrain.h b/source/module_hamilt_lcao/module_deltaspin/spin_constrain.h index 78aa14099f..1b2c18c570 100644 --- a/source/module_hamilt_lcao/module_deltaspin/spin_constrain.h +++ b/source/module_hamilt_lcao/module_deltaspin/spin_constrain.h @@ -37,7 +37,6 @@ class SpinConstrain int nspin_in, K_Vectors& kv_in, std::string KS_SOLVER_in, - hsolver::HSolver* phsol_in, hamilt::Hamilt* p_hamilt_in, psi::Psi* psi_in, elecstate::ElecState* pelec_in); @@ -203,7 +202,6 @@ class SpinConstrain void set_ParaV(Parallel_Orbitals* ParaV_in); /// @brief set parameters for solver void set_solver_parameters(K_Vectors& kv_in, - hsolver::HSolver* phsol_in, hamilt::Hamilt* p_hamilt_in, psi::Psi* psi_in, elecstate::ElecState* pelec_in, diff --git a/source/module_hamilt_lcao/module_deltaspin/test/init_sc_test.cpp b/source/module_hamilt_lcao/module_deltaspin/test/init_sc_test.cpp index 2e28efd2d2..f4adec2308 100644 --- a/source/module_hamilt_lcao/module_deltaspin/test/init_sc_test.cpp +++ b/source/module_hamilt_lcao/module_deltaspin/test/init_sc_test.cpp @@ -88,7 +88,6 @@ TYPED_TEST(SpinConstrainTest, InitSc) KS_SOLVER, nullptr, nullptr, - nullptr, nullptr); EXPECT_EQ(this->sc.get_nat(), 6); EXPECT_EQ(this->sc.get_ntype(), 2); diff --git a/source/module_hamilt_lcao/module_deltaspin/test/spin_constrain_test.cpp b/source/module_hamilt_lcao/module_deltaspin/test/spin_constrain_test.cpp index 989a11bc86..390d8cd4cb 100644 --- a/source/module_hamilt_lcao/module_deltaspin/test/spin_constrain_test.cpp +++ b/source/module_hamilt_lcao/module_deltaspin/test/spin_constrain_test.cpp @@ -308,12 +308,15 @@ TYPED_TEST(SpinConstrainTest, SetTargetMagType1) double mag_y = sc_data.target_mag_val * std::sin(sc_data.target_mag_angle1 * M_PI / 180) * std::sin(sc_data.target_mag_angle2 * M_PI / 180); double mag_z = sc_data.target_mag_val * std::cos(sc_data.target_mag_angle1 * M_PI / 180); - if (std::abs(mag_x) < 1e-14) + if (std::abs(mag_x) < 1e-14) { mag_x = 0.0; - if (std::abs(mag_y) < 1e-14) +} + if (std::abs(mag_y) < 1e-14) { mag_y = 0.0; - if (std::abs(mag_z) < 1e-14) +} + if (std::abs(mag_z) < 1e-14) { mag_z = 0.0; +} EXPECT_DOUBLE_EQ(mag_x, target_mag[iat].x); EXPECT_DOUBLE_EQ(mag_y, target_mag[iat].y); EXPECT_DOUBLE_EQ(mag_z, target_mag[iat].z); @@ -349,7 +352,7 @@ TYPED_TEST(SpinConstrainTest, SetInputParameters) int nsc_min = 2; double alpha_trial = 0.01; double sccut = 3.0; - bool decay_grad_switch = 1; + bool decay_grad_switch = true; this->sc.set_input_parameters(sc_thr, nsc, nsc_min, alpha_trial, sccut, decay_grad_switch); EXPECT_DOUBLE_EQ(this->sc.get_sc_thr(), sc_thr); EXPECT_EQ(this->sc.get_nsc(), nsc); @@ -363,7 +366,7 @@ TYPED_TEST(SpinConstrainTest, SetSolverParameters) { K_Vectors kv; this->sc.set_nspin(4); - this->sc.set_solver_parameters(kv, nullptr, nullptr, nullptr, nullptr, "genelpa"); + this->sc.set_solver_parameters(kv, nullptr, nullptr, nullptr, "genelpa"); EXPECT_EQ(this->sc.get_nspin(), 4); EXPECT_EQ(this->sc.phsol, nullptr); EXPECT_EQ(this->sc.p_hamilt, nullptr); diff --git a/source/module_hsolver/hsolver.cpp b/source/module_hsolver/hsolver.cpp index 5b8d9b73e7..d5d1289836 100644 --- a/source/module_hsolver/hsolver.cpp +++ b/source/module_hsolver/hsolver.cpp @@ -3,6 +3,132 @@ namespace hsolver { +double set_diagethr_ks(const std::string basis_type, + const std::string esolver_type, + const std::string calculation_in, + const std::string init_chg_in, + const std::string precision_flag_in, + const int istep, + const int iter, + const double drho, + const double pw_diag_thr_init, + const double diag_ethr_in, + const double nelec_in) +{ + double res_diag_ethr = diag_ethr_in; + + if (basis_type == "pw" && esolver_type == "ksdft") + { + // It is too complex now and should be modified. + if (iter == 1) + { + if (std::abs(res_diag_ethr - 1.0e-2) < 1.0e-6) + { + if (init_chg_in == "file") + { + //====================================================== + // if you think that the starting potential is good + // do not spoil it with a louly first diagonalization: + // set a strict diag ethr in the input file + // ()diago_the_init + //====================================================== + res_diag_ethr = 1.0e-5; + } + else + { + //======================================================= + // starting atomic potential is probably far from scf + // don't waste iterations in the first diagonalization + //======================================================= + res_diag_ethr = 1.0e-2; + } + } + + if (calculation_in == "md" || calculation_in == "relax" || calculation_in == "cell-relax") + { + res_diag_ethr = std::max(res_diag_ethr, static_cast(pw_diag_thr_init)); + } + } + else + { + if (iter == 2) + { + res_diag_ethr = 1.e-2; + } + res_diag_ethr = std::min(res_diag_ethr, + static_cast(0.1) * drho + / std::max(static_cast(1.0), static_cast(nelec_in))); + } + + // It is essential for single precision implementation to keep the diag ethr + // value less or equal to the single-precision limit of convergence(0.5e-4). + // modified by denghuilu at 2023-05-15 + if (precision_flag_in == "single") + { + res_diag_ethr = std::max(res_diag_ethr, static_cast(0.5e-4)); + } + } + else + { + res_diag_ethr = 0.0; + } + + return res_diag_ethr; +} + + +double set_diagethr_sdft(const std::string basis_type, + const std::string esolver_type, + const std::string calculation_in, + const std::string init_chg_in, + const int istep, + const int iter, + const double drho, + const double pw_diag_thr_init, + const double diag_ethr_in, + const int nband_in, + const double stoiter_ks_ne_in) +{ + double res_diag_ethr = diag_ethr_in; + + if (basis_type == "pw" && esolver_type == "sdft") + { + if (iter == 1) + { + if (istep == 0) + { + if (init_chg_in == "file") + { + res_diag_ethr = 1.0e-5; + } + res_diag_ethr = std::max(res_diag_ethr, pw_diag_thr_init); + } + else + { + res_diag_ethr = std::max(res_diag_ethr, 1.0e-5); + } + } + else + { + if (nband_in > 0 && stoiter_ks_ne_in > 1e-6) //GlobalV::NBANDS > 0 && this->stoiter.KS_ne > 1e-6 + { + res_diag_ethr = std::min(res_diag_ethr, 0.1 * drho / std::max(1.0, stoiter_ks_ne_in)); + } + else + { + res_diag_ethr = 0.0; + } + } + } + else + { + res_diag_ethr = 0.0; + } + + return res_diag_ethr; +} + + double reset_diag_ethr(std::ofstream& ofs_running, const std::string basis_type, const std::string esolver_type, @@ -14,7 +140,7 @@ double reset_diag_ethr(std::ofstream& ofs_running, { double new_diag_ethr = 0.0; - + if (basis_type == "pw" && esolver_type == "ksdft") { ofs_running << " Notice: Threshold on eigenvalues was too large.\n"; @@ -58,100 +184,4 @@ double cal_hsolve_error(const std::string basis_type, } }; - -// double set_diagethr(double diag_ethr_in, -// const int istep, -// const int iter, -// const double drho, -// std::string basis_type, -// std::string esolver_type) -// { -// if (basis_type = "pw" && esolver_type = "ksdft") -// { -// // It is too complex now and should be modified. -// if (iter == 1) -// { -// if (std::abs(diag_ethr_in - 1.0e-2) < 1.0e-6) -// { -// if (GlobalV::init_chg == "file") -// { -// //====================================================== -// // if you think that the starting potential is good -// // do not spoil it with a louly first diagonalization: -// // set a strict diag ethr in the input file -// // ()diago_the_init -// //====================================================== -// diag_ethr_in = 1.0e-5; -// } -// else -// { -// //======================================================= -// // starting atomic potential is probably far from scf -// // don't waste iterations in the first diagonalization -// //======================================================= -// diag_ethr_in = 1.0e-2; -// } -// } - -// if (GlobalV::CALCULATION == "md" || GlobalV::CALCULATION == "relax" || GlobalV::CALCULATION == "cell-relax") -// { -// diag_ethr_in = std::max(diag_ethr_in, static_cast(GlobalV::PW_DIAG_THR)); -// } -// } -// else -// { -// if (iter == 2) -// { -// diag_ethr_in = 1.e-2; -// } -// diag_ethr_in = std::min(diag_ethr_in, -// static_cast(0.1) * drho -// / std::max(static_cast(1.0), static_cast(GlobalV::nelec))); -// } -// // It is essential for single precision implementation to keep the diag ethr -// // value less or equal to the single-precision limit of convergence(0.5e-4). -// // modified by denghuilu at 2023-05-15 -// if (GlobalV::precision_flag == "single") -// { -// diag_ethr_in = std::max(diag_ethr_in, static_cast(0.5e-4)); -// } -// } -// else if (basis_type = "pw" && esolver_type = "sdft") -// { -// if (iter == 1) -// { -// if (istep == 0) -// { -// if (GlobalV::init_chg == "file") -// { -// diag_ethr_in = 1.0e-5; -// } -// diag_ethr_in = std::max(diag_ethr_in, GlobalV::PW_DIAG_THR); -// } -// else -// { -// diag_ethr_in = std::max(diag_ethr_in, 1.0e-5); -// } -// } -// else -// { -// if (GlobalV::NBANDS > 0 && this->stoiter.KS_ne > 1e-6) -// { -// diag_ethr_in = std::min(diag_ethr_in, 0.1 * drho / std::max(1.0, this->stoiter.KS_ne)); -// } -// else -// { -// diag_ethr_in = 0.0; -// } -// } -// } -// else -// { -// diag_ethr_in = 0.0; -// } - -// return 0.0; -// }; - - } // namespace hsolver \ No newline at end of file diff --git a/source/module_hsolver/hsolver.h b/source/module_hsolver/hsolver.h index bdf282beed..ecf806289d 100644 --- a/source/module_hsolver/hsolver.h +++ b/source/module_hsolver/hsolver.h @@ -22,14 +22,34 @@ class HSolver public: HSolver() {}; - - // set diagethr according to drho (for lcao and lcao-in-pw, we suppose the error is zero and we set diagethr to 0) - virtual Real set_diagethr(Real diag_ethr_in, const int istep, const int iter, const Real drho) - { - return 0.0; - } }; + +double set_diagethr_ks(const std::string basis_type, + const std::string esolver_type, + const std::string calculation_in, + const std::string init_chg_in, + const std::string precision_flag_in, + const int istep, + const int iter, + const double drho, + const double pw_diag_thr_init, + const double diag_ethr_in, + const double nelec_in); + +double set_diagethr_sdft(const std::string basis_type, + const std::string esolver_type, + const std::string calculation_in, + const std::string init_chg_in, + const int istep, + const int iter, + const double drho, + const double pw_diag_thr_init, + const double diag_ethr_in, + const int nband_in, + const double stoiter_ks_ne_in); + + // reset diagethr according to drho and hsolver_error double reset_diag_ethr(std::ofstream& ofs_running, const std::string basis_type, diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index 1751afec5f..51bd210366 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -638,62 +638,6 @@ void HSolverPW::output_iterInfo() DiagoIterAssist::avg_iter = 0.0; } } -template -typename HSolverPW::Real HSolverPW::set_diagethr(Real diag_ethr_in, - const int istep, - const int iter, - const Real drho) -{ - // It is too complex now and should be modified. - if (iter == 1) - { - if (std::abs(diag_ethr_in - 1.0e-2) < 1.0e-6) - { - if (GlobalV::init_chg == "file") - { - //====================================================== - // if you think that the starting potential is good - // do not spoil it with a louly first diagonalization: - // set a strict diag ethr in the input file - // ()diago_the_init - //====================================================== - diag_ethr_in = 1.0e-5; - } - else - { - //======================================================= - // starting atomic potential is probably far from scf - // don't waste iterations in the first diagonalization - //======================================================= - diag_ethr_in = 1.0e-2; - } - } - - if (GlobalV::CALCULATION == "md" || GlobalV::CALCULATION == "relax" || GlobalV::CALCULATION == "cell-relax") - { - diag_ethr_in = std::max(diag_ethr_in, static_cast(GlobalV::PW_DIAG_THR)); - } - } - else - { - if (iter == 2) - { - diag_ethr_in = 1.e-2; - } - diag_ethr_in = std::min(diag_ethr_in, - static_cast(0.1) * drho - / std::max(static_cast(1.0), static_cast(GlobalV::nelec))); - } - // It is essential for single precision implementation to keep the diag ethr - // value less or equal to the single-precision limit of convergence(0.5e-4). - // modified by denghuilu at 2023-05-15 - if (GlobalV::precision_flag == "single") - { - diag_ethr_in = std::max(diag_ethr_in, static_cast(0.5e-4)); - } - - return diag_ethr_in; -} template class HSolverPW, base_device::DEVICE_CPU>; template class HSolverPW, base_device::DEVICE_CPU>; diff --git a/source/module_hsolver/hsolver_pw.h b/source/module_hsolver/hsolver_pw.h index 5c721a5775..43fcc3b001 100644 --- a/source/module_hsolver/hsolver_pw.h +++ b/source/module_hsolver/hsolver_pw.h @@ -10,7 +10,7 @@ namespace hsolver { template -class HSolverPW : public HSolver +class HSolverPW { private: // Note GetTypeReal::type will @@ -46,8 +46,6 @@ class HSolverPW : public HSolver const int diag_iter_max_in, const double iter_diag_thr_in, const bool skip_charge); - - virtual Real set_diagethr(Real diag_ethr_in, const int istep, const int iter, const Real drho) override; protected: // diago caller @@ -98,6 +96,7 @@ class HSolverPW : public HSolver void paw_func_after_kloop(psi::Psi& psi, elecstate::ElecState* pes); #endif + }; diff --git a/source/module_hsolver/hsolver_pw_sdft.cpp b/source/module_hsolver/hsolver_pw_sdft.cpp index 398ff1283d..83d83fc0f5 100644 --- a/source/module_hsolver/hsolver_pw_sdft.cpp +++ b/source/module_hsolver/hsolver_pw_sdft.cpp @@ -7,7 +7,9 @@ #include -namespace hsolver { +namespace hsolver +{ + void HSolverPW_SDFT::solve(hamilt::Hamilt>* pHamilt, psi::Psi>& psi, elecstate::ElecState* pes, @@ -42,18 +44,18 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt>* pHamilt, // select the method of diagonalization this->method = method_in; // report if the specified diagonalization method is not supported - const std::initializer_list _methods - = {"cg", "dav", "dav_subspace", "bpcg"}; - if (std::find(std::begin(_methods), std::end(_methods), this->method) - == std::end(_methods)) { - ModuleBase::WARNING_QUIT("HSolverPW::solve", - "This method of DiagH is not supported!"); + const std::initializer_list _methods = {"cg", "dav", "dav_subspace", "bpcg"}; + if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods)) + { + ModuleBase::WARNING_QUIT("HSolverPW::solve", "This method of DiagH is not supported!"); } // part of KSDFT to get KS orbitals - for (int ik = 0; ik < nks; ++ik) { + for (int ik = 0; ik < nks; ++ik) + { pHamilt->updateHk(ik); - if (nbands > 0 && GlobalV::MY_STOGROUP == 0) { + if (nbands > 0 && GlobalV::MY_STOGROUP == 0) + { this->updatePsiK(pHamilt, psi, ik); // template add precondition calculating here update_precondition(precondition, ik, this->wfc_basis->npwk[ik]); @@ -65,12 +67,9 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt>* pHamilt, stoiter.stohchi.current_ik = ik; #ifdef __MPI - if (nbands > 0) { - MPI_Bcast(&psi(ik, 0, 0), - npwx * nbands, - MPI_DOUBLE_COMPLEX, - 0, - PARAPW_WORLD); + if (nbands > 0) + { + MPI_Bcast(&psi(ik, 0, 0), npwx * nbands, MPI_DOUBLE_COMPLEX, 0, PARAPW_WORLD); MPI_Bcast(&(pes->ekb(ik, 0)), nbands, MPI_DOUBLE, 0, PARAPW_WORLD); } #endif @@ -86,32 +85,38 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt>* pHamilt, this->initialed_psi = true; } - for (int ik = 0; ik < nks; ik++) { + for (int ik = 0; ik < nks; ik++) + { // init k - if (nks > 1) { + if (nks > 1) + { pHamilt->updateHk(ik); -} + } stoiter.stohchi.current_ik = ik; stoiter.calPn(ik, stowf); } stoiter.itermu(iter, pes); stoiter.calHsqrtchi(stowf); - if (skip_charge) { + if (skip_charge) + { ModuleBase::timer::tick("HSolverPW_SDFT", "solve"); return; } //(5) calculate new charge density // calculate KS rho. - if (nbands > 0) { + if (nbands > 0) + { pes->psiToRho(psi); #ifdef __MPI MPI_Bcast(&pes->f_en.eband, 1, MPI_DOUBLE, 0, PARAPW_WORLD); #endif - } else { - for (int is = 0; is < GlobalV::NSPIN; is++) { - ModuleBase::GlobalFunc::ZEROS(pes->charge->rho[is], - pes->charge->nrxx); + } + else + { + for (int is = 0; is < GlobalV::NSPIN; is++) + { + ModuleBase::GlobalFunc::ZEROS(pes->charge->rho[is], pes->charge->nrxx); } } // calculate stochastic rho @@ -122,29 +127,4 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt>* pHamilt, return; } -double HSolverPW_SDFT::set_diagethr(double diag_ethr_in, - const int istep, - const int iter, - const double drho) { - if (iter == 1) { - if (istep == 0) { - if (GlobalV::init_chg == "file") { - diag_ethr_in = 1.0e-5; - } - diag_ethr_in = std::max(diag_ethr_in, GlobalV::PW_DIAG_THR); - } else { - diag_ethr_in = std::max(diag_ethr_in, 1.0e-5); -} - } else { - if (GlobalV::NBANDS > 0 && this->stoiter.KS_ne > 1e-6) { - diag_ethr_in - = std::min(diag_ethr_in, - 0.1 * drho / std::max(1.0, this->stoiter.KS_ne)); - } else { - diag_ethr_in = 0.0; -} - } - - return diag_ethr_in; -} } // namespace hsolver \ No newline at end of file diff --git a/source/module_hsolver/hsolver_pw_sdft.h b/source/module_hsolver/hsolver_pw_sdft.h index 610805ed55..14a45627d1 100644 --- a/source/module_hsolver/hsolver_pw_sdft.h +++ b/source/module_hsolver/hsolver_pw_sdft.h @@ -32,13 +32,6 @@ class HSolverPW_SDFT : public HSolverPW> const double pw_diag_thr_in, const bool skip_charge); - virtual double set_diagethr(double diag_ethr_in, const int istep, const int iter, const double drho) override; - - void set_KS_ne(const double& KS_ne_in) - { - stoiter.KS_ne = KS_ne_in; - } - Stochastic_Iter stoiter; }; } // namespace hsolver diff --git a/source/module_hsolver/test/test_hsolver.cpp b/source/module_hsolver/test/test_hsolver.cpp index d49094f9e2..2adbceaec7 100644 --- a/source/module_hsolver/test/test_hsolver.cpp +++ b/source/module_hsolver/test/test_hsolver.cpp @@ -76,14 +76,14 @@ class TestHSolver : public ::testing::Test // EXPECT_EQ(hs_d.method, "none"); // } -TEST_F(TestHSolver, diagethr) -{ - float test_diagethr = hs_f.set_diagethr(0.0, 0, 0, 0.0); - EXPECT_EQ(test_diagethr, 0.0); +// TEST_F(TestHSolver, diagethr) +// { +// float test_diagethr = hs_f.set_diagethr(0.0, 0, 0, 0.0); +// EXPECT_EQ(test_diagethr, 0.0); - double test_diagethr_d = hs_d.set_diagethr(0.0, 0, 0, 0.0); - EXPECT_EQ(test_diagethr_d, 0.0); -} +// double test_diagethr_d = hs_d.set_diagethr(0.0, 0, 0, 0.0); +// EXPECT_EQ(test_diagethr_d, 0.0); +// } namespace hsolver { template