diff --git a/source/module_elecstate/cal_dm.h b/source/module_elecstate/cal_dm.h index 5ac41aab9a..13f41bf455 100644 --- a/source/module_elecstate/cal_dm.h +++ b/source/module_elecstate/cal_dm.h @@ -27,7 +27,7 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, //dm.fix_k(ik); dm[ik].create(ParaV->ncol, ParaV->nrow); // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); - psi::Psi wg_wfc(wfc, 1); + psi::Psi wg_wfc(wfc, 1, nbands_local); int ib_global = 0; for (int ib_local = 0; ib_local < nbands_local; ++ib_local) @@ -41,7 +41,8 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!"); } } - if (ib_global >= wg.nc) continue; + if (ib_global >= wg.nc) { continue; +} const double wg_local = wg(ik, ib_global); double* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0)); BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1); @@ -99,7 +100,8 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!"); } } - if (ib_global >= wg.nc) continue; + if (ib_global >= wg.nc) { continue; +} const double wg_local = wg(ik, ib_global); std::complex* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0)); BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1); diff --git a/source/module_elecstate/module_dm/cal_dm_psi.cpp b/source/module_elecstate/module_dm/cal_dm_psi.cpp index 47fbfbf8c3..cd868dcf9e 100644 --- a/source/module_elecstate/module_dm/cal_dm_psi.cpp +++ b/source/module_elecstate/module_dm/cal_dm_psi.cpp @@ -32,7 +32,8 @@ void cal_dm_psi(const Parallel_Orbitals* ParaV, // dm.fix_k(ik); // dm[ik].create(ParaV->ncol, ParaV->nrow); // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); - psi::Psi wg_wfc(wfc, 1); + + psi::Psi wg_wfc(wfc, 1, nbands_local); int ib_global = 0; for (int ib_local = 0; ib_local < nbands_local; ++ib_local) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index a99e813e01..4a830489f4 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -4,28 +4,31 @@ using namespace hamilt; - -template -Operator::Operator(){} - -template -Operator::~Operator() +template +Operator::Operator() { - if(this->hpsi != nullptr) { delete this->hpsi; } + +template +Operator::~Operator() +{ + if (this->hpsi != nullptr) + { + delete this->hpsi; + } Operator* last = this->next_op; Operator* last_sub = this->next_sub_op; - while(last != nullptr || last_sub != nullptr) + while (last != nullptr || last_sub != nullptr) { - if(last_sub != nullptr) - {//delete sub_chain first + if (last_sub != nullptr) + { // delete sub_chain first Operator* node_delete = last_sub; last_sub = last_sub->next_sub_op; node_delete->next_sub_op = nullptr; delete node_delete; } else - {//delete main chain if sub_chain is deleted + { // delete main chain if sub_chain is deleted Operator* node_delete = last; last_sub = last->next_sub_op; node_delete->next_sub_op = nullptr; @@ -36,7 +39,7 @@ Operator::~Operator() } } -template +template typename Operator::hpsi_info Operator::hPsi(hpsi_info& input) const { using syncmem_op = base_device::memory::synchronize_memory_op; @@ -46,12 +49,12 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp T* tmhpsi = this->get_hpsi(input); const T* tmpsi_in = std::get<0>(psi_info); - //if range in hpsi_info is illegal, the first return of to_range() would be nullptr + // if range in hpsi_info is illegal, the first return of to_range() would be nullptr if (tmpsi_in == nullptr) { ModuleBase::WARNING_QUIT("Operator", "please choose correct range of psi for hPsi()!"); } - //if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return + // if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return T* hpsi_pointer = std::get<2>(input); if (this->in_place) { @@ -62,21 +65,26 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp } auto call_act = [&, this](const Operator* op, const bool& is_first_node) -> void { - // a "psi" with the bands of needed range psi::Psi psi_wrapper(const_cast(tmpsi_in), 1, nbands, psi_input->get_nbasis(), true); - - + switch (op->get_act_type()) { case 2: op->act(psi_wrapper, *this->hpsi, nbands); break; default: - op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_ngk(op->ik), is_first_node); + op->act(nbands, + psi_input->get_nbasis(), + psi_input->npol, + tmpsi_in, + this->hpsi->get_pointer(), + psi_input->get_ngk(op->ik), + // psi_input->get_current_nbas(), + is_first_node); break; } - }; + }; ModuleBase::timer::tick("Operator", "hPsi"); call_act(this, true); // first node @@ -91,39 +99,43 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer); } - -template -void Operator::init(const int ik_in) +template +void Operator::init(const int ik_in) { this->ik = ik_in; - if(this->next_op != nullptr) { + if (this->next_op != nullptr) + { this->next_op->init(ik_in); } } -template -void Operator::add(Operator* next) +template +void Operator::add(Operator* next) { - if(next==nullptr) { return; -} + if (next == nullptr) + { + return; + } next->is_first_node = false; - if(next->next_op != nullptr) { this->add(next->next_op); -} + if (next->next_op != nullptr) + { + this->add(next->next_op); + } Operator* last = this; - //loop to end of the chain - while(last->next_op != nullptr) + // loop to end of the chain + while (last->next_op != nullptr) { - if(next->cal_type==last->cal_type) + if (next->cal_type == last->cal_type) { break; } last = last->next_op; } - if(next->cal_type == last->cal_type) + if (next->cal_type == last->cal_type) { - //insert next to sub chain of current node + // insert next to sub chain of current node Operator* sub_last = last; - while(sub_last->next_sub_op != nullptr) + while (sub_last->next_sub_op != nullptr) { sub_last = sub_last->next_sub_op; } @@ -136,24 +148,24 @@ void Operator::add(Operator* next) } } -template +template T* Operator::get_hpsi(const hpsi_info& info) const { const int nbands_range = (std::get<1>(info).range_2 - std::get<1>(info).range_1 + 1); - //in_place call of hPsi, hpsi inputs as new psi, - //create a new hpsi and delete old hpsi later + // in_place call of hPsi, hpsi inputs as new psi, + // create a new hpsi and delete old hpsi later T* hpsi_pointer = std::get<2>(info); const T* psi_pointer = std::get<0>(info)->get_pointer(); - if(this->hpsi != nullptr) + if (this->hpsi != nullptr) { delete this->hpsi; this->hpsi = nullptr; } - if(!hpsi_pointer) + if (!hpsi_pointer) { ModuleBase::WARNING_QUIT("Operator::hPsi", "hpsi_pointer can not be nullptr"); } - else if(hpsi_pointer == psi_pointer) + else if (hpsi_pointer == psi_pointer) { this->in_place = true; this->hpsi = new psi::Psi(std::get<0>(info)[0], 1, nbands_range); @@ -163,7 +175,7 @@ T* Operator::get_hpsi(const hpsi_info& info) const this->in_place = false; this->hpsi = new psi::Psi(hpsi_pointer, std::get<0>(info)[0], 1, nbands_range); } - + hpsi_pointer = this->hpsi->get_pointer(); size_t total_hpsi_size = nbands_range * this->hpsi->get_nbasis(); // ModuleBase::GlobalFunc::ZEROS(hpsi_pointer, total_hpsi_size); @@ -172,7 +184,8 @@ T* Operator::get_hpsi(const hpsi_info& info) const return hpsi_pointer; } -namespace hamilt { +namespace hamilt +{ template class Operator; template class Operator, base_device::DEVICE_CPU>; template class Operator; @@ -183,4 +196,4 @@ template class Operator, base_device::DEVICE_GPU>; template class Operator; template class Operator, base_device::DEVICE_GPU>; #endif -} +} // namespace hamilt diff --git a/source/module_hamilt_general/operator.h b/source/module_hamilt_general/operator.h index 6cf29122fe..80ed065ccc 100644 --- a/source/module_hamilt_general/operator.h +++ b/source/module_hamilt_general/operator.h @@ -1,19 +1,19 @@ #ifndef OPERATOR_H #define OPERATOR_H -#include - #include "module_base/global_function.h" #include "module_base/tool_quit.h" #include "module_psi/psi.h" +#include + namespace hamilt { enum class calculation_type { no, - pw_ekinetic, + pw_ekinetic, pw_nonlocal, pw_veff, pw_meta, @@ -28,49 +28,54 @@ enum class calculation_type lcao_tddft_velocity, }; -// Basic class for operator module, +// Basic class for operator module, // it is designed for "O|psi>" and "" // Operator "O" might have several different types, which should be calculated one by one. // In basic class , function add() is designed for combine all operators together with a chain. template class Operator { - public: + public: Operator(); virtual ~Operator(); - //this is the core function for Operator - // do H|psi> from input |psi> , + // this is the core function for Operator + // do H|psi> from input |psi> , /// as default, different operators donate hPsi independently - /// run this->act function for the first operator and run all act() for other nodes in chain table + /// run this->act function for the first operator and run all act() for other nodes in chain table /// if this procedure is not suitable for your operator, just override this function. - /// output of hpsi would be first member of the returned tuple + /// output of hpsi would be first member of the returned tuple typedef std::tuple*, const psi::Range, T*> hpsi_info; - virtual hpsi_info hPsi(hpsi_info& input)const; + + virtual hpsi_info hPsi(hpsi_info& input) const; virtual void init(const int ik_in); virtual void add(Operator* next); - virtual int get_ik() const { return this->ik; } + virtual int get_ik() const + { + return this->ik; + } - ///do operation : |hpsi_choosed> = V|psi_choosed> - ///V is the target operator act on choosed psi, the consequence should be added to choosed hpsi - /// interface type 1: pointer-only (default) - /// @note PW: nbasis = max_npw * npol, nbands = nband * npol, npol = npol. Strange but PAY ATTENTION!!! + /// do operation : |hpsi_choosed> = V|psi_choosed> + /// V is the target operator act on choosed psi, the consequence should be added to choosed hpsi + /// interface type 1: pointer-only (default) + /// @note PW: nbasis = max_npw * npol, nbands = nband * npol, npol = npol. Strange but PAY ATTENTION!!! virtual void act(const int nbands, - const int nbasis, - const int npol, - const T* tmpsi_in, - T* tmhpsi, - const int ngk_ik = 0, - const bool is_first_node = false)const {}; + const int nbasis, + const int npol, + const T* tmpsi_in, + T* tmhpsi, + const int ngk_ik = 0, + const bool is_first_node = false) const {}; /// developer-friendly interfaces for act() function /// interface type 2: input and change the Psi-type HPsi // virtual void act(const psi::Psi& psi_in, psi::Psi& psi_out) const {}; virtual void act(const psi::Psi& psi_in, psi::Psi& psi_out, const int nbands) const {}; + /// interface type 3: return a Psi-type HPsi // virtual psi::Psi act(const psi::Psi& psi_in) const { return psi_in; }; @@ -78,36 +83,41 @@ class Operator /// type 1 (default): pointer-only /// act(const T* psi_in, T* psi_out) - /// type 2: use the `Psi`class + /// type 2: use the `Psi`class /// act(const Psi& psi_in, Psi& psi_out) - int get_act_type() const { return this->act_type; } -protected: + int get_act_type() const + { + return this->act_type; + } + + protected: int ik = 0; - int act_type = 1; ///< determine which act() interface would be called in hPsi() + int act_type = 1; ///< determine which act() interface would be called in hPsi() mutable bool in_place = false; - //calculation type, only different type can be in main chain table + // calculation type, only different type can be in main chain table enum calculation_type cal_type; Operator* next_sub_op = nullptr; bool is_first_node = true; - //if this Operator is first node in chain table, hpsi would not be empty + // if this Operator is first node in chain table, hpsi would not be empty mutable psi::Psi* hpsi = nullptr; /*This function would analyze hpsi_info and choose how to arrange hpsi storage In hpsi_info, if the third parameter hpsi_pointer is set, which indicates memory of hpsi is arranged by developer; - if hpsi_pointer is not set(nullptr), which indicates memory of hpsi is arranged by Operator, this case is rare. + if hpsi_pointer is not set(nullptr), which indicates memory of hpsi is arranged by Operator, this case is rare. two cases would occurred: - 1. hpsi_pointer != nullptr && psi_pointer == hpsi_pointer , psi would be replaced by hpsi, hpsi need a temporary memory - 2. hpsi_pointer != nullptr && psi_pointer != hpsi_pointer , this is the commonly case + 1. hpsi_pointer != nullptr && psi_pointer == hpsi_pointer , psi would be replaced by hpsi, hpsi need a temporary + memory + 2. hpsi_pointer != nullptr && psi_pointer != hpsi_pointer , this is the commonly case */ - T* get_hpsi(const hpsi_info& info)const; + T* get_hpsi(const hpsi_info& info) const; - Device *ctx = {}; + Device* ctx = {}; using set_memory_op = base_device::memory::set_memory_op; }; -}//end namespace hamilt +} // end namespace hamilt #endif \ No newline at end of file diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp index 858e6b3fd5..0d49cadaa0 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp @@ -23,7 +23,8 @@ Velocity::Velocity ModuleBase::WARNING_QUIT("Velocity", "Constuctor of Operator::Velocity is failed, please check your code!"); } this->tpiba = ucell_in -> tpiba; - if(this->nonlocal) this->ppcell->initgradq_vnl(*this->ucell); + if(this->nonlocal) { this->ppcell->initgradq_vnl(*this->ucell); +} } void Velocity::init(const int ik_in) diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index 0c1ad2e8b8..a885296c62 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -378,10 +378,12 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, ngk_vector[i] = ngk_pointer[i]; } + const int cur_nbasis = psi.get_ngk(psi.get_current_k()); + if (this->method == "cg") { // wrap the subspace_func into a lambda function - auto subspace_func = [hm, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& psi_out) { + auto subspace_func = [hm, ngk_vector, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& psi_out) { // psi_in should be a 2D tensor: // psi_in.shape() = [nbands, nbasis] const auto ndim = psi_in.shape().ndim(); @@ -391,12 +393,14 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, 1, psi_in.shape().dim_size(0), psi_in.shape().dim_size(1), - ngk_vector); + ngk_vector, + cur_nbasis); auto psi_out_wrapper = psi::Psi(psi_out.data(), 1, psi_out.shape().dim_size(0), psi_out.shape().dim_size(1), - ngk_vector); + ngk_vector, + cur_nbasis); auto eigen = ct::Tensor(ct::DataTypeToEnum::value, ct::DeviceType::CpuDevice, ct::TensorShape({psi_in.shape().dim_size(0)})); @@ -415,7 +419,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, using ct_Device = typename ct::PsiToContainer::type; // wrap the hpsi_func and spsi_func into a lambda function - auto hpsi_func = [hm, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) { + auto hpsi_func = [hm, ngk_vector, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) { ModuleBase::timer::tick("DiagoCG_New", "hpsi_func"); // psi_in should be a 2D tensor: // psi_in.shape() = [nbands, nbasis] @@ -426,7 +430,8 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, 1, ndim == 1 ? 1 : psi_in.shape().dim_size(0), ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), - ngk_vector); + ngk_vector, + cur_nbasis); psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1); using hpsi_info = typename hamilt::Operator::hpsi_info; hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data()); @@ -486,11 +491,11 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, const int nband = psi.get_nbands(); const int nbasis = psi.get_nbasis(); // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec - auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { + auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { ModuleBase::timer::tick("DavSubspace", "hpsi_func"); // Convert "pointer data stucture" to a psi::Psi object - auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector); + auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis); psi::Range bands_range(true, 0, 0, nvec - 1); @@ -507,11 +512,11 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, else if (this->method == "dav_subspace") { // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec - auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { + auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { ModuleBase::timer::tick("DavSubspace", "hpsi_func"); // Convert "pointer data stucture" to a psi::Psi object - auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector); + auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis); psi::Range bands_range(true, 0, 0, nvec - 1); @@ -558,11 +563,11 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, // Davidson matrix-blockvector functions /// wrap hpsi into lambda function, Matrix \times blockvector // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec - auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { + auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { ModuleBase::timer::tick("David", "hpsi_func"); // Convert pointer of psi_in to a psi::Psi object - auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector); + auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis); psi::Range bands_range(true, 0, 0, nvec - 1); diff --git a/source/module_io/get_pchg_lcao.cpp b/source/module_io/get_pchg_lcao.cpp index 6e069fd017..4b3013b581 100644 --- a/source/module_io/get_pchg_lcao.cpp +++ b/source/module_io/get_pchg_lcao.cpp @@ -478,7 +478,8 @@ void IState_Charge::idmatrix(const int& ib, // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); this->psi_gamma->fix_k(is); - psi::Psi wg_wfc(*this->psi_gamma, 1); + + psi::Psi wg_wfc(*this->psi_gamma, 1, this->psi_gamma->get_nbands()); for (int ir = 0; ir < wg_wfc.get_nbands(); ++ir) { @@ -540,7 +541,8 @@ void IState_Charge::idmatrix(const int& ib, } this->psi_k->fix_k(ik); - psi::Psi> wg_wfc(*this->psi_k, 1); + // psi::Psi> wg_wfc(*this->psi_k, 1); + psi::Psi> wg_wfc(1, this->psi_k->get_nbands(), this->psi_k->get_nbasis()); for (int ir = 0; ir < wg_wfc.get_nbands(); ++ir) { diff --git a/source/module_io/write_dos_lcao.cpp b/source/module_io/write_dos_lcao.cpp index e475c77459..df07cef1d6 100644 --- a/source/module_io/write_dos_lcao.cpp +++ b/source/module_io/write_dos_lcao.cpp @@ -461,7 +461,8 @@ void ModuleIO::write_dos_lcao(const UnitCell& ucell, } psi->fix_k(ik); - psi::Psi> Dwfc(psi[0], 1); + psi::Psi> Dwfc(*psi, 1, psi->get_nbands()); + std::complex* p_dwfc = Dwfc.get_pointer(); for (int index = 0; index < Dwfc.size(); ++index) { diff --git a/source/module_io/write_proj_band_lcao.cpp b/source/module_io/write_proj_band_lcao.cpp index ccd7a0d4b0..b5660f7da5 100644 --- a/source/module_io/write_proj_band_lcao.cpp +++ b/source/module_io/write_proj_band_lcao.cpp @@ -25,8 +25,9 @@ void ModuleIO::write_proj_band_lcao( const double* sk = dynamic_cast*>(p_ham)->getSk(); int nspin0 = 1; - if (PARAM.inp.nspin == 2) + if (PARAM.inp.nspin == 2) { nspin0 = 2; +} int nks = 0; if (nspin0 == 1) { @@ -103,14 +104,16 @@ void ModuleIO::write_proj_band_lcao( out << "" << std::endl; out << "" << PARAM.inp.nspin << "" << std::endl; - if (PARAM.inp.nspin == 4) + if (PARAM.inp.nspin == 4) { out << "" << std::setw(2) << PARAM.globalv.nlocal / 2 << "" << std::endl; - else + } else { out << "" << std::setw(2) << PARAM.globalv.nlocal << "" << std::endl; +} out << "" << std::endl; - for (int ib = 0; ib < PARAM.inp.nbands; ib++) + for (int ib = 0; ib < PARAM.inp.nbands; ib++) { out << " " << (pelec->ekb(is * nks, ib)) * ModuleBase::Ry_to_eV; +} out << std::endl; out << "" << std::endl; @@ -139,9 +142,9 @@ void ModuleIO::write_proj_band_lcao( out << "" << std::endl; for (int ib = 0; ib < PARAM.inp.nbands; ib++) { - if (PARAM.inp.nspin == 1 || PARAM.inp.nspin == 2) + if (PARAM.inp.nspin == 1 || PARAM.inp.nspin == 2) { out << std::setw(13) << weight(is, ib * PARAM.globalv.nlocal + w); - else if (PARAM.inp.nspin == 4) + } else if (PARAM.inp.nspin == 4) { int w0 = w - s0; out << std::setw(13) @@ -178,8 +181,9 @@ void ModuleIO::write_proj_band_lcao( ModuleBase::timer::tick("ModuleIO", "write_proj_band_lcao"); int nspin0 = 1; - if (PARAM.inp.nspin == 2) + if (PARAM.inp.nspin == 2) { nspin0 = 2; +} int nks = 0; if (nspin0 == 1) { @@ -221,7 +225,8 @@ void ModuleIO::write_proj_band_lcao( // calculate Mulk psi->fix_k(ik); - psi::Psi> Dwfc(psi[0], 1); + psi::Psi> Dwfc(psi[0], 1, psi->get_nbands()); + std::complex* p_dwfc = Dwfc.get_pointer(); for (int index = 0; index < Dwfc.size(); ++index) { @@ -301,8 +306,9 @@ void ModuleIO::write_proj_band_lcao( for (int ik = 0; ik < nks; ik++) { - for (int ib = 0; ib < PARAM.inp.nbands; ib++) + for (int ib = 0; ib < PARAM.inp.nbands; ib++) { out << " " << (pelec->ekb(ik + is * nks, ib)) * ModuleBase::Ry_to_eV; +} out << std::endl; } out << "" << std::endl; diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index fb8abc78cd..3fcb347790 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -32,7 +32,6 @@ template Psi::Psi() { this->npol = PARAM.globalv.npol; - this->device = base_device::get_device_type(this->ctx); } template @@ -52,8 +51,9 @@ Psi::Psi(const int nk_in, const int nbd_in, const int nbs_in, const i this->current_b = 0; this->current_k = 0; this->npol = PARAM.globalv.npol; - this->device = base_device::get_device_type(this->ctx); + this->resize(nk_in, nbd_in, nbs_in); + // Currently only GPU's implementation is supported for device recording! base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); base_device::information::record_device_memory(this->ctx, @@ -69,6 +69,7 @@ Psi::Psi(T* psi_pointer, const int nbd_in, const int nbs_in, const std::vector& ngk_vector_in, + const int current_nbasis_in, const bool k_first_in) { this->k_first = k_first_in; @@ -76,11 +77,10 @@ Psi::Psi(T* psi_pointer, this->current_b = 0; this->current_k = 0; this->npol = PARAM.globalv.npol; - this->device = base_device::get_device_type(this->ctx); this->nk = nk_in; this->nbands = nbd_in; this->nbasis = nbs_in; - this->current_nbasis = nbs_in; + this->current_nbasis = current_nbasis_in; this->psi_current = this->psi = psi_pointer; this->allocate_inside = false; // Currently only GPU's implementation is supported for device recording! @@ -96,7 +96,6 @@ Psi::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int this->current_b = 0; this->current_k = 0; this->npol = PARAM.globalv.npol; - this->device = base_device::get_device_type(this->ctx); this->nk = nk_in; this->nbands = nbd_in; this->nbasis = nbs_in; @@ -108,28 +107,40 @@ Psi::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int } template -Psi::Psi(const Psi& psi_in, const int nk_in, int nband_in) +Psi::Psi(const Psi& psi_in, const int nk_in, const int nband_in) { - assert(nk_in <= psi_in.get_nk()); - if (nband_in == 0) - { - nband_in = psi_in.get_nbands(); - } + assert(nk_in <= psi_in.get_nk() && nk_in > 0); + assert(nband_in <= psi_in.get_nbands() && nband_in > 0); + this->k_first = psi_in.get_k_first(); - this->device = psi_in.device; - this->resize(nk_in, nband_in, psi_in.get_nbasis()); - this->ngk = psi_in.ngk; this->npol = psi_in.npol; - if (nband_in <= psi_in.get_nbands()) + this->allocate_inside = true; + + this->nk = nk_in; + this->nbands = nband_in; + this->nbasis = psi_in.get_nbasis(); + + // This function will delete the psi array first(if psi exist), then malloc a new memory for it. + resize_memory_op()(this->ctx, + this->psi, + (static_cast(this->nk) * static_cast(this->nbands) + * static_cast(this->nbasis)), + "no_record"); + synchronize_memory_op()(this->ctx, psi_in.get_device(), this->psi, psi_in.get_pointer(), this->size()); + + this->current_k = 0; + this->current_b = 0; + this->current_nbasis = this->nbasis; + this->psi_current = this->psi; + this->psi_bias = 0; + + if (this->nk != psi_in.get_nk()) + { + this->ngk = nullptr; + } + else { - // copy from Psi from psi_in(current_k, 0, 0), - // if size of k is 1, current_k in new Psi is psi_in.current_k - if (nk_in == 1) - { - // current_k for this Psi only keep the spin index same as the copied Psi - this->current_k = psi_in.get_current_k(); - } - synchronize_memory_op()(this->ctx, psi_in.get_device(), this->psi, psi_in.get_pointer(), this->size()); + this->ngk = psi_in.ngk; } } @@ -137,8 +148,6 @@ template Psi::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in) { this->k_first = psi_in.get_k_first(); - this->device = base_device::get_device_type(this->ctx); - assert(this->device == psi_in.device); assert(nk_in <= psi_in.get_nk()); if (nband_in == 0) { @@ -166,7 +175,7 @@ Psi::Psi(const Psi& psi_in) this->current_b = psi_in.get_current_b(); this->k_first = psi_in.get_k_first(); // this function will copy psi_in.psi to this->psi no matter the device types of each other. - this->device = base_device::get_device_type(this->ctx); + this->resize(psi_in.get_nk(), psi_in.get_nbands(), psi_in.get_nbasis()); base_device::memory::synchronize_memory_op()(this->ctx, psi_in.get_device(), @@ -191,7 +200,7 @@ Psi::Psi(const Psi& psi_in) this->current_b = psi_in.get_current_b(); this->k_first = psi_in.get_k_first(); // this function will copy psi_in.psi to this->psi no matter the device types of each other. - this->device = base_device::get_device_type(this->ctx); + this->resize(psi_in.get_nk(), psi_in.get_nbands(), psi_in.get_nbasis()); // Specifically, if the Device_in type is CPU and the Device type is GPU. @@ -234,8 +243,12 @@ template void Psi::resize(const int nks_in, const int nbands_in, const int nbasis_in) { assert(nks_in > 0 && nbands_in >= 0 && nbasis_in > 0); + // This function will delete the psi array first(if psi exist), then malloc a new memory for it. resize_memory_op()(this->ctx, this->psi, nks_in * static_cast(nbands_in) * nbasis_in, "no_record"); + + // this->zero_out(); + this->nk = nks_in; this->nbands = nbands_in; this->nbasis = nbasis_in; diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 6b374c8a70..860112f066 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -43,7 +43,7 @@ class Psi Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in = nullptr, const bool k_first_in = true); // Constructor 4: copy a new Psi which have several k-points and several bands from inputted psi_in - Psi(const Psi& psi_in, const int nk_in, int nband_in = 0); + Psi(const Psi& psi_in, const int nk_in, const int nband_in); // Constructor 5: a wrapper of a data pointer, used for Operator::hPsi() // in this case, fix_k can not be used @@ -64,17 +64,13 @@ class Psi const int nbd_in, const int nbs_in, const std::vector& ngk_vector_in, + const int current_nbasis_in, const bool k_first_in = true); // Constructor 8-2: a pointer version of constructor 3 // only used in operator.cpp call_act func - Psi(T* psi_pointer, - const int nk_in, - const int nbd_in, - const int nbs_in, - const bool k_first_in); + Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int nbs_in, const bool k_first_in); - // Destructor for deleting the psi array manually ~Psi(); @@ -136,12 +132,12 @@ class Psi // solve Range: return(pointer of begin, number of bands or k-points) std::tuple to_range(const Range& range) const; + int npol = 1; private: T* psi = nullptr; // avoid using C++ STL - - base_device::AbacusDevice_t device = {}; // track the device type (CPU, GPU and SYCL are supported currented) + Device* ctx = {}; // an context identifier for obtaining the device variable // dimensions diff --git a/source/module_psi/test/psi_test.cpp b/source/module_psi/test/psi_test.cpp index df22b5f885..fa3f357407 100644 --- a/source/module_psi/test/psi_test.cpp +++ b/source/module_psi/test/psi_test.cpp @@ -15,7 +15,7 @@ class TestPsi : public ::testing::Test const psi::Psi>* psi_object33 = new psi::Psi>(ink, inbands, inbasis, &ngk[0]); const psi::Psi* psi_object34 = new psi::Psi(ink, inbands, inbasis, &ngk[0]); - psi::Psi>* psi_object4 = new psi::Psi>(*psi_object31, ink, 0); + // psi::Psi>* psi_object4 = new psi::Psi>(*psi_object31, ink, 0); psi::Psi>* psi_object5 = new psi::Psi>(psi_object31->get_pointer(), *psi_object31, ink, 0); };