Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

not ready #5761

Draft
wants to merge 12 commits into
base: develop
Choose a base branch
from
8 changes: 5 additions & 3 deletions source/module_elecstate/cal_dm.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg,
//dm.fix_k(ik);
dm[ik].create(ParaV->ncol, ParaV->nrow);
// wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw);
psi::Psi<double> wg_wfc(wfc, 1);
psi::Psi<double> wg_wfc(wfc, 1, nbands_local);

int ib_global = 0;
for (int ib_local = 0; ib_local < nbands_local; ++ib_local)
Expand All @@ -41,7 +41,8 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg,
ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!");
}
}
if (ib_global >= wg.nc) continue;
if (ib_global >= wg.nc) { continue;
}
const double wg_local = wg(ik, ib_global);
double* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0));
BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1);
Expand Down Expand Up @@ -99,7 +100,8 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg,
ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!");
}
}
if (ib_global >= wg.nc) continue;
if (ib_global >= wg.nc) { continue;
}
const double wg_local = wg(ik, ib_global);
std::complex<double>* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0));
BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1);
Expand Down
3 changes: 2 additions & 1 deletion source/module_elecstate/module_dm/cal_dm_psi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ void cal_dm_psi(const Parallel_Orbitals* ParaV,
// dm.fix_k(ik);
// dm[ik].create(ParaV->ncol, ParaV->nrow);
// wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw);
psi::Psi<double> wg_wfc(wfc, 1);

psi::Psi<double> wg_wfc(wfc, 1, nbands_local);

int ib_global = 0;
for (int ib_local = 0; ib_local < nbands_local; ++ib_local)
Expand Down
106 changes: 62 additions & 44 deletions source/module_hamilt_general/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,31 @@

using namespace hamilt;


template<typename T, typename Device>
Operator<T, Device>::Operator(){}

template<typename T, typename Device>
Operator<T, Device>::~Operator()
template <typename T, typename Device>
Operator<T, Device>::Operator()
{
if(this->hpsi != nullptr) { delete this->hpsi;
}

template <typename T, typename Device>
Operator<T, Device>::~Operator()
{
if (this->hpsi != nullptr)
{
delete this->hpsi;
}
Operator* last = this->next_op;
Operator* last_sub = this->next_sub_op;
while(last != nullptr || last_sub != nullptr)
while (last != nullptr || last_sub != nullptr)
{
if(last_sub != nullptr)
{//delete sub_chain first
if (last_sub != nullptr)
{ // delete sub_chain first
Operator* node_delete = last_sub;
last_sub = last_sub->next_sub_op;
node_delete->next_sub_op = nullptr;
delete node_delete;
}
else
{//delete main chain if sub_chain is deleted
{ // delete main chain if sub_chain is deleted
Operator* node_delete = last;
last_sub = last->next_sub_op;
node_delete->next_sub_op = nullptr;
Expand All @@ -36,7 +39,7 @@ Operator<T, Device>::~Operator()
}
}

template<typename T, typename Device>
template <typename T, typename Device>
typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& input) const
{
using syncmem_op = base_device::memory::synchronize_memory_op<T, Device, Device>;
Expand All @@ -46,12 +49,12 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp

T* tmhpsi = this->get_hpsi(input);
const T* tmpsi_in = std::get<0>(psi_info);
//if range in hpsi_info is illegal, the first return of to_range() would be nullptr
// if range in hpsi_info is illegal, the first return of to_range() would be nullptr
if (tmpsi_in == nullptr)
{
ModuleBase::WARNING_QUIT("Operator", "please choose correct range of psi for hPsi()!");
}
//if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return
// if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return
T* hpsi_pointer = std::get<2>(input);
if (this->in_place)
{
Expand All @@ -62,21 +65,31 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
}

auto call_act = [&, this](const Operator* op, const bool& is_first_node) -> void {

// a "psi" with the bands of needed range
psi::Psi<T, Device> psi_wrapper(const_cast<T*>(tmpsi_in), 1, nbands, psi_input->get_nbasis(), true);



// std::cout << "op->ik : " << op->ik << std::endl;
// std::cout << "psi_input->get_ngk(op->ik) : " << psi_input->get_ngk(op->ik) << std::endl;
// std::cout << "psi_input->get_current_nbas() : " << psi_input->get_current_nbas() << std::endl;

// std::cout << "psi_input->ik : " << psi_input->get_nk() << std::endl;

switch (op->get_act_type())
{
case 2:
op->act(psi_wrapper, *this->hpsi, nbands);
break;
default:
op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_ngk(op->ik), is_first_node);
op->act(nbands,
psi_input->get_nbasis(),
psi_input->npol,
tmpsi_in,
this->hpsi->get_pointer(),
psi_input->get_current_nbas() / psi_input->npol,
is_first_node);
break;
}
};
};

ModuleBase::timer::tick("Operator", "hPsi");
call_act(this, true); // first node
Expand All @@ -91,39 +104,43 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer);
}


template<typename T, typename Device>
void Operator<T, Device>::init(const int ik_in)
template <typename T, typename Device>
void Operator<T, Device>::init(const int ik_in)
{
this->ik = ik_in;
if(this->next_op != nullptr) {
if (this->next_op != nullptr)
{
this->next_op->init(ik_in);
}
}

template<typename T, typename Device>
void Operator<T, Device>::add(Operator* next)
template <typename T, typename Device>
void Operator<T, Device>::add(Operator* next)
{
if(next==nullptr) { return;
}
if (next == nullptr)
{
return;
}
next->is_first_node = false;
if(next->next_op != nullptr) { this->add(next->next_op);
}
if (next->next_op != nullptr)
{
this->add(next->next_op);
}
Operator* last = this;
//loop to end of the chain
while(last->next_op != nullptr)
// loop to end of the chain
while (last->next_op != nullptr)
{
if(next->cal_type==last->cal_type)
if (next->cal_type == last->cal_type)
{
break;
}
last = last->next_op;
}
if(next->cal_type == last->cal_type)
if (next->cal_type == last->cal_type)
{
//insert next to sub chain of current node
// insert next to sub chain of current node
Operator* sub_last = last;
while(sub_last->next_sub_op != nullptr)
while (sub_last->next_sub_op != nullptr)
{
sub_last = sub_last->next_sub_op;
}
Expand All @@ -136,24 +153,24 @@ void Operator<T, Device>::add(Operator* next)
}
}

template<typename T, typename Device>
template <typename T, typename Device>
T* Operator<T, Device>::get_hpsi(const hpsi_info& info) const
{
const int nbands_range = (std::get<1>(info).range_2 - std::get<1>(info).range_1 + 1);
//in_place call of hPsi, hpsi inputs as new psi,
//create a new hpsi and delete old hpsi later
// in_place call of hPsi, hpsi inputs as new psi,
// create a new hpsi and delete old hpsi later
T* hpsi_pointer = std::get<2>(info);
const T* psi_pointer = std::get<0>(info)->get_pointer();
if(this->hpsi != nullptr)
if (this->hpsi != nullptr)
{
delete this->hpsi;
this->hpsi = nullptr;
}
if(!hpsi_pointer)
if (!hpsi_pointer)
{
ModuleBase::WARNING_QUIT("Operator::hPsi", "hpsi_pointer can not be nullptr");
}
else if(hpsi_pointer == psi_pointer)
else if (hpsi_pointer == psi_pointer)
{
this->in_place = true;
this->hpsi = new psi::Psi<T, Device>(std::get<0>(info)[0], 1, nbands_range);
Expand All @@ -163,7 +180,7 @@ T* Operator<T, Device>::get_hpsi(const hpsi_info& info) const
this->in_place = false;
this->hpsi = new psi::Psi<T, Device>(hpsi_pointer, std::get<0>(info)[0], 1, nbands_range);
}

hpsi_pointer = this->hpsi->get_pointer();
size_t total_hpsi_size = nbands_range * this->hpsi->get_nbasis();
// ModuleBase::GlobalFunc::ZEROS(hpsi_pointer, total_hpsi_size);
Expand All @@ -172,7 +189,8 @@ T* Operator<T, Device>::get_hpsi(const hpsi_info& info) const
return hpsi_pointer;
}

namespace hamilt {
namespace hamilt
{
template class Operator<float, base_device::DEVICE_CPU>;
template class Operator<std::complex<float>, base_device::DEVICE_CPU>;
template class Operator<double, base_device::DEVICE_CPU>;
Expand All @@ -183,4 +201,4 @@ template class Operator<std::complex<float>, base_device::DEVICE_GPU>;
template class Operator<double, base_device::DEVICE_GPU>;
template class Operator<std::complex<double>, base_device::DEVICE_GPU>;
#endif
}
} // namespace hamilt
Loading
Loading