Skip to content

Commit

Permalink
fix bug about ngk
Browse files Browse the repository at this point in the history
  • Loading branch information
haozhihan committed Dec 25, 2024
1 parent 1c2f523 commit 1fb8851
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 12 deletions.
5 changes: 5 additions & 0 deletions source/module_hamilt_general/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp

// a "psi" with the bands of needed range
psi::Psi<T, Device> psi_wrapper(const_cast<T*>(tmpsi_in), 1, nbands, psi_input->get_nbasis(), true);

// std::cout << "op->ik : " << op->ik << std::endl;
// std::cout << "psi_input->get_ngk(op->ik) : " << psi_input->get_ngk(op->ik) << std::endl;
// std::cout << "psi_input->get_current_nbas() : " << psi_input->get_current_nbas() << std::endl;



switch (op->get_act_type())
Expand Down
27 changes: 16 additions & 11 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
ngk_vector[i] = ngk_pointer[i];
}

const int cur_nbasis = psi.get_current_nbas();

if (this->method == "cg")
{
// wrap the subspace_func into a lambda function
auto subspace_func = [hm, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
auto subspace_func = [hm, ngk_vector, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
// psi_in should be a 2D tensor:
// psi_in.shape() = [nbands, nbasis]
const auto ndim = psi_in.shape().ndim();
Expand All @@ -391,12 +393,14 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
1,
psi_in.shape().dim_size(0),
psi_in.shape().dim_size(1),
ngk_vector);
ngk_vector,
cur_nbasis);
auto psi_out_wrapper = psi::Psi<T, Device>(psi_out.data<T>(),
1,
psi_out.shape().dim_size(0),
psi_out.shape().dim_size(1),
ngk_vector);
ngk_vector,
cur_nbasis);
auto eigen = ct::Tensor(ct::DataTypeToEnum<Real>::value,
ct::DeviceType::CpuDevice,
ct::TensorShape({psi_in.shape().dim_size(0)}));
Expand All @@ -415,7 +419,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
using ct_Device = typename ct::PsiToContainer<Device>::type;

// wrap the hpsi_func and spsi_func into a lambda function
auto hpsi_func = [hm, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
auto hpsi_func = [hm, ngk_vector, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
ModuleBase::timer::tick("DiagoCG_New", "hpsi_func");
// psi_in should be a 2D tensor:
// psi_in.shape() = [nbands, nbasis]
Expand All @@ -426,7 +430,8 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
1,
ndim == 1 ? 1 : psi_in.shape().dim_size(0),
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1),
ngk_vector);
ngk_vector,
cur_nbasis);
psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1);
using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data<T>());
Expand Down Expand Up @@ -486,11 +491,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
const int nband = psi.get_nbands();
const int nbasis = psi.get_nbasis();
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
ModuleBase::timer::tick("DavSubspace", "hpsi_func");

// Convert "pointer data stucture" to a psi::Psi object
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector);
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis);

psi::Range bands_range(true, 0, 0, nvec - 1);

Expand All @@ -507,11 +512,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
else if (this->method == "dav_subspace")
{
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
ModuleBase::timer::tick("DavSubspace", "hpsi_func");

// Convert "pointer data stucture" to a psi::Psi object
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector);
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis);

psi::Range bands_range(true, 0, 0, nvec - 1);

Expand Down Expand Up @@ -558,11 +563,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
// Davidson matrix-blockvector functions
/// wrap hpsi into lambda function, Matrix \times blockvector
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
ModuleBase::timer::tick("David", "hpsi_func");

// Convert pointer of psi_in to a psi::Psi object
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector);
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis);

psi::Range bands_range(true, 0, 0, nvec - 1);

Expand Down
3 changes: 2 additions & 1 deletion source/module_psi/psi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ Psi<T, Device>::Psi(T* psi_pointer,
const int nbd_in,
const int nbs_in,
const std::vector<int>& ngk_vector_in,
const int current_nbasis_in,
const bool k_first_in)
{
this->k_first = k_first_in;
Expand All @@ -79,7 +80,7 @@ Psi<T, Device>::Psi(T* psi_pointer,
this->nk = nk_in;
this->nbands = nbd_in;
this->nbasis = nbs_in;
this->current_nbasis = nbs_in;
this->current_nbasis = current_nbasis_in;
this->psi_current = this->psi = psi_pointer;
this->allocate_inside = false;
// Currently only GPU's implementation is supported for device recording!
Expand Down
1 change: 1 addition & 0 deletions source/module_psi/psi.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class Psi
const int nbd_in,
const int nbs_in,
const std::vector<int>& ngk_vector_in,
const int current_nbasis_in,
const bool k_first_in = true);

// Constructor 8-2: a pointer version of constructor 3
Expand Down

0 comments on commit 1fb8851

Please sign in to comment.