@@ -240,23 +240,49 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(
240
240
// std::vector<T> hpsi(psi_temp.get_nbands() * psi_temp.get_nbasis());
241
241
242
242
// do hPsi for all bands
243
- psi::Range all_bands_range (1 , psi_temp.get_current_k (), 0 , psi_temp.get_nbands ()-1 );
244
- hpsi_info hpsi_in (&psi_temp, all_bands_range, hpsi);
245
- if (pHamilt->ops == nullptr )
243
+ if (base_device::get_device_type (ctx) == base_device::GpuDevice)
246
244
{
247
- ModuleBase::WARNING (" DiagoIterAssist::diagH_subspace_init" ,
248
- " Severe warning: Operators in Hamilt are not allocated yet, will return value of psi to evc directly\n " );
249
- for (int iband = 0 ; iband < evc.get_nbands (); iband++)
245
+ for (int i = 0 ; i < psi_temp.get_nbands (); i++)
250
246
{
251
- for (int ig = 0 ; ig < evc.get_nbasis (); ig++)
247
+ psi::Range band_by_band_range (1 , psi_temp.get_current_k (), i, i);
248
+ hpsi_info hpsi_in (&psi_temp, band_by_band_range, hpsi + i * psi_temp.get_nbasis ());
249
+ if (pHamilt->ops == nullptr )
252
250
{
253
- evc (iband, ig) = psi[iband * evc.get_nbasis () + ig];
251
+ ModuleBase::WARNING (" DiagoIterAssist::diagH_subspace_init" ,
252
+ " Severe warning: Operators in Hamilt are not allocated yet, will return value of psi to evc directly\n " );
253
+ for (int iband = 0 ; iband < evc.get_nbands (); iband++)
254
+ {
255
+ for (int ig = 0 ; ig < evc.get_nbasis (); ig++)
256
+ {
257
+ evc (iband, ig) = psi[iband * evc.get_nbasis () + ig];
258
+ }
259
+ en[iband] = 0.0 ;
260
+ }
261
+ return ;
254
262
}
255
- en[iband] = 0.0 ;
263
+ pHamilt-> ops -> hPsi (hpsi_in) ;
256
264
}
257
- return ;
258
265
}
259
- pHamilt->ops ->hPsi (hpsi_in);
266
+ else if (base_device::get_device_type (ctx) == base_device::CpuDevice)
267
+ {
268
+ psi::Range all_bands_range (1 , psi_temp.get_current_k (), 0 , psi_temp.get_nbands ()-1 );
269
+ hpsi_info hpsi_in (&psi_temp, all_bands_range, hpsi);
270
+ if (pHamilt->ops == nullptr )
271
+ {
272
+ ModuleBase::WARNING (" DiagoIterAssist::diagH_subspace_init" ,
273
+ " Severe warning: Operators in Hamilt are not allocated yet, will return value of psi to evc directly\n " );
274
+ for (int iband = 0 ; iband < evc.get_nbands (); iband++)
275
+ {
276
+ for (int ig = 0 ; ig < evc.get_nbasis (); ig++)
277
+ {
278
+ evc (iband, ig) = psi[iband * evc.get_nbasis () + ig];
279
+ }
280
+ en[iband] = 0.0 ;
281
+ }
282
+ return ;
283
+ }
284
+ pHamilt->ops ->hPsi (hpsi_in);
285
+ }
260
286
261
287
gemm_op<T, Device>()(
262
288
ctx,
0 commit comments