Skip to content

Commit 49737e5

Browse files
authored
update cal_elem func (#3953)
avoiding matrix transposition can computational efficiency.
1 parent aacf171 commit 49737e5

File tree

2 files changed

+47
-121
lines changed

2 files changed

+47
-121
lines changed

source/module_hsolver/diago_dav_subspace.cpp

Lines changed: 46 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,7 @@ void Diago_DavSubspace<T, Device>::diag_once(hamilt::Hamilt<T, Device>* phm_in,
134134
basis,
135135
this->hphi,
136136
this->hcc,
137-
this->scc,
138-
true);
137+
this->scc);
139138

140139
this->diag_zhegvx(nbase,
141140
this->n_band,
@@ -175,8 +174,7 @@ void Diago_DavSubspace<T, Device>::diag_once(hamilt::Hamilt<T, Device>* phm_in,
175174
basis,
176175
this->hphi,
177176
this->hcc,
178-
this->scc,
179-
false);
177+
this->scc);
180178

181179
this->diag_zhegvx(nbase,
182180
this->n_band,
@@ -399,84 +397,43 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
399397
const psi::Psi<T, Device>& basis,
400398
const T* hphi,
401399
T* hcc,
402-
T* scc,
403-
bool init)
400+
T* scc)
404401
{
405402
ModuleBase::timer::tick("Diago_DavSubspace", "cal_elem");
406403

407-
if (init)
408-
{
409-
assert(nbase == 0);
410-
assert(this->n_band == notconv);
411-
gemm_op<T, Device>()(this->ctx,
412-
'C',
413-
'N',
414-
notconv,
415-
notconv,
416-
this->dim,
417-
this->one,
418-
&basis(0, 0),
419-
this->dim,
420-
hphi,
421-
this->dim,
422-
this->zero,
423-
hcc,
424-
this->nbase_x);
425-
426-
gemm_op<T, Device>()(this->ctx,
427-
'C',
428-
'N',
429-
notconv,
430-
notconv,
431-
this->dim,
432-
this->one,
433-
&basis(0, 0),
434-
this->dim,
435-
&basis(0, 0),
436-
this->dim,
437-
this->zero,
438-
scc,
439-
this->nbase_x);
440-
}
441-
else
442-
{
443-
gemm_op<T, Device>()(this->ctx,
444-
'C',
445-
'N',
446-
notconv,
447-
nbase + notconv,
448-
this->dim,
449-
this->one,
450-
&hphi[nbase * this->dim],
451-
this->dim,
452-
&basis(0, 0),
453-
this->dim,
454-
this->zero,
455-
hcc + nbase,
456-
this->nbase_x);
404+
gemm_op<T, Device>()(this->ctx,
405+
'C',
406+
'N',
407+
nbase + notconv,
408+
notconv,
409+
this->dim,
410+
this->one,
411+
&basis(0, 0),
412+
this->dim,
413+
&hphi[nbase * this->dim],
414+
this->dim,
415+
this->zero,
416+
&hcc[nbase * this->nbase_x],
417+
this->nbase_x);
457418

458-
gemm_op<T, Device>()(this->ctx,
459-
'C',
460-
'N',
461-
notconv,
462-
nbase + notconv,
463-
this->dim,
464-
this->one,
465-
&basis(nbase, 0),
466-
this->dim,
467-
&basis(0, 0),
468-
this->dim,
469-
this->zero,
470-
scc + nbase,
471-
this->nbase_x);
472-
}
419+
gemm_op<T, Device>()(this->ctx,
420+
'C',
421+
'N',
422+
nbase + notconv,
423+
notconv,
424+
this->dim,
425+
this->one,
426+
&basis(0, 0),
427+
this->dim,
428+
&basis(nbase, 0),
429+
this->dim,
430+
this->zero,
431+
&scc[nbase * this->nbase_x],
432+
this->nbase_x);
473433

474434
#ifdef __MPI
475435
if (GlobalV::NPROC_IN_POOL > 1)
476436
{
477-
matrixTranspose_op<T, Device>()(this->ctx, this->nbase_x, this->nbase_x, hcc, hcc);
478-
matrixTranspose_op<T, Device>()(this->ctx, this->nbase_x, this->nbase_x, scc, scc);
479-
480437
auto* swap = new T[notconv * this->nbase_x];
481438
syncmem_complex_op()(this->ctx, this->ctx, swap, hcc + nbase * this->nbase_x, notconv * this->nbase_x);
482439

@@ -532,64 +489,34 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
532489
}
533490
}
534491
delete[] swap;
535-
536-
matrixTranspose_op<T, Device>()(this->ctx, this->nbase_x, this->nbase_x, hcc, hcc);
537-
matrixTranspose_op<T, Device>()(this->ctx, this->nbase_x, this->nbase_x, scc, scc);
538492
}
539493
#endif
540494

541-
nbase += notconv;
542-
int nb1 = nbase - notconv;
543-
// reset:
544-
if (init)
495+
const size_t last_nbase = nbase; // init: last_nbase = 0
496+
nbase = nbase + notconv;
497+
498+
for (size_t i = 0; i < nbase; i++)
545499
{
546-
for (size_t i = 0; i < nbase; i++)
500+
if (i >= last_nbase)
547501
{
548-
549502
hcc[i * this->nbase_x + i] = set_real_tocomplex(hcc[i * this->nbase_x + i]);
550503
scc[i * this->nbase_x + i] = set_real_tocomplex(scc[i * this->nbase_x + i]);
551-
552-
for (size_t j = i + 1; j < nbase; j++)
553-
{
554-
hcc[j * this->nbase_x + i] = get_conj(hcc[i * this->nbase_x + j]);
555-
scc[j * this->nbase_x + i] = get_conj(scc[i * this->nbase_x + j]);
556-
}
557504
}
558-
for (size_t i = nbase; i < this->nbase_x; i++)
505+
for (size_t j = std::max(i + 1, last_nbase); j < nbase; j++)
559506
{
560-
for (size_t j = nbase; j < this->nbase_x; j++)
561-
{
562-
hcc[i * this->nbase_x + j] = cs.zero;
563-
scc[i * this->nbase_x + j] = cs.zero;
564-
hcc[j * this->nbase_x + i] = cs.zero;
565-
scc[j * this->nbase_x + i] = cs.zero;
566-
}
507+
hcc[i * this->nbase_x + j] = get_conj(hcc[j * this->nbase_x + i]);
508+
scc[i * this->nbase_x + j] = get_conj(scc[j * this->nbase_x + i]);
567509
}
568510
}
569-
else
511+
512+
for (size_t i = nbase; i < this->nbase_x; i++)
570513
{
571-
for (size_t i = 0; i < nbase; i++)
514+
for (size_t j = nbase; j < this->nbase_x; j++)
572515
{
573-
if (i >= nb1)
574-
{
575-
hcc[i * this->nbase_x + i] = set_real_tocomplex(hcc[i * this->nbase_x + i]);
576-
scc[i * this->nbase_x + i] = set_real_tocomplex(scc[i * this->nbase_x + i]);
577-
}
578-
for (size_t j = std::max(i + 1, (size_t)nb1); j < nbase; j++)
579-
{
580-
hcc[j * this->nbase_x + i] = get_conj(hcc[i * this->nbase_x + j]);
581-
scc[j * this->nbase_x + i] = get_conj(scc[i * this->nbase_x + j]);
582-
}
583-
}
584-
for (size_t i = nbase; i < this->nbase_x; i++)
585-
{
586-
for (size_t j = nbase; j < this->nbase_x; j++)
587-
{
588-
hcc[i * this->nbase_x + j] = cs.zero;
589-
scc[i * this->nbase_x + j] = cs.zero;
590-
hcc[j * this->nbase_x + i] = cs.zero;
591-
scc[j * this->nbase_x + i] = cs.zero;
592-
}
516+
hcc[i * this->nbase_x + j] = cs.zero;
517+
scc[i * this->nbase_x + j] = cs.zero;
518+
hcc[j * this->nbase_x + i] = cs.zero;
519+
scc[j * this->nbase_x + i] = cs.zero;
593520
}
594521
}
595522

source/module_hsolver/diago_dav_subspace.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ class Diago_DavSubspace : public DiagH<T, Device>
8181
const psi::Psi<T, Device>& basis,
8282
const T* hphi,
8383
T* hcc,
84-
T* scc,
85-
bool init);
84+
T* scc);
8685

8786
void refresh(const int& dim,
8887
const int& nband,

0 commit comments

Comments
 (0)