diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index db54852331..6f7a7cc77e 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -65,6 +65,11 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp // a "psi" with the bands of needed range psi::Psi psi_wrapper(const_cast(tmpsi_in), 1, nbands, psi_input->get_nbasis(), true); + + // std::cout << "op->ik : " << op->ik << std::endl; + // std::cout << "psi_input->get_ngk(op->ik) : " << psi_input->get_ngk(op->ik) << std::endl; + // std::cout << "psi_input->get_current_nbas() : " << psi_input->get_current_nbas() << std::endl; + switch (op->get_act_type()) diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index 0c1ad2e8b8..97f32aa587 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -378,10 +378,12 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, ngk_vector[i] = ngk_pointer[i]; } + const int cur_nbasis = psi.get_current_nbas(); + if (this->method == "cg") { // wrap the subspace_func into a lambda function - auto subspace_func = [hm, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& psi_out) { + auto subspace_func = [hm, ngk_vector, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& psi_out) { // psi_in should be a 2D tensor: // psi_in.shape() = [nbands, nbasis] const auto ndim = psi_in.shape().ndim(); @@ -391,12 +393,14 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, 1, psi_in.shape().dim_size(0), psi_in.shape().dim_size(1), - ngk_vector); + ngk_vector, + cur_nbasis); auto psi_out_wrapper = psi::Psi(psi_out.data(), 1, psi_out.shape().dim_size(0), psi_out.shape().dim_size(1), - ngk_vector); + ngk_vector, + cur_nbasis); auto eigen = ct::Tensor(ct::DataTypeToEnum::value, ct::DeviceType::CpuDevice, ct::TensorShape({psi_in.shape().dim_size(0)})); @@ -415,7 +419,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, using ct_Device = typename ct::PsiToContainer::type; // wrap the hpsi_func and spsi_func into a lambda function - auto hpsi_func = [hm, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) { + auto hpsi_func = [hm, ngk_vector, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) { ModuleBase::timer::tick("DiagoCG_New", "hpsi_func"); // psi_in should be a 2D tensor: // psi_in.shape() = [nbands, nbasis] @@ -426,7 +430,8 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, 1, ndim == 1 ? 1 : psi_in.shape().dim_size(0), ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), - ngk_vector); + ngk_vector, + cur_nbasis); psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1); using hpsi_info = typename hamilt::Operator::hpsi_info; hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data()); @@ -486,11 +491,11 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, const int nband = psi.get_nbands(); const int nbasis = psi.get_nbasis(); // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec - auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { + auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { ModuleBase::timer::tick("DavSubspace", "hpsi_func"); // Convert "pointer data stucture" to a psi::Psi object - auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector); + auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis); psi::Range bands_range(true, 0, 0, nvec - 1); @@ -507,11 +512,11 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, else if (this->method == "dav_subspace") { // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec - auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { + auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { ModuleBase::timer::tick("DavSubspace", "hpsi_func"); // Convert "pointer data stucture" to a psi::Psi object - auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector); + auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis); psi::Range bands_range(true, 0, 0, nvec - 1); @@ -558,11 +563,11 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, // Davidson matrix-blockvector functions /// wrap hpsi into lambda function, Matrix \times blockvector // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec - auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { + auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { ModuleBase::timer::tick("David", "hpsi_func"); // Convert pointer of psi_in to a psi::Psi object - auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector); + auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis); psi::Range bands_range(true, 0, 0, nvec - 1); diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index f4ba33eedd..3fcb347790 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -69,6 +69,7 @@ Psi::Psi(T* psi_pointer, const int nbd_in, const int nbs_in, const std::vector& ngk_vector_in, + const int current_nbasis_in, const bool k_first_in) { this->k_first = k_first_in; @@ -79,7 +80,7 @@ Psi::Psi(T* psi_pointer, this->nk = nk_in; this->nbands = nbd_in; this->nbasis = nbs_in; - this->current_nbasis = nbs_in; + this->current_nbasis = current_nbasis_in; this->psi_current = this->psi = psi_pointer; this->allocate_inside = false; // Currently only GPU's implementation is supported for device recording! diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 88f143df2e..860112f066 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -64,6 +64,7 @@ class Psi const int nbd_in, const int nbs_in, const std::vector& ngk_vector_in, + const int current_nbasis_in, const bool k_first_in = true); // Constructor 8-2: a pointer version of constructor 3