From 84e63b6ca4318a787fe3330393829493c1aad0ee Mon Sep 17 00:00:00 2001 From: haozhihan Date: Wed, 25 Dec 2024 02:46:32 +0000 Subject: [PATCH 01/16] remove Psi(const Psi& psi_in, const int nk_in, int nband_in); --- source/module_elecstate/cal_dm.h | 3 ++- .../module_elecstate/module_dm/cal_dm_psi.cpp | 3 ++- source/module_hamilt_general/operator.cpp | 3 ++- source/module_io/get_pchg_lcao.cpp | 6 +++-- source/module_io/write_dos_lcao.cpp | 4 ++- source/module_io/write_proj_band_lcao.cpp | 4 ++- source/module_psi/psi.cpp | 26 ------------------- source/module_psi/psi.h | 3 --- source/module_psi/test/psi_test.cpp | 2 +- 9 files changed, 17 insertions(+), 37 deletions(-) diff --git a/source/module_elecstate/cal_dm.h b/source/module_elecstate/cal_dm.h index 5ac41aab9a..5344cabc1a 100644 --- a/source/module_elecstate/cal_dm.h +++ b/source/module_elecstate/cal_dm.h @@ -27,7 +27,8 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, //dm.fix_k(ik); dm[ik].create(ParaV->ncol, ParaV->nrow); // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); - psi::Psi wg_wfc(wfc, 1); + // psi::Psi wg_wfc(wfc, 1); + psi::Psi wg_wfc(1, nbands_local, nbasis_local); int ib_global = 0; for (int ib_local = 0; ib_local < nbands_local; ++ib_local) diff --git a/source/module_elecstate/module_dm/cal_dm_psi.cpp b/source/module_elecstate/module_dm/cal_dm_psi.cpp index 47fbfbf8c3..dc15f0635c 100644 --- a/source/module_elecstate/module_dm/cal_dm_psi.cpp +++ b/source/module_elecstate/module_dm/cal_dm_psi.cpp @@ -32,7 +32,8 @@ void cal_dm_psi(const Parallel_Orbitals* ParaV, // dm.fix_k(ik); // dm[ik].create(ParaV->ncol, ParaV->nrow); // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); - psi::Psi wg_wfc(wfc, 1); + // psi::Psi wg_wfc(wfc, 1, ); + psi::Psi wg_wfc(1, nbands_local, nbasis_local); int ib_global = 0; for (int ib_local = 0; ib_local < nbands_local; ++ib_local) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index a99e813e01..3ec209d0bf 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -156,7 +156,8 @@ T* Operator::get_hpsi(const hpsi_info& info) const else if(hpsi_pointer == psi_pointer) { this->in_place = true; - this->hpsi = new psi::Psi(std::get<0>(info)[0], 1, nbands_range); + // this->hpsi = new psi::Psi(std::get<0>(info)[0], 1, nbands_range); + this->hpsi = new psi::Psi(1, nbands_range, std::get<0>(info)->get_nbasis()); } else { diff --git a/source/module_io/get_pchg_lcao.cpp b/source/module_io/get_pchg_lcao.cpp index 6e069fd017..4cd3b05024 100644 --- a/source/module_io/get_pchg_lcao.cpp +++ b/source/module_io/get_pchg_lcao.cpp @@ -478,7 +478,8 @@ void IState_Charge::idmatrix(const int& ib, // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); this->psi_gamma->fix_k(is); - psi::Psi wg_wfc(*this->psi_gamma, 1); + // psi::Psi wg_wfc(*this->psi_gamma, 1); + psi::Psi wg_wfc(1, this->psi_gamma->get_nbands(), this->psi_gamma->get_nbasis()); for (int ir = 0; ir < wg_wfc.get_nbands(); ++ir) { @@ -540,7 +541,8 @@ void IState_Charge::idmatrix(const int& ib, } this->psi_k->fix_k(ik); - psi::Psi> wg_wfc(*this->psi_k, 1); + // psi::Psi> wg_wfc(*this->psi_k, 1); + psi::Psi> wg_wfc(1, this->psi_k->get_nbands(), this->psi_k->get_nbasis()); for (int ir = 0; ir < wg_wfc.get_nbands(); ++ir) { diff --git a/source/module_io/write_dos_lcao.cpp b/source/module_io/write_dos_lcao.cpp index e475c77459..8f4c251f4c 100644 --- a/source/module_io/write_dos_lcao.cpp +++ b/source/module_io/write_dos_lcao.cpp @@ -461,7 +461,9 @@ void ModuleIO::write_dos_lcao(const UnitCell& ucell, } psi->fix_k(ik); - psi::Psi> Dwfc(psi[0], 1); + // psi::Psi> Dwfc(psi[0], 1); + psi::Psi> Dwfc(1, psi->get_nbands(), psi->get_nbasis()); + std::complex* p_dwfc = Dwfc.get_pointer(); for (int index = 0; index < Dwfc.size(); ++index) { diff --git a/source/module_io/write_proj_band_lcao.cpp b/source/module_io/write_proj_band_lcao.cpp index ccd7a0d4b0..18759f2e29 100644 --- a/source/module_io/write_proj_band_lcao.cpp +++ b/source/module_io/write_proj_band_lcao.cpp @@ -221,7 +221,9 @@ void ModuleIO::write_proj_band_lcao( // calculate Mulk psi->fix_k(ik); - psi::Psi> Dwfc(psi[0], 1); + // psi::Psi> Dwfc(psi[0], 1); + psi::Psi> Dwfc(1, psi->get_nbands(), psi->get_nbasis()); + std::complex* p_dwfc = Dwfc.get_pointer(); for (int index = 0; index < Dwfc.size(); ++index) { diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index fb8abc78cd..20940529f1 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -107,32 +107,6 @@ Psi::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); } -template -Psi::Psi(const Psi& psi_in, const int nk_in, int nband_in) -{ - assert(nk_in <= psi_in.get_nk()); - if (nband_in == 0) - { - nband_in = psi_in.get_nbands(); - } - this->k_first = psi_in.get_k_first(); - this->device = psi_in.device; - this->resize(nk_in, nband_in, psi_in.get_nbasis()); - this->ngk = psi_in.ngk; - this->npol = psi_in.npol; - if (nband_in <= psi_in.get_nbands()) - { - // copy from Psi from psi_in(current_k, 0, 0), - // if size of k is 1, current_k in new Psi is psi_in.current_k - if (nk_in == 1) - { - // current_k for this Psi only keep the spin index same as the copied Psi - this->current_k = psi_in.get_current_k(); - } - synchronize_memory_op()(this->ctx, psi_in.get_device(), this->psi, psi_in.get_pointer(), this->size()); - } -} - template Psi::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in) { diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 6b374c8a70..2fe4f6cca6 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -42,9 +42,6 @@ class Psi // Constructor 3: specify nk, nbands, nbasis, ngk, and do not need to call resize() later Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in = nullptr, const bool k_first_in = true); - // Constructor 4: copy a new Psi which have several k-points and several bands from inputted psi_in - Psi(const Psi& psi_in, const int nk_in, int nband_in = 0); - // Constructor 5: a wrapper of a data pointer, used for Operator::hPsi() // in this case, fix_k can not be used Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in = 0); diff --git a/source/module_psi/test/psi_test.cpp b/source/module_psi/test/psi_test.cpp index df22b5f885..fa3f357407 100644 --- a/source/module_psi/test/psi_test.cpp +++ b/source/module_psi/test/psi_test.cpp @@ -15,7 +15,7 @@ class TestPsi : public ::testing::Test const psi::Psi>* psi_object33 = new psi::Psi>(ink, inbands, inbasis, &ngk[0]); const psi::Psi* psi_object34 = new psi::Psi(ink, inbands, inbasis, &ngk[0]); - psi::Psi>* psi_object4 = new psi::Psi>(*psi_object31, ink, 0); + // psi::Psi>* psi_object4 = new psi::Psi>(*psi_object31, ink, 0); psi::Psi>* psi_object5 = new psi::Psi>(psi_object31->get_pointer(), *psi_object31, ink, 0); }; From b15cd5c2413d02db421c336a8f4aad6c6949c04e Mon Sep 17 00:00:00 2001 From: haozhihan Date: Wed, 25 Dec 2024 03:30:14 +0000 Subject: [PATCH 02/16] fix bug --- source/module_psi/psi.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 20940529f1..28b6f4c90b 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -208,8 +208,12 @@ template void Psi::resize(const int nks_in, const int nbands_in, const int nbasis_in) { assert(nks_in > 0 && nbands_in >= 0 && nbasis_in > 0); + // This function will delete the psi array first(if psi exist), then malloc a new memory for it. resize_memory_op()(this->ctx, this->psi, nks_in * static_cast(nbands_in) * nbasis_in, "no_record"); + + this->zero_out(); + this->nk = nks_in; this->nbands = nbands_in; this->nbasis = nbasis_in; From 9900bb77542ecb8bcd305a4f36e5243d50cc04b7 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Wed, 25 Dec 2024 07:20:46 +0000 Subject: [PATCH 03/16] fix bug --- source/module_elecstate/cal_dm.h | 3 +- .../module_elecstate/module_dm/cal_dm_psi.cpp | 4 +-- source/module_hamilt_general/operator.cpp | 3 +- source/module_io/get_pchg_lcao.cpp | 4 +-- source/module_io/write_dos_lcao.cpp | 3 +- source/module_io/write_proj_band_lcao.cpp | 3 +- source/module_psi/psi.cpp | 32 +++++++++++++++++-- source/module_psi/psi.h | 4 +++ 8 files changed, 42 insertions(+), 14 deletions(-) diff --git a/source/module_elecstate/cal_dm.h b/source/module_elecstate/cal_dm.h index 5344cabc1a..b28f685d20 100644 --- a/source/module_elecstate/cal_dm.h +++ b/source/module_elecstate/cal_dm.h @@ -27,8 +27,7 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, //dm.fix_k(ik); dm[ik].create(ParaV->ncol, ParaV->nrow); // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); - // psi::Psi wg_wfc(wfc, 1); - psi::Psi wg_wfc(1, nbands_local, nbasis_local); + psi::Psi wg_wfc(wfc, 1, nbands_local); int ib_global = 0; for (int ib_local = 0; ib_local < nbands_local; ++ib_local) diff --git a/source/module_elecstate/module_dm/cal_dm_psi.cpp b/source/module_elecstate/module_dm/cal_dm_psi.cpp index dc15f0635c..cd868dcf9e 100644 --- a/source/module_elecstate/module_dm/cal_dm_psi.cpp +++ b/source/module_elecstate/module_dm/cal_dm_psi.cpp @@ -32,8 +32,8 @@ void cal_dm_psi(const Parallel_Orbitals* ParaV, // dm.fix_k(ik); // dm[ik].create(ParaV->ncol, ParaV->nrow); // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); - // psi::Psi wg_wfc(wfc, 1, ); - psi::Psi wg_wfc(1, nbands_local, nbasis_local); + + psi::Psi wg_wfc(wfc, 1, nbands_local); int ib_global = 0; for (int ib_local = 0; ib_local < nbands_local; ++ib_local) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index 3ec209d0bf..a99e813e01 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -156,8 +156,7 @@ T* Operator::get_hpsi(const hpsi_info& info) const else if(hpsi_pointer == psi_pointer) { this->in_place = true; - // this->hpsi = new psi::Psi(std::get<0>(info)[0], 1, nbands_range); - this->hpsi = new psi::Psi(1, nbands_range, std::get<0>(info)->get_nbasis()); + this->hpsi = new psi::Psi(std::get<0>(info)[0], 1, nbands_range); } else { diff --git a/source/module_io/get_pchg_lcao.cpp b/source/module_io/get_pchg_lcao.cpp index 4cd3b05024..4b3013b581 100644 --- a/source/module_io/get_pchg_lcao.cpp +++ b/source/module_io/get_pchg_lcao.cpp @@ -478,8 +478,8 @@ void IState_Charge::idmatrix(const int& ib, // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); this->psi_gamma->fix_k(is); - // psi::Psi wg_wfc(*this->psi_gamma, 1); - psi::Psi wg_wfc(1, this->psi_gamma->get_nbands(), this->psi_gamma->get_nbasis()); + + psi::Psi wg_wfc(*this->psi_gamma, 1, this->psi_gamma->get_nbands()); for (int ir = 0; ir < wg_wfc.get_nbands(); ++ir) { diff --git a/source/module_io/write_dos_lcao.cpp b/source/module_io/write_dos_lcao.cpp index 8f4c251f4c..df07cef1d6 100644 --- a/source/module_io/write_dos_lcao.cpp +++ b/source/module_io/write_dos_lcao.cpp @@ -461,8 +461,7 @@ void ModuleIO::write_dos_lcao(const UnitCell& ucell, } psi->fix_k(ik); - // psi::Psi> Dwfc(psi[0], 1); - psi::Psi> Dwfc(1, psi->get_nbands(), psi->get_nbasis()); + psi::Psi> Dwfc(*psi, 1, psi->get_nbands()); std::complex* p_dwfc = Dwfc.get_pointer(); for (int index = 0; index < Dwfc.size(); ++index) diff --git a/source/module_io/write_proj_band_lcao.cpp b/source/module_io/write_proj_band_lcao.cpp index 18759f2e29..34843207b7 100644 --- a/source/module_io/write_proj_band_lcao.cpp +++ b/source/module_io/write_proj_band_lcao.cpp @@ -221,8 +221,7 @@ void ModuleIO::write_proj_band_lcao( // calculate Mulk psi->fix_k(ik); - // psi::Psi> Dwfc(psi[0], 1); - psi::Psi> Dwfc(1, psi->get_nbands(), psi->get_nbasis()); + psi::Psi> Dwfc(psi[0], 1, psi->get_nbands()); std::complex* p_dwfc = Dwfc.get_pointer(); for (int index = 0; index < Dwfc.size(); ++index) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 28b6f4c90b..85e2f416cb 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -107,6 +107,34 @@ Psi::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); } + +template +Psi::Psi(const Psi& psi_in, const int nk_in, int nband_in) +{ + assert(nk_in <= psi_in.get_nk()); + if (nband_in == 0) + { + nband_in = psi_in.get_nbands(); + } + this->k_first = psi_in.get_k_first(); + this->device = psi_in.device; + this->resize(nk_in, nband_in, psi_in.get_nbasis()); + this->ngk = psi_in.ngk; + this->npol = psi_in.npol; + if (nband_in <= psi_in.get_nbands()) + { + // copy from Psi from psi_in(current_k, 0, 0), + // if size of k is 1, current_k in new Psi is psi_in.current_k + if (nk_in == 1) + { + // current_k for this Psi only keep the spin index same as the copied Psi + this->current_k = psi_in.get_current_k(); + } + synchronize_memory_op()(this->ctx, psi_in.get_device(), this->psi, psi_in.get_pointer(), this->size()); + } +} + + template Psi::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in) { @@ -208,11 +236,11 @@ template void Psi::resize(const int nks_in, const int nbands_in, const int nbasis_in) { assert(nks_in > 0 && nbands_in >= 0 && nbasis_in > 0); - + // This function will delete the psi array first(if psi exist), then malloc a new memory for it. resize_memory_op()(this->ctx, this->psi, nks_in * static_cast(nbands_in) * nbasis_in, "no_record"); - this->zero_out(); + // this->zero_out(); this->nk = nks_in; this->nbands = nbands_in; diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 2fe4f6cca6..41ac645ce3 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -42,6 +42,10 @@ class Psi // Constructor 3: specify nk, nbands, nbasis, ngk, and do not need to call resize() later Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in = nullptr, const bool k_first_in = true); + // Constructor 4: copy a new Psi which have several k-points and several bands from inputted psi_in + Psi(const Psi& psi_in, const int nk_in, int nband_in = 0); + + // Constructor 5: a wrapper of a data pointer, used for Operator::hPsi() // in this case, fix_k can not be used Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in = 0); From a02b5d8d5eb139a9d5ff46e41441807c93db5722 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Wed, 25 Dec 2024 08:05:47 +0000 Subject: [PATCH 04/16] [pre-commit.ci lite] apply automatic fixes --- source/module_elecstate/cal_dm.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/source/module_elecstate/cal_dm.h b/source/module_elecstate/cal_dm.h index b28f685d20..13f41bf455 100644 --- a/source/module_elecstate/cal_dm.h +++ b/source/module_elecstate/cal_dm.h @@ -41,7 +41,8 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!"); } } - if (ib_global >= wg.nc) continue; + if (ib_global >= wg.nc) { continue; +} const double wg_local = wg(ik, ib_global); double* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0)); BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1); @@ -99,7 +100,8 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!"); } } - if (ib_global >= wg.nc) continue; + if (ib_global >= wg.nc) { continue; +} const double wg_local = wg(ik, ib_global); std::complex* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0)); BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1); From 8e3a58fd7c60f60524f8301d74abb3645fefce47 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Wed, 25 Dec 2024 16:37:04 +0800 Subject: [PATCH 05/16] remove device value in psi --- source/module_psi/psi.cpp | 21 +++++++-------------- source/module_psi/psi.h | 15 ++++----------- 2 files changed, 11 insertions(+), 25 deletions(-) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 85e2f416cb..7bd2996808 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -32,7 +32,6 @@ template Psi::Psi() { this->npol = PARAM.globalv.npol; - this->device = base_device::get_device_type(this->ctx); } template @@ -52,8 +51,9 @@ Psi::Psi(const int nk_in, const int nbd_in, const int nbs_in, const i this->current_b = 0; this->current_k = 0; this->npol = PARAM.globalv.npol; - this->device = base_device::get_device_type(this->ctx); + this->resize(nk_in, nbd_in, nbs_in); + // Currently only GPU's implementation is supported for device recording! base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); base_device::information::record_device_memory(this->ctx, @@ -76,7 +76,6 @@ Psi::Psi(T* psi_pointer, this->current_b = 0; this->current_k = 0; this->npol = PARAM.globalv.npol; - this->device = base_device::get_device_type(this->ctx); this->nk = nk_in; this->nbands = nbd_in; this->nbasis = nbs_in; @@ -96,7 +95,6 @@ Psi::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int this->current_b = 0; this->current_k = 0; this->npol = PARAM.globalv.npol; - this->device = base_device::get_device_type(this->ctx); this->nk = nk_in; this->nbands = nbd_in; this->nbasis = nbs_in; @@ -111,13 +109,10 @@ Psi::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int template Psi::Psi(const Psi& psi_in, const int nk_in, int nband_in) { - assert(nk_in <= psi_in.get_nk()); - if (nband_in == 0) - { - nband_in = psi_in.get_nbands(); - } + assert(nk_in <= psi_in.get_nk() && nk_in > 0); + assert(nband_in <= psi_in.get_nbands() && nband_in > 0); + this->k_first = psi_in.get_k_first(); - this->device = psi_in.device; this->resize(nk_in, nband_in, psi_in.get_nbasis()); this->ngk = psi_in.ngk; this->npol = psi_in.npol; @@ -139,8 +134,6 @@ template Psi::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in) { this->k_first = psi_in.get_k_first(); - this->device = base_device::get_device_type(this->ctx); - assert(this->device == psi_in.device); assert(nk_in <= psi_in.get_nk()); if (nband_in == 0) { @@ -168,7 +161,7 @@ Psi::Psi(const Psi& psi_in) this->current_b = psi_in.get_current_b(); this->k_first = psi_in.get_k_first(); // this function will copy psi_in.psi to this->psi no matter the device types of each other. - this->device = base_device::get_device_type(this->ctx); + this->resize(psi_in.get_nk(), psi_in.get_nbands(), psi_in.get_nbasis()); base_device::memory::synchronize_memory_op()(this->ctx, psi_in.get_device(), @@ -193,7 +186,7 @@ Psi::Psi(const Psi& psi_in) this->current_b = psi_in.get_current_b(); this->k_first = psi_in.get_k_first(); // this function will copy psi_in.psi to this->psi no matter the device types of each other. - this->device = base_device::get_device_type(this->ctx); + this->resize(psi_in.get_nk(), psi_in.get_nbands(), psi_in.get_nbasis()); // Specifically, if the Device_in type is CPU and the Device type is GPU. diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 41ac645ce3..042fd865d7 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -43,9 +43,8 @@ class Psi Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in = nullptr, const bool k_first_in = true); // Constructor 4: copy a new Psi which have several k-points and several bands from inputted psi_in - Psi(const Psi& psi_in, const int nk_in, int nband_in = 0); - - + Psi(const Psi& psi_in, const int nk_in, const int nband_in); + // Constructor 5: a wrapper of a data pointer, used for Operator::hPsi() // in this case, fix_k can not be used Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in = 0); @@ -69,13 +68,8 @@ class Psi // Constructor 8-2: a pointer version of constructor 3 // only used in operator.cpp call_act func - Psi(T* psi_pointer, - const int nk_in, - const int nbd_in, - const int nbs_in, - const bool k_first_in); + Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int nbs_in, const bool k_first_in); - // Destructor for deleting the psi array manually ~Psi(); @@ -141,8 +135,7 @@ class Psi private: T* psi = nullptr; // avoid using C++ STL - - base_device::AbacusDevice_t device = {}; // track the device type (CPU, GPU and SYCL are supported currented) + Device* ctx = {}; // an context identifier for obtaining the device variable // dimensions From c716bb7e15ef8f62347c765b5495a81cbbb043cd Mon Sep 17 00:00:00 2001 From: haozhihan Date: Wed, 25 Dec 2024 17:44:17 +0800 Subject: [PATCH 06/16] update Psi(const Psi& psi_in, const int nk_in, int nband_in) --- source/module_psi/psi.cpp | 39 ++++++++++++++++++++++++++------------- source/module_psi/psi.h | 1 + 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 7bd2996808..cf2a926c22 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -105,7 +105,6 @@ Psi::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); } - template Psi::Psi(const Psi& psi_in, const int nk_in, int nband_in) { @@ -113,23 +112,37 @@ Psi::Psi(const Psi& psi_in, const int nk_in, int nband_in) assert(nband_in <= psi_in.get_nbands() && nband_in > 0); this->k_first = psi_in.get_k_first(); - this->resize(nk_in, nband_in, psi_in.get_nbasis()); - this->ngk = psi_in.ngk; this->npol = psi_in.npol; - if (nband_in <= psi_in.get_nbands()) + this->allocate_inside = true; + + this->nk = nk_in; + this->nbands = nband_in; + this->nbasis = psi_in.get_nbasis(); + + // This function will delete the psi array first(if psi exist), then malloc a new memory for it. + resize_memory_op()(this->ctx, + this->psi, + (static_cast(this->nk) * static_cast(this->nbands) + * static_cast(this->nbasis)), + "no_record"); + synchronize_memory_op()(this->ctx, psi_in.get_device(), this->psi, psi_in.get_pointer(), this->size()); + + this->current_k = 0; + this->current_b = 0; + this->current_nbasis = this->nbasis; + this->psi_current = this->psi; + this->psi_bias = 0; + + if (this->nk != psi_in.get_nk()) { - // copy from Psi from psi_in(current_k, 0, 0), - // if size of k is 1, current_k in new Psi is psi_in.current_k - if (nk_in == 1) - { - // current_k for this Psi only keep the spin index same as the copied Psi - this->current_k = psi_in.get_current_k(); - } - synchronize_memory_op()(this->ctx, psi_in.get_device(), this->psi, psi_in.get_pointer(), this->size()); + this->ngk = nullptr; + } + else + { + this->ngk = psi_in.ngk; } } - template Psi::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in) { diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 042fd865d7..88f143df2e 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -131,6 +131,7 @@ class Psi // solve Range: return(pointer of begin, number of bands or k-points) std::tuple to_range(const Range& range) const; + int npol = 1; private: From 1c2f523affda8da23db65558b88266951ee143f2 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Wed, 25 Dec 2024 12:21:23 +0000 Subject: [PATCH 07/16] update get_ngk usage --- source/module_hamilt_general/operator.cpp | 2 +- .../module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp | 2 +- source/module_psi/psi.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index a99e813e01..db54852331 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -73,7 +73,7 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp op->act(psi_wrapper, *this->hpsi, nbands); break; default: - op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_ngk(op->ik), is_first_node); + op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_current_nbas(), is_first_node); break; } }; diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp index 858e6b3fd5..1a1196b864 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp @@ -47,7 +47,7 @@ void Velocity::act ) const { ModuleBase::timer::tick("Operator", "Velocity"); - const int npw = psi_in->get_ngk(this->ik); + const int npw = psi_in->get_current_nbas(); const int max_npw = psi_in->get_nbasis() / psi_in->npol; const int npol = psi_in->npol; const std::complex* tmpsi_in = psi0; diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index cf2a926c22..f4ba33eedd 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -106,7 +106,7 @@ Psi::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int } template -Psi::Psi(const Psi& psi_in, const int nk_in, int nband_in) +Psi::Psi(const Psi& psi_in, const int nk_in, const int nband_in) { assert(nk_in <= psi_in.get_nk() && nk_in > 0); assert(nband_in <= psi_in.get_nbands() && nband_in > 0); From 1fb8851c4e7be7e5494c4dc41031867181dd617b Mon Sep 17 00:00:00 2001 From: haozhihan Date: Wed, 25 Dec 2024 13:33:37 +0000 Subject: [PATCH 08/16] fix bug about ngk --- source/module_hamilt_general/operator.cpp | 5 +++++ source/module_hsolver/hsolver_pw.cpp | 27 ++++++++++++++--------- source/module_psi/psi.cpp | 3 ++- source/module_psi/psi.h | 1 + 4 files changed, 24 insertions(+), 12 deletions(-) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index db54852331..6f7a7cc77e 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -65,6 +65,11 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp // a "psi" with the bands of needed range psi::Psi psi_wrapper(const_cast(tmpsi_in), 1, nbands, psi_input->get_nbasis(), true); + + // std::cout << "op->ik : " << op->ik << std::endl; + // std::cout << "psi_input->get_ngk(op->ik) : " << psi_input->get_ngk(op->ik) << std::endl; + // std::cout << "psi_input->get_current_nbas() : " << psi_input->get_current_nbas() << std::endl; + switch (op->get_act_type()) diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index 0c1ad2e8b8..97f32aa587 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -378,10 +378,12 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, ngk_vector[i] = ngk_pointer[i]; } + const int cur_nbasis = psi.get_current_nbas(); + if (this->method == "cg") { // wrap the subspace_func into a lambda function - auto subspace_func = [hm, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& psi_out) { + auto subspace_func = [hm, ngk_vector, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& psi_out) { // psi_in should be a 2D tensor: // psi_in.shape() = [nbands, nbasis] const auto ndim = psi_in.shape().ndim(); @@ -391,12 +393,14 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, 1, psi_in.shape().dim_size(0), psi_in.shape().dim_size(1), - ngk_vector); + ngk_vector, + cur_nbasis); auto psi_out_wrapper = psi::Psi(psi_out.data(), 1, psi_out.shape().dim_size(0), psi_out.shape().dim_size(1), - ngk_vector); + ngk_vector, + cur_nbasis); auto eigen = ct::Tensor(ct::DataTypeToEnum::value, ct::DeviceType::CpuDevice, ct::TensorShape({psi_in.shape().dim_size(0)})); @@ -415,7 +419,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, using ct_Device = typename ct::PsiToContainer::type; // wrap the hpsi_func and spsi_func into a lambda function - auto hpsi_func = [hm, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) { + auto hpsi_func = [hm, ngk_vector, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) { ModuleBase::timer::tick("DiagoCG_New", "hpsi_func"); // psi_in should be a 2D tensor: // psi_in.shape() = [nbands, nbasis] @@ -426,7 +430,8 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, 1, ndim == 1 ? 1 : psi_in.shape().dim_size(0), ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), - ngk_vector); + ngk_vector, + cur_nbasis); psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1); using hpsi_info = typename hamilt::Operator::hpsi_info; hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data()); @@ -486,11 +491,11 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, const int nband = psi.get_nbands(); const int nbasis = psi.get_nbasis(); // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec - auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { + auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { ModuleBase::timer::tick("DavSubspace", "hpsi_func"); // Convert "pointer data stucture" to a psi::Psi object - auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector); + auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis); psi::Range bands_range(true, 0, 0, nvec - 1); @@ -507,11 +512,11 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, else if (this->method == "dav_subspace") { // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec - auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { + auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { ModuleBase::timer::tick("DavSubspace", "hpsi_func"); // Convert "pointer data stucture" to a psi::Psi object - auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector); + auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis); psi::Range bands_range(true, 0, 0, nvec - 1); @@ -558,11 +563,11 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, // Davidson matrix-blockvector functions /// wrap hpsi into lambda function, Matrix \times blockvector // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec - auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { + auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { ModuleBase::timer::tick("David", "hpsi_func"); // Convert pointer of psi_in to a psi::Psi object - auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector); + auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis); psi::Range bands_range(true, 0, 0, nvec - 1); diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index f4ba33eedd..3fcb347790 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -69,6 +69,7 @@ Psi::Psi(T* psi_pointer, const int nbd_in, const int nbs_in, const std::vector& ngk_vector_in, + const int current_nbasis_in, const bool k_first_in) { this->k_first = k_first_in; @@ -79,7 +80,7 @@ Psi::Psi(T* psi_pointer, this->nk = nk_in; this->nbands = nbd_in; this->nbasis = nbs_in; - this->current_nbasis = nbs_in; + this->current_nbasis = current_nbasis_in; this->psi_current = this->psi = psi_pointer; this->allocate_inside = false; // Currently only GPU's implementation is supported for device recording! diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 88f143df2e..860112f066 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -64,6 +64,7 @@ class Psi const int nbd_in, const int nbs_in, const std::vector& ngk_vector_in, + const int current_nbasis_in, const bool k_first_in = true); // Constructor 8-2: a pointer version of constructor 3 From 1a9aea99c3278074504e01c3e1b66b6e825a0aa9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Wed, 25 Dec 2024 14:18:30 +0000 Subject: [PATCH 09/16] [pre-commit.ci lite] apply automatic fixes --- .../module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp index 1a1196b864..94a671372b 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp @@ -23,7 +23,8 @@ Velocity::Velocity ModuleBase::WARNING_QUIT("Velocity", "Constuctor of Operator::Velocity is failed, please check your code!"); } this->tpiba = ucell_in -> tpiba; - if(this->nonlocal) this->ppcell->initgradq_vnl(*this->ucell); + if(this->nonlocal) { this->ppcell->initgradq_vnl(*this->ucell); +} } void Velocity::init(const int ik_in) From 9a3a9f07f1a89a050e6f76547428ba8262a16d8b Mon Sep 17 00:00:00 2001 From: haozhihan Date: Thu, 26 Dec 2024 11:13:38 +0800 Subject: [PATCH 10/16] fix bug --- source/module_hamilt_general/operator.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index 6f7a7cc77e..f36691b950 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -66,9 +66,11 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp // a "psi" with the bands of needed range psi::Psi psi_wrapper(const_cast(tmpsi_in), 1, nbands, psi_input->get_nbasis(), true); - // std::cout << "op->ik : " << op->ik << std::endl; - // std::cout << "psi_input->get_ngk(op->ik) : " << psi_input->get_ngk(op->ik) << std::endl; - // std::cout << "psi_input->get_current_nbas() : " << psi_input->get_current_nbas() << std::endl; + std::cout << "op->ik : " << op->ik << std::endl; + std::cout << "psi_input->get_ngk(op->ik) : " << psi_input->get_ngk(op->ik) << std::endl; + std::cout << "psi_input->get_current_nbas() : " << psi_input->get_current_nbas() << std::endl; + + std::cout << "psi_input->ik : " << psi_input->get_nk() << std::endl; @@ -78,7 +80,7 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp op->act(psi_wrapper, *this->hpsi, nbands); break; default: - op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_current_nbas(), is_first_node); + op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_current_nbas() / psi_input->npol, is_first_node); break; } }; From 093c3f213ec9d943d800d1b5ca3c29e1024c011b Mon Sep 17 00:00:00 2001 From: haozhihan Date: Thu, 26 Dec 2024 11:30:52 +0800 Subject: [PATCH 11/16] format operator --- source/module_hamilt_general/operator.cpp | 109 ++++++++++++---------- source/module_hamilt_general/operator.h | 76 ++++++++------- 2 files changed, 103 insertions(+), 82 deletions(-) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index f36691b950..a8c95955f7 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -4,28 +4,31 @@ using namespace hamilt; - -template -Operator::Operator(){} - -template -Operator::~Operator() +template +Operator::Operator() { - if(this->hpsi != nullptr) { delete this->hpsi; } + +template +Operator::~Operator() +{ + if (this->hpsi != nullptr) + { + delete this->hpsi; + } Operator* last = this->next_op; Operator* last_sub = this->next_sub_op; - while(last != nullptr || last_sub != nullptr) + while (last != nullptr || last_sub != nullptr) { - if(last_sub != nullptr) - {//delete sub_chain first + if (last_sub != nullptr) + { // delete sub_chain first Operator* node_delete = last_sub; last_sub = last_sub->next_sub_op; node_delete->next_sub_op = nullptr; delete node_delete; } else - {//delete main chain if sub_chain is deleted + { // delete main chain if sub_chain is deleted Operator* node_delete = last; last_sub = last->next_sub_op; node_delete->next_sub_op = nullptr; @@ -36,7 +39,7 @@ Operator::~Operator() } } -template +template typename Operator::hpsi_info Operator::hPsi(hpsi_info& input) const { using syncmem_op = base_device::memory::synchronize_memory_op; @@ -46,12 +49,12 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp T* tmhpsi = this->get_hpsi(input); const T* tmpsi_in = std::get<0>(psi_info); - //if range in hpsi_info is illegal, the first return of to_range() would be nullptr + // if range in hpsi_info is illegal, the first return of to_range() would be nullptr if (tmpsi_in == nullptr) { ModuleBase::WARNING_QUIT("Operator", "please choose correct range of psi for hPsi()!"); } - //if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return + // if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return T* hpsi_pointer = std::get<2>(input); if (this->in_place) { @@ -62,28 +65,31 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp } auto call_act = [&, this](const Operator* op, const bool& is_first_node) -> void { - // a "psi" with the bands of needed range psi::Psi psi_wrapper(const_cast(tmpsi_in), 1, nbands, psi_input->get_nbasis(), true); - std::cout << "op->ik : " << op->ik << std::endl; - std::cout << "psi_input->get_ngk(op->ik) : " << psi_input->get_ngk(op->ik) << std::endl; - std::cout << "psi_input->get_current_nbas() : " << psi_input->get_current_nbas() << std::endl; + // std::cout << "op->ik : " << op->ik << std::endl; + // std::cout << "psi_input->get_ngk(op->ik) : " << psi_input->get_ngk(op->ik) << std::endl; + // std::cout << "psi_input->get_current_nbas() : " << psi_input->get_current_nbas() << std::endl; + + // std::cout << "psi_input->ik : " << psi_input->get_nk() << std::endl; - std::cout << "psi_input->ik : " << psi_input->get_nk() << std::endl; - - - switch (op->get_act_type()) { case 2: op->act(psi_wrapper, *this->hpsi, nbands); break; default: - op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_current_nbas() / psi_input->npol, is_first_node); + op->act(nbands, + psi_input->get_nbasis(), + psi_input->npol, + tmpsi_in, + this->hpsi->get_pointer(), + psi_input->get_current_nbas() / psi_input->npol, + is_first_node); break; } - }; + }; ModuleBase::timer::tick("Operator", "hPsi"); call_act(this, true); // first node @@ -98,39 +104,43 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer); } - -template -void Operator::init(const int ik_in) +template +void Operator::init(const int ik_in) { this->ik = ik_in; - if(this->next_op != nullptr) { + if (this->next_op != nullptr) + { this->next_op->init(ik_in); } } -template -void Operator::add(Operator* next) +template +void Operator::add(Operator* next) { - if(next==nullptr) { return; -} + if (next == nullptr) + { + return; + } next->is_first_node = false; - if(next->next_op != nullptr) { this->add(next->next_op); -} + if (next->next_op != nullptr) + { + this->add(next->next_op); + } Operator* last = this; - //loop to end of the chain - while(last->next_op != nullptr) + // loop to end of the chain + while (last->next_op != nullptr) { - if(next->cal_type==last->cal_type) + if (next->cal_type == last->cal_type) { break; } last = last->next_op; } - if(next->cal_type == last->cal_type) + if (next->cal_type == last->cal_type) { - //insert next to sub chain of current node + // insert next to sub chain of current node Operator* sub_last = last; - while(sub_last->next_sub_op != nullptr) + while (sub_last->next_sub_op != nullptr) { sub_last = sub_last->next_sub_op; } @@ -143,24 +153,24 @@ void Operator::add(Operator* next) } } -template +template T* Operator::get_hpsi(const hpsi_info& info) const { const int nbands_range = (std::get<1>(info).range_2 - std::get<1>(info).range_1 + 1); - //in_place call of hPsi, hpsi inputs as new psi, - //create a new hpsi and delete old hpsi later + // in_place call of hPsi, hpsi inputs as new psi, + // create a new hpsi and delete old hpsi later T* hpsi_pointer = std::get<2>(info); const T* psi_pointer = std::get<0>(info)->get_pointer(); - if(this->hpsi != nullptr) + if (this->hpsi != nullptr) { delete this->hpsi; this->hpsi = nullptr; } - if(!hpsi_pointer) + if (!hpsi_pointer) { ModuleBase::WARNING_QUIT("Operator::hPsi", "hpsi_pointer can not be nullptr"); } - else if(hpsi_pointer == psi_pointer) + else if (hpsi_pointer == psi_pointer) { this->in_place = true; this->hpsi = new psi::Psi(std::get<0>(info)[0], 1, nbands_range); @@ -170,7 +180,7 @@ T* Operator::get_hpsi(const hpsi_info& info) const this->in_place = false; this->hpsi = new psi::Psi(hpsi_pointer, std::get<0>(info)[0], 1, nbands_range); } - + hpsi_pointer = this->hpsi->get_pointer(); size_t total_hpsi_size = nbands_range * this->hpsi->get_nbasis(); // ModuleBase::GlobalFunc::ZEROS(hpsi_pointer, total_hpsi_size); @@ -179,7 +189,8 @@ T* Operator::get_hpsi(const hpsi_info& info) const return hpsi_pointer; } -namespace hamilt { +namespace hamilt +{ template class Operator; template class Operator, base_device::DEVICE_CPU>; template class Operator; @@ -190,4 +201,4 @@ template class Operator, base_device::DEVICE_GPU>; template class Operator; template class Operator, base_device::DEVICE_GPU>; #endif -} +} // namespace hamilt diff --git a/source/module_hamilt_general/operator.h b/source/module_hamilt_general/operator.h index 6cf29122fe..80ed065ccc 100644 --- a/source/module_hamilt_general/operator.h +++ b/source/module_hamilt_general/operator.h @@ -1,19 +1,19 @@ #ifndef OPERATOR_H #define OPERATOR_H -#include - #include "module_base/global_function.h" #include "module_base/tool_quit.h" #include "module_psi/psi.h" +#include + namespace hamilt { enum class calculation_type { no, - pw_ekinetic, + pw_ekinetic, pw_nonlocal, pw_veff, pw_meta, @@ -28,49 +28,54 @@ enum class calculation_type lcao_tddft_velocity, }; -// Basic class for operator module, +// Basic class for operator module, // it is designed for "O|psi>" and "" // Operator "O" might have several different types, which should be calculated one by one. // In basic class , function add() is designed for combine all operators together with a chain. template class Operator { - public: + public: Operator(); virtual ~Operator(); - //this is the core function for Operator - // do H|psi> from input |psi> , + // this is the core function for Operator + // do H|psi> from input |psi> , /// as default, different operators donate hPsi independently - /// run this->act function for the first operator and run all act() for other nodes in chain table + /// run this->act function for the first operator and run all act() for other nodes in chain table /// if this procedure is not suitable for your operator, just override this function. - /// output of hpsi would be first member of the returned tuple + /// output of hpsi would be first member of the returned tuple typedef std::tuple*, const psi::Range, T*> hpsi_info; - virtual hpsi_info hPsi(hpsi_info& input)const; + + virtual hpsi_info hPsi(hpsi_info& input) const; virtual void init(const int ik_in); virtual void add(Operator* next); - virtual int get_ik() const { return this->ik; } + virtual int get_ik() const + { + return this->ik; + } - ///do operation : |hpsi_choosed> = V|psi_choosed> - ///V is the target operator act on choosed psi, the consequence should be added to choosed hpsi - /// interface type 1: pointer-only (default) - /// @note PW: nbasis = max_npw * npol, nbands = nband * npol, npol = npol. Strange but PAY ATTENTION!!! + /// do operation : |hpsi_choosed> = V|psi_choosed> + /// V is the target operator act on choosed psi, the consequence should be added to choosed hpsi + /// interface type 1: pointer-only (default) + /// @note PW: nbasis = max_npw * npol, nbands = nband * npol, npol = npol. Strange but PAY ATTENTION!!! virtual void act(const int nbands, - const int nbasis, - const int npol, - const T* tmpsi_in, - T* tmhpsi, - const int ngk_ik = 0, - const bool is_first_node = false)const {}; + const int nbasis, + const int npol, + const T* tmpsi_in, + T* tmhpsi, + const int ngk_ik = 0, + const bool is_first_node = false) const {}; /// developer-friendly interfaces for act() function /// interface type 2: input and change the Psi-type HPsi // virtual void act(const psi::Psi& psi_in, psi::Psi& psi_out) const {}; virtual void act(const psi::Psi& psi_in, psi::Psi& psi_out, const int nbands) const {}; + /// interface type 3: return a Psi-type HPsi // virtual psi::Psi act(const psi::Psi& psi_in) const { return psi_in; }; @@ -78,36 +83,41 @@ class Operator /// type 1 (default): pointer-only /// act(const T* psi_in, T* psi_out) - /// type 2: use the `Psi`class + /// type 2: use the `Psi`class /// act(const Psi& psi_in, Psi& psi_out) - int get_act_type() const { return this->act_type; } -protected: + int get_act_type() const + { + return this->act_type; + } + + protected: int ik = 0; - int act_type = 1; ///< determine which act() interface would be called in hPsi() + int act_type = 1; ///< determine which act() interface would be called in hPsi() mutable bool in_place = false; - //calculation type, only different type can be in main chain table + // calculation type, only different type can be in main chain table enum calculation_type cal_type; Operator* next_sub_op = nullptr; bool is_first_node = true; - //if this Operator is first node in chain table, hpsi would not be empty + // if this Operator is first node in chain table, hpsi would not be empty mutable psi::Psi* hpsi = nullptr; /*This function would analyze hpsi_info and choose how to arrange hpsi storage In hpsi_info, if the third parameter hpsi_pointer is set, which indicates memory of hpsi is arranged by developer; - if hpsi_pointer is not set(nullptr), which indicates memory of hpsi is arranged by Operator, this case is rare. + if hpsi_pointer is not set(nullptr), which indicates memory of hpsi is arranged by Operator, this case is rare. two cases would occurred: - 1. hpsi_pointer != nullptr && psi_pointer == hpsi_pointer , psi would be replaced by hpsi, hpsi need a temporary memory - 2. hpsi_pointer != nullptr && psi_pointer != hpsi_pointer , this is the commonly case + 1. hpsi_pointer != nullptr && psi_pointer == hpsi_pointer , psi would be replaced by hpsi, hpsi need a temporary + memory + 2. hpsi_pointer != nullptr && psi_pointer != hpsi_pointer , this is the commonly case */ - T* get_hpsi(const hpsi_info& info)const; + T* get_hpsi(const hpsi_info& info) const; - Device *ctx = {}; + Device* ctx = {}; using set_memory_op = base_device::memory::set_memory_op; }; -}//end namespace hamilt +} // end namespace hamilt #endif \ No newline at end of file From af1b7bc7028155c868ac11038a8e22c60088e1c3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Thu, 26 Dec 2024 04:11:55 +0000 Subject: [PATCH 12/16] [pre-commit.ci lite] apply automatic fixes --- source/module_io/write_proj_band_lcao.cpp | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/source/module_io/write_proj_band_lcao.cpp b/source/module_io/write_proj_band_lcao.cpp index 34843207b7..b5660f7da5 100644 --- a/source/module_io/write_proj_band_lcao.cpp +++ b/source/module_io/write_proj_band_lcao.cpp @@ -25,8 +25,9 @@ void ModuleIO::write_proj_band_lcao( const double* sk = dynamic_cast*>(p_ham)->getSk(); int nspin0 = 1; - if (PARAM.inp.nspin == 2) + if (PARAM.inp.nspin == 2) { nspin0 = 2; +} int nks = 0; if (nspin0 == 1) { @@ -103,14 +104,16 @@ void ModuleIO::write_proj_band_lcao( out << "" << std::endl; out << "" << PARAM.inp.nspin << "" << std::endl; - if (PARAM.inp.nspin == 4) + if (PARAM.inp.nspin == 4) { out << "" << std::setw(2) << PARAM.globalv.nlocal / 2 << "" << std::endl; - else + } else { out << "" << std::setw(2) << PARAM.globalv.nlocal << "" << std::endl; +} out << "" << std::endl; - for (int ib = 0; ib < PARAM.inp.nbands; ib++) + for (int ib = 0; ib < PARAM.inp.nbands; ib++) { out << " " << (pelec->ekb(is * nks, ib)) * ModuleBase::Ry_to_eV; +} out << std::endl; out << "" << std::endl; @@ -139,9 +142,9 @@ void ModuleIO::write_proj_band_lcao( out << "" << std::endl; for (int ib = 0; ib < PARAM.inp.nbands; ib++) { - if (PARAM.inp.nspin == 1 || PARAM.inp.nspin == 2) + if (PARAM.inp.nspin == 1 || PARAM.inp.nspin == 2) { out << std::setw(13) << weight(is, ib * PARAM.globalv.nlocal + w); - else if (PARAM.inp.nspin == 4) + } else if (PARAM.inp.nspin == 4) { int w0 = w - s0; out << std::setw(13) @@ -178,8 +181,9 @@ void ModuleIO::write_proj_band_lcao( ModuleBase::timer::tick("ModuleIO", "write_proj_band_lcao"); int nspin0 = 1; - if (PARAM.inp.nspin == 2) + if (PARAM.inp.nspin == 2) { nspin0 = 2; +} int nks = 0; if (nspin0 == 1) { @@ -302,8 +306,9 @@ void ModuleIO::write_proj_band_lcao( for (int ik = 0; ik < nks; ik++) { - for (int ib = 0; ib < PARAM.inp.nbands; ib++) + for (int ib = 0; ib < PARAM.inp.nbands; ib++) { out << " " << (pelec->ekb(ik + is * nks, ib)) * ModuleBase::Ry_to_eV; +} out << std::endl; } out << "" << std::endl; From 16687c3baf5a7e18ce079419ec5019ea52b0fd0f Mon Sep 17 00:00:00 2001 From: haozhihan Date: Thu, 26 Dec 2024 07:03:05 +0000 Subject: [PATCH 13/16] fix bug --- source/module_hamilt_general/operator.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index a8c95955f7..dce5335db4 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -74,6 +74,10 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp // std::cout << "psi_input->ik : " << psi_input->get_nk() << std::endl; + // std::cout << "psi_input->npol : " << psi_input->npol << std::endl; + + + switch (op->get_act_type()) { case 2: @@ -85,7 +89,10 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), - psi_input->get_current_nbas() / psi_input->npol, + psi_input->get_ngk(op->ik), + // 0, + // psi_input->get_current_nbas(), + // psi_input->get_current_nbas() / psi_input->npol, is_first_node); break; } From 35d26d65720b0e2b7ea8222cbe0556897ece5902 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Thu, 26 Dec 2024 15:55:12 +0800 Subject: [PATCH 14/16] fix bug --- source/module_hamilt_general/operator.cpp | 4 ++-- source/module_hsolver/hsolver_pw.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index dce5335db4..fbd1525805 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -89,9 +89,9 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), - psi_input->get_ngk(op->ik), + // psi_input->get_ngk(op->ik), // 0, - // psi_input->get_current_nbas(), + psi_input->get_current_nbas(), // psi_input->get_current_nbas() / psi_input->npol, is_first_node); break; diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index 97f32aa587..a885296c62 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -378,7 +378,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, ngk_vector[i] = ngk_pointer[i]; } - const int cur_nbasis = psi.get_current_nbas(); + const int cur_nbasis = psi.get_ngk(psi.get_current_k()); if (this->method == "cg") { From 3096085262c07e1edb4de90031c89e29481df527 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Thu, 26 Dec 2024 09:13:00 +0000 Subject: [PATCH 15/16] fix bug --- source/module_hamilt_general/operator.cpp | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index fbd1525805..4a830489f4 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -68,16 +68,6 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp // a "psi" with the bands of needed range psi::Psi psi_wrapper(const_cast(tmpsi_in), 1, nbands, psi_input->get_nbasis(), true); - // std::cout << "op->ik : " << op->ik << std::endl; - // std::cout << "psi_input->get_ngk(op->ik) : " << psi_input->get_ngk(op->ik) << std::endl; - // std::cout << "psi_input->get_current_nbas() : " << psi_input->get_current_nbas() << std::endl; - - // std::cout << "psi_input->ik : " << psi_input->get_nk() << std::endl; - - // std::cout << "psi_input->npol : " << psi_input->npol << std::endl; - - - switch (op->get_act_type()) { case 2: @@ -89,10 +79,8 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), - // psi_input->get_ngk(op->ik), - // 0, - psi_input->get_current_nbas(), - // psi_input->get_current_nbas() / psi_input->npol, + psi_input->get_ngk(op->ik), + // psi_input->get_current_nbas(), is_first_node); break; } From a3817e4983c0147697ab4f46c066ec24a97678b6 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Thu, 26 Dec 2024 10:06:41 +0000 Subject: [PATCH 16/16] fix bug --- .../module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp index 94a671372b..0d49cadaa0 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp @@ -48,7 +48,7 @@ void Velocity::act ) const { ModuleBase::timer::tick("Operator", "Velocity"); - const int npw = psi_in->get_current_nbas(); + const int npw = psi_in->get_ngk(this->ik); const int max_npw = psi_in->get_nbasis() / psi_in->npol; const int npol = psi_in->npol; const std::complex* tmpsi_in = psi0;