Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: enable cal_force and cal_stress in nscf #5752

Merged
merged 3 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion source/module_elecstate/elecstate_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ void ElecStatePW<T, Device>::rhoBandK(const psi::Psi<T, Device>& psi)
}

template <typename T, typename Device>
void ElecStatePW<T, Device>::add_usrho(const psi::Psi<T, Device>& psi)
void ElecStatePW<T, Device>::cal_becsum(const psi::Psi<T, Device>& psi)
{
const T one{1, 0};
const T zero{0, 0};
Expand Down Expand Up @@ -392,6 +392,12 @@ void ElecStatePW<T, Device>::add_usrho(const psi::Psi<T, Device>& psi)
}
}
delmem_complex_op()(this->ctx, becp);
}

template <typename T, typename Device>
void ElecStatePW<T, Device>::add_usrho(const psi::Psi<T, Device>& psi)
{
this->cal_becsum(psi);

// transform soft charge to recip space using smooth grids
T* rhog = nullptr;
Expand Down
5 changes: 4 additions & 1 deletion source/module_elecstate/elecstate_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class ElecStatePW : public ElecState

virtual void cal_tau(const psi::Psi<T, Device>& psi);

//! calculate becsum for uspp
void cal_becsum(const psi::Psi<T, Device>& psi);

Real* becsum = nullptr;

//! init rho_data and kin_r_data
Expand All @@ -61,7 +64,7 @@ class ElecStatePW : public ElecState

//! calcualte rho for each k
void rhoBandK(const psi::Psi<T, Device>& psi);

//! add to the charge density in reciprocal space the part which is due to the US augmentation.
void add_usrho(const psi::Psi<T, Device>& psi);

Expand Down
4 changes: 4 additions & 0 deletions source/module_hsolver/hsolver_lcaopw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ void HSolverLIP<T>::solve(hamilt::Hamilt<T>* pHamilt, // ESolver_KS_PW::p_hamilt
reinterpret_cast<elecstate::ElecStatePW<T>*>(pes)->calEBand();
if (skip_charge)
{
if (PARAM.globalv.use_uspp)
{
reinterpret_cast<elecstate::ElecStatePW<T>*>(pes)->cal_becsum(psi);
}
ModuleBase::timer::tick("HSolverLIP", "solve");
return;
}
Expand Down
4 changes: 4 additions & 0 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
reinterpret_cast<elecstate::ElecStatePW<T>*>(pes)->calEBand();
if (skip_charge)
{
if (PARAM.globalv.use_uspp)
{
reinterpret_cast<elecstate::ElecStatePW<T, Device>*>(pes)->cal_becsum(psi);
}
ModuleBase::timer::tick("HSolverPW", "solve");
return;
}
Expand Down
42 changes: 41 additions & 1 deletion source/module_hsolver/test/hsolver_supplementary_mock.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#pragma once
#include "module_elecstate/elecstate.h"
#include "module_elecstate/elecstate_pw.h"
#include "module_psi/wavefunc.h"

namespace elecstate
Expand Down Expand Up @@ -62,6 +62,46 @@ void ElecState::init_ks(Charge* chg_in, // pointer for class Charge
return;
}

template <typename T, typename Device>
ElecStatePW<T, Device>::ElecStatePW(ModulePW::PW_Basis_K* wfc_basis_in,
Charge* chg_in,
K_Vectors* pkv_in,
UnitCell* ucell_in,
pseudopot_cell_vnl* ppcell_in,
ModulePW::PW_Basis* rhodpw_in,
ModulePW::PW_Basis* rhopw_in,
ModulePW::PW_Basis_Big* bigpw_in)
: basis(wfc_basis_in)
{
}

template <typename T, typename Device>
ElecStatePW<T, Device>::~ElecStatePW()
{
}

template <typename T, typename Device>
void ElecStatePW<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)
{
}

template <typename T, typename Device>
void ElecStatePW<T, Device>::cal_tau(const psi::Psi<T, Device>& psi)
{
}

template <typename T, typename Device>
void ElecStatePW<T, Device>::cal_becsum(const psi::Psi<T, Device>& psi)
{
}

template class ElecStatePW<std::complex<float>, base_device::DEVICE_CPU>;
template class ElecStatePW<std::complex<double>, base_device::DEVICE_CPU>;
#if ((defined __CUDA) || (defined __ROCM))
template class ElecStatePW<std::complex<float>, base_device::DEVICE_GPU>;
template class ElecStatePW<std::complex<double>, base_device::DEVICE_GPU>;
#endif

Potential::~Potential()
{
}
Expand Down
29 changes: 0 additions & 29 deletions source/module_hsolver/test/test_hsolver_sdft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,40 +23,11 @@ Sto_Func<REAL>::Sto_Func()
}
template class Sto_Func<double>;


template <>
elecstate::ElecStatePW<std::complex<double>, base_device::DEVICE_CPU>::ElecStatePW(ModulePW::PW_Basis_K* wfc_basis_in,
Charge* chg_in,
K_Vectors* pkv_in,
UnitCell* ucell_in,
pseudopot_cell_vnl* ppcell_in,
ModulePW::PW_Basis* rhodpw_in,
ModulePW::PW_Basis* rhopw_in,
ModulePW::PW_Basis_Big* bigpw_in)
: basis(wfc_basis_in)
{
}

template<>
elecstate::ElecStatePW<std::complex<double>, base_device::DEVICE_CPU>::~ElecStatePW()
{
}

template<>
void elecstate::ElecStatePW<std::complex<double>, base_device::DEVICE_CPU>::init_rho_data()
{
}

template<>
void elecstate::ElecStatePW<std::complex<double>, base_device::DEVICE_CPU>::psiToRho(const psi::Psi<std::complex<double>, base_device::DEVICE_CPU>& psi)
{
}

template<>
void elecstate::ElecStatePW<std::complex<double>, base_device::DEVICE_CPU>::cal_tau(const psi::Psi<std::complex<double>, base_device::DEVICE_CPU>& psi)
{
}

template <typename REAL, typename Device>
StoChe<REAL, Device>::StoChe(const int& nche, const int& method, const REAL& emax_sto, const REAL& emin_sto)
{
Expand Down
2 changes: 1 addition & 1 deletion source/module_io/read_input_item_system.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ void ReadInput::item_system()
item.annotation = "if calculate the force at the end of the electronic iteration";
item.reset_value = [](const Input_Item& item, Parameter& para) {
std::vector<std::string> use_force = {"cell-relax", "relax", "md"};
std::vector<std::string> not_use_force = {"get_wf", "get_pchg", "nscf", "get_S"};
std::vector<std::string> not_use_force = {"get_wf", "get_pchg", "get_S"};
if (std::find(use_force.begin(), use_force.end(), para.input.calculation) != use_force.end())
{
if (!para.input.cal_force)
Expand Down
128 changes: 62 additions & 66 deletions source/module_relax/relax_driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,90 +54,86 @@ void Relax_Driver::relax_driver(ModuleESolver::ESolver* p_esolver, UnitCell& uce
time_t fstart = time(nullptr);
ModuleBase::matrix force;
ModuleBase::matrix stress;
if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax" || PARAM.inp.calculation == "cell-relax")
{
// I'm considering putting force and stress
// as part of ucell and use ucell to pass information
// back and forth between esolver and relaxation
// but I'll use force and stress explicitly here for now

// calculate the total energy
this->etot = p_esolver->cal_energy();
// I'm considering putting force and stress
// as part of ucell and use ucell to pass information
// back and forth between esolver and relaxation
// but I'll use force and stress explicitly here for now

// calculate the total energy
this->etot = p_esolver->cal_energy();

// calculate and gather all parts of total ionic forces
if (PARAM.inp.cal_force)
{
p_esolver->cal_force(ucell, force);
}
// calculate and gather all parts of stress
if (PARAM.inp.cal_stress)
{
p_esolver->cal_stress(ucell, stress);
}

// calculate and gather all parts of total ionic forces
if (PARAM.inp.cal_force)
if (PARAM.inp.calculation == "relax" || PARAM.inp.calculation == "cell-relax")
{
if (PARAM.inp.relax_new)
{
p_esolver->cal_force(ucell, force);
stop = rl.relax_step(ucell, force, stress, this->etot);
}
// calculate and gather all parts of stress
if (PARAM.inp.cal_stress)
else
{
p_esolver->cal_stress(ucell, stress);
stop = rl_old.relax_step(istep,
this->etot,
ucell,
force,
stress,
force_step,
stress_step); // pengfei Li 2018-05-14
}

if (PARAM.inp.calculation == "relax" || PARAM.inp.calculation == "cell-relax")
// print structure
// changelog 20240509
// because I move out the dependence on GlobalV from UnitCell::print_stru_file
// so its parameter is calculated here
bool need_orb = PARAM.inp.basis_type == "pw";
need_orb = need_orb && PARAM.inp.psi_initializer;
need_orb = need_orb && PARAM.inp.init_wfc.substr(0, 3) == "nao";
need_orb = need_orb || PARAM.inp.basis_type == "lcao";
need_orb = need_orb || PARAM.inp.basis_type == "lcao_in_pw";
std::stringstream ss, ss1;
ss << PARAM.globalv.global_out_dir << "STRU_ION_D";
ucell.print_stru_file(ss.str(),
PARAM.inp.nspin,
true,
PARAM.inp.calculation == "md",
PARAM.inp.out_mul,
need_orb,
PARAM.globalv.deepks_setorb,
GlobalV::MY_RANK);

if (Ions_Move_Basic::out_stru)
{
if (PARAM.inp.relax_new)
{
stop = rl.relax_step(ucell,force, stress, this->etot);
}
else
{
stop = rl_old.relax_step(istep,
this->etot,
ucell,
force,
stress,
force_step,
stress_step); // pengfei Li 2018-05-14
}
// print structure
// changelog 20240509
// because I move out the dependence on GlobalV from UnitCell::print_stru_file
// so its parameter is calculated here
bool need_orb = PARAM.inp.basis_type == "pw";
need_orb = need_orb && PARAM.inp.psi_initializer;
need_orb = need_orb && PARAM.inp.init_wfc.substr(0, 3) == "nao";
need_orb = need_orb || PARAM.inp.basis_type == "lcao";
need_orb = need_orb || PARAM.inp.basis_type == "lcao_in_pw";
std::stringstream ss, ss1;
ss << PARAM.globalv.global_out_dir << "STRU_ION_D";
ucell.print_stru_file(ss.str(),
ss1 << PARAM.globalv.global_out_dir << "STRU_ION";
ss1 << istep << "_D";
ucell.print_stru_file(ss1.str(),
PARAM.inp.nspin,
true,
PARAM.inp.calculation == "md",
PARAM.inp.out_mul,
need_orb,
PARAM.globalv.deepks_setorb,
GlobalV::MY_RANK);

if (Ions_Move_Basic::out_stru)
{
ss1 << PARAM.globalv.global_out_dir << "STRU_ION";
ss1 << istep << "_D";
ucell.print_stru_file(ss1.str(),
PARAM.inp.nspin,
true,
PARAM.inp.calculation == "md",
PARAM.inp.out_mul,
need_orb,
PARAM.globalv.deepks_setorb,
GlobalV::MY_RANK);
ModuleIO::CifParser::write(PARAM.globalv.global_out_dir + "STRU_NOW.cif",
ucell,
"# Generated by ABACUS ModuleIO::CifParser",
"data_?");
}

ModuleIO::output_after_relax(stop, p_esolver->conv_esolver, GlobalV::ofs_running);
ModuleIO::CifParser::write(PARAM.globalv.global_out_dir + "STRU_NOW.cif",
ucell,
"# Generated by ABACUS ModuleIO::CifParser",
"data_?");
}

#ifdef __RAPIDJSON
// add the energy to outout
Json::add_output_energy(p_esolver->cal_energy() * ModuleBase::Ry_to_eV);
#endif
ModuleIO::output_after_relax(stop, p_esolver->conv_esolver, GlobalV::ofs_running);
}

#ifdef __RAPIDJSON
// add the energy to outout
Json::add_output_energy(p_esolver->cal_energy() * ModuleBase::Ry_to_eV);
// add Json of cell coo stress force
double unit_transform = ModuleBase::RYDBERG_SI / pow(ModuleBase::BOHR_RADIUS_SI, 3) * 1.0e-8;
double fac = ModuleBase::Ry_to_eV / 0.529177;
Expand Down
Loading