@@ -134,8 +134,7 @@ void Diago_DavSubspace<T, Device>::diag_once(hamilt::Hamilt<T, Device>* phm_in,
134
134
basis,
135
135
this ->hphi ,
136
136
this ->hcc ,
137
- this ->scc ,
138
- true );
137
+ this ->scc );
139
138
140
139
this ->diag_zhegvx (nbase,
141
140
this ->n_band ,
@@ -175,8 +174,7 @@ void Diago_DavSubspace<T, Device>::diag_once(hamilt::Hamilt<T, Device>* phm_in,
175
174
basis,
176
175
this ->hphi ,
177
176
this ->hcc ,
178
- this ->scc ,
179
- false );
177
+ this ->scc );
180
178
181
179
this ->diag_zhegvx (nbase,
182
180
this ->n_band ,
@@ -399,84 +397,43 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
399
397
const psi::Psi<T, Device>& basis,
400
398
const T* hphi,
401
399
T* hcc,
402
- T* scc,
403
- bool init)
400
+ T* scc)
404
401
{
405
402
ModuleBase::timer::tick (" Diago_DavSubspace" , " cal_elem" );
406
403
407
- if (init)
408
- {
409
- assert (nbase == 0 );
410
- assert (this ->n_band == notconv);
411
- gemm_op<T, Device>()(this ->ctx ,
412
- ' C' ,
413
- ' N' ,
414
- notconv,
415
- notconv,
416
- this ->dim ,
417
- this ->one ,
418
- &basis (0 , 0 ),
419
- this ->dim ,
420
- hphi,
421
- this ->dim ,
422
- this ->zero ,
423
- hcc,
424
- this ->nbase_x );
425
-
426
- gemm_op<T, Device>()(this ->ctx ,
427
- ' C' ,
428
- ' N' ,
429
- notconv,
430
- notconv,
431
- this ->dim ,
432
- this ->one ,
433
- &basis (0 , 0 ),
434
- this ->dim ,
435
- &basis (0 , 0 ),
436
- this ->dim ,
437
- this ->zero ,
438
- scc,
439
- this ->nbase_x );
440
- }
441
- else
442
- {
443
- gemm_op<T, Device>()(this ->ctx ,
444
- ' C' ,
445
- ' N' ,
446
- notconv,
447
- nbase + notconv,
448
- this ->dim ,
449
- this ->one ,
450
- &hphi[nbase * this ->dim ],
451
- this ->dim ,
452
- &basis (0 , 0 ),
453
- this ->dim ,
454
- this ->zero ,
455
- hcc + nbase,
456
- this ->nbase_x );
404
+ gemm_op<T, Device>()(this ->ctx ,
405
+ ' C' ,
406
+ ' N' ,
407
+ nbase + notconv,
408
+ notconv,
409
+ this ->dim ,
410
+ this ->one ,
411
+ &basis (0 , 0 ),
412
+ this ->dim ,
413
+ &hphi[nbase * this ->dim ],
414
+ this ->dim ,
415
+ this ->zero ,
416
+ &hcc[nbase * this ->nbase_x ],
417
+ this ->nbase_x );
457
418
458
- gemm_op<T, Device>()(this ->ctx ,
459
- ' C' ,
460
- ' N' ,
461
- notconv,
462
- nbase + notconv,
463
- this ->dim ,
464
- this ->one ,
465
- &basis (nbase, 0 ),
466
- this ->dim ,
467
- &basis (0 , 0 ),
468
- this ->dim ,
469
- this ->zero ,
470
- scc + nbase,
471
- this ->nbase_x );
472
- }
419
+ gemm_op<T, Device>()(this ->ctx ,
420
+ ' C' ,
421
+ ' N' ,
422
+ nbase + notconv,
423
+ notconv,
424
+ this ->dim ,
425
+ this ->one ,
426
+ &basis (0 , 0 ),
427
+ this ->dim ,
428
+ &basis (nbase, 0 ),
429
+ this ->dim ,
430
+ this ->zero ,
431
+ &scc[nbase * this ->nbase_x ],
432
+ this ->nbase_x );
473
433
474
434
#ifdef __MPI
475
435
if (GlobalV::NPROC_IN_POOL > 1 )
476
436
{
477
- matrixTranspose_op<T, Device>()(this ->ctx , this ->nbase_x , this ->nbase_x , hcc, hcc);
478
- matrixTranspose_op<T, Device>()(this ->ctx , this ->nbase_x , this ->nbase_x , scc, scc);
479
-
480
437
auto * swap = new T[notconv * this ->nbase_x ];
481
438
syncmem_complex_op ()(this ->ctx , this ->ctx , swap, hcc + nbase * this ->nbase_x , notconv * this ->nbase_x );
482
439
@@ -532,64 +489,34 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
532
489
}
533
490
}
534
491
delete[] swap;
535
-
536
- matrixTranspose_op<T, Device>()(this ->ctx , this ->nbase_x , this ->nbase_x , hcc, hcc);
537
- matrixTranspose_op<T, Device>()(this ->ctx , this ->nbase_x , this ->nbase_x , scc, scc);
538
492
}
539
493
#endif
540
494
541
- nbase += notconv;
542
- int nb1 = nbase - notconv;
543
- // reset:
544
- if (init )
495
+ const size_t last_nbase = nbase; // init: last_nbase = 0
496
+ nbase = nbase + notconv;
497
+
498
+ for ( size_t i = 0 ; i < nbase; i++ )
545
499
{
546
- for ( size_t i = 0 ; i < nbase; i++ )
500
+ if (i >= last_nbase )
547
501
{
548
-
549
502
hcc[i * this ->nbase_x + i] = set_real_tocomplex (hcc[i * this ->nbase_x + i]);
550
503
scc[i * this ->nbase_x + i] = set_real_tocomplex (scc[i * this ->nbase_x + i]);
551
-
552
- for (size_t j = i + 1 ; j < nbase; j++)
553
- {
554
- hcc[j * this ->nbase_x + i] = get_conj (hcc[i * this ->nbase_x + j]);
555
- scc[j * this ->nbase_x + i] = get_conj (scc[i * this ->nbase_x + j]);
556
- }
557
504
}
558
- for (size_t i = nbase; i < this -> nbase_x ; i ++)
505
+ for (size_t j = std::max (i + 1 , last_nbase); j < nbase; j ++)
559
506
{
560
- for (size_t j = nbase; j < this ->nbase_x ; j++)
561
- {
562
- hcc[i * this ->nbase_x + j] = cs.zero ;
563
- scc[i * this ->nbase_x + j] = cs.zero ;
564
- hcc[j * this ->nbase_x + i] = cs.zero ;
565
- scc[j * this ->nbase_x + i] = cs.zero ;
566
- }
507
+ hcc[i * this ->nbase_x + j] = get_conj (hcc[j * this ->nbase_x + i]);
508
+ scc[i * this ->nbase_x + j] = get_conj (scc[j * this ->nbase_x + i]);
567
509
}
568
510
}
569
- else
511
+
512
+ for (size_t i = nbase; i < this ->nbase_x ; i++)
570
513
{
571
- for (size_t i = 0 ; i < nbase; i ++)
514
+ for (size_t j = nbase; j < this -> nbase_x ; j ++)
572
515
{
573
- if (i >= nb1)
574
- {
575
- hcc[i * this ->nbase_x + i] = set_real_tocomplex (hcc[i * this ->nbase_x + i]);
576
- scc[i * this ->nbase_x + i] = set_real_tocomplex (scc[i * this ->nbase_x + i]);
577
- }
578
- for (size_t j = std::max (i + 1 , (size_t )nb1); j < nbase; j++)
579
- {
580
- hcc[j * this ->nbase_x + i] = get_conj (hcc[i * this ->nbase_x + j]);
581
- scc[j * this ->nbase_x + i] = get_conj (scc[i * this ->nbase_x + j]);
582
- }
583
- }
584
- for (size_t i = nbase; i < this ->nbase_x ; i++)
585
- {
586
- for (size_t j = nbase; j < this ->nbase_x ; j++)
587
- {
588
- hcc[i * this ->nbase_x + j] = cs.zero ;
589
- scc[i * this ->nbase_x + j] = cs.zero ;
590
- hcc[j * this ->nbase_x + i] = cs.zero ;
591
- scc[j * this ->nbase_x + i] = cs.zero ;
592
- }
516
+ hcc[i * this ->nbase_x + j] = cs.zero ;
517
+ scc[i * this ->nbase_x + j] = cs.zero ;
518
+ hcc[j * this ->nbase_x + i] = cs.zero ;
519
+ scc[j * this ->nbase_x + i] = cs.zero ;
593
520
}
594
521
}
595
522
0 commit comments