Skip to content

Commit d6c5d43

Browse files
haozhihanmohanchen
andauthored
Memory: modify hpsi interface in diagH_subspace_init func (#4167)
* modify hpsi interface in diagH_subspace_init * fix build bug --------- Co-authored-by: Mohan Chen <[email protected]>
1 parent 0ecc169 commit d6c5d43

File tree

1 file changed

+37
-11
lines changed

1 file changed

+37
-11
lines changed

source/module_hsolver/diago_iter_assist.cpp

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -240,23 +240,49 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(
240240
// std::vector<T> hpsi(psi_temp.get_nbands() * psi_temp.get_nbasis());
241241

242242
// do hPsi for all bands
243-
psi::Range all_bands_range(1, psi_temp.get_current_k(), 0, psi_temp.get_nbands()-1);
244-
hpsi_info hpsi_in(&psi_temp, all_bands_range, hpsi);
245-
if(pHamilt->ops == nullptr)
243+
if (base_device::get_device_type(ctx) == base_device::GpuDevice)
246244
{
247-
ModuleBase::WARNING("DiagoIterAssist::diagH_subspace_init",
248-
"Severe warning: Operators in Hamilt are not allocated yet, will return value of psi to evc directly\n");
249-
for(int iband = 0; iband < evc.get_nbands(); iband++)
245+
for (int i = 0; i < psi_temp.get_nbands(); i++)
250246
{
251-
for(int ig = 0; ig < evc.get_nbasis(); ig++)
247+
psi::Range band_by_band_range(1, psi_temp.get_current_k(), i, i);
248+
hpsi_info hpsi_in(&psi_temp, band_by_band_range, hpsi + i * psi_temp.get_nbasis());
249+
if(pHamilt->ops == nullptr)
252250
{
253-
evc(iband, ig) = psi[iband * evc.get_nbasis() + ig];
251+
ModuleBase::WARNING("DiagoIterAssist::diagH_subspace_init",
252+
"Severe warning: Operators in Hamilt are not allocated yet, will return value of psi to evc directly\n");
253+
for(int iband = 0; iband < evc.get_nbands(); iband++)
254+
{
255+
for(int ig = 0; ig < evc.get_nbasis(); ig++)
256+
{
257+
evc(iband, ig) = psi[iband * evc.get_nbasis() + ig];
258+
}
259+
en[iband] = 0.0;
260+
}
261+
return;
254262
}
255-
en[iband] = 0.0;
263+
pHamilt->ops->hPsi(hpsi_in);
256264
}
257-
return;
258265
}
259-
pHamilt->ops->hPsi(hpsi_in);
266+
else if (base_device::get_device_type(ctx) == base_device::CpuDevice)
267+
{
268+
psi::Range all_bands_range(1, psi_temp.get_current_k(), 0, psi_temp.get_nbands()-1);
269+
hpsi_info hpsi_in(&psi_temp, all_bands_range, hpsi);
270+
if(pHamilt->ops == nullptr)
271+
{
272+
ModuleBase::WARNING("DiagoIterAssist::diagH_subspace_init",
273+
"Severe warning: Operators in Hamilt are not allocated yet, will return value of psi to evc directly\n");
274+
for(int iband = 0; iband < evc.get_nbands(); iband++)
275+
{
276+
for(int ig = 0; ig < evc.get_nbasis(); ig++)
277+
{
278+
evc(iband, ig) = psi[iband * evc.get_nbasis() + ig];
279+
}
280+
en[iband] = 0.0;
281+
}
282+
return;
283+
}
284+
pHamilt->ops->hPsi(hpsi_in);
285+
}
260286

261287
gemm_op<T, Device>()(
262288
ctx,

0 commit comments

Comments
 (0)