Skip to content

Commit 21d40ce

Browse files
haozhihanmohanchen
andauthored
Feature: Add a new Davidson iteration method called subspace davidson for pw basis (#3903)
* add new_dav method which is same to davidson method for pw basis * add new_dav method for pw basis which is more efficient than origin dav method * fix ilaenv interface bug * update new_dav method 3.5 * implement new_dav method for pw basis (cpu version, one core) * fix bug of dav method for pw basis * debug for new davidson method * opt some value setting for new_dav files * format and reorganize the code * fix CUDA compile bug * format diago_newdav.cpp * Implement multi-core parallelism of the new davidson method * fix build bug for without mpi * replace new-dav of subspace-dav * change file name from diago_newdav to diago_subspacedav * fix build bug for tests * change the name of subspacedav to dav_subspace * fix build bug in Integration Test --------- Co-authored-by: Mohan Chen <[email protected]>
1 parent 125581b commit 21d40ce

File tree

19 files changed

+2269
-749
lines changed

19 files changed

+2269
-749
lines changed

source/Makefile.Objects

+1
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ OBJS_HCONTAINER=base_matrix.o\
286286

287287
OBJS_HSOLVER=diago_cg.o\
288288
diago_david.o\
289+
diago_dav_subspace.o\
289290
diagh_consts.o\
290291
diago_bpcg.o\
291292
hsolver_pw.o\

source/module_base/lapack_connector.h

+119-65
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
extern "C"
2323
{
2424

25+
int ilaenv_(int* ispec,const char* name,const char* opts,
26+
const int* n1,const int* n2,const int* n3,const int* n4);
27+
2528
// solve the generalized eigenproblem Ax=eBx, where A is Hermitian and complex couble
2629
// zhegv_ & zhegvd_ returns all eigenvalues while zhegvx_ returns selected ones
2730
void dsygvd_(const int* itype, const char* jobz, const char* uplo, const int* n,
@@ -60,9 +63,12 @@ extern "C"
6063
const int* m, double* w, std::complex<double> *z, const int *ldz,
6164
std::complex<double> *work, const int* lwork, double* rwork, int* iwork, int* ifail, int* info);
6265

63-
void zhegv_(const int* itype,const char* jobz,const char* uplo,const int* n,
64-
std::complex<double>* a,const int* lda,std::complex<double>* b,const int* ldb,
65-
double* w,std::complex<double>* work,int* lwork,double* rwork,int* info);
66+
67+
void dsygvx_(const int* itype, const char* jobz, const char* range, const char* uplo,
68+
const int* n, double* A, const int* lda, double* B, const int* ldb,
69+
const double* vl, const double* vu, const int* il, const int* iu,
70+
const double* abstol, const int* m, double* w, double* Z, const int* ldz,
71+
double* work, int* lwork, int*iwork, int* ifail, int* info);
6672

6773
void chegvx_(const int* itype,const char* jobz,const char* range,const char* uplo,
6874
const int* n,std::complex<float> *a,const int* lda,std::complex<float> *b,
@@ -78,6 +84,16 @@ extern "C"
7884
std::complex<double> *z,const int *ldz,std::complex<double> *work,const int* lwork,
7985
double* rwork,int* iwork,int* ifail,int* info);
8086

87+
void zhegv_(const int* itype,const char* jobz,const char* uplo,const int* n,
88+
std::complex<double>* a,const int* lda,std::complex<double>* b,const int* ldb,
89+
double* w,std::complex<double>* work,int* lwork,double* rwork,int* info);
90+
void chegv_(const int* itype,const char* jobz,const char* uplo,const int* n,
91+
std::complex<float>* a,const int* lda,std::complex<float>* b,const int* ldb,
92+
float* w,std::complex<float>* work,int* lwork,float* rwork,int* info);
93+
void dsygv_(const int* itype, const char* jobz,const char* uplo, const int* n,
94+
double* a,const int* lda,double* b,const int* ldb,
95+
double* w,double* work,int* lwork,int* info);
96+
8197
// solve the eigenproblem Ax=ex, where A is Hermitian and complex couble
8298
// zheev_ returns all eigenvalues while zheevx_ returns selected ones
8399
void zheev_(const char* jobz,const char* uplo,const int* n,std::complex<double> *a,
@@ -86,18 +102,6 @@ extern "C"
86102
void cheev_(const char* jobz,const char* uplo,const int* n,std::complex<float> *a,
87103
const int* lda,float* w,std::complex<float >* work,const int* lwork,
88104
float* rwork,int* info);
89-
90-
// solve the generalized eigenproblem Ax=eBx, where A is Symmetric and real couble
91-
// dsygv_ returns all eigenvalues while dsygvx_ returns selected ones
92-
void dsygv_(const int* itype, const char* jobz,const char* uplo, const int* n,
93-
double* a,const int* lda,double* b,const int* ldb,
94-
double* w,double* work,int* lwork,int* info);
95-
void dsygvx_(const int* itype, const char* jobz, const char* range, const char* uplo,
96-
const int* n, double* A, const int* lda, double* B, const int* ldb,
97-
const double* vl, const double* vu, const int* il, const int* iu,
98-
const double* abstol, int* m, double* w, double* Z, const int* ldz,
99-
double* work, int* lwork, int*iwork, int* ifail, int* info);
100-
// solve the eigenproblem Ax=ex, where A is Symmetric and real double
101105
void dsyev_(const char* jobz,const char* uplo,const int* n,double *a,
102106
const int* lda,double* w,double* work,const int* lwork, int* info);
103107

@@ -314,23 +318,19 @@ class LapackConnector
314318
}
315319

316320
public:
317-
// wrap function of fortran lapack routine zhegvd.
318321
static inline
319-
void zhegvd(const int itype, const char jobz, const char uplo, const int n,
320-
std::complex<double>* a, const int lda,
321-
const std::complex<double>* b, const int ldb, double* w,
322-
std::complex<double>* work, int lwork, double* rwork, int lrwork,
323-
int* iwork, int liwork, int& info)
322+
int ilaenv( int ispec, const char *name,const char *opts,const int n1,const int n2,
323+
const int n3,const int n4)
324324
{
325-
zhegvd_(&itype, &jobz, &uplo, &n,
326-
a, &lda, b, &ldb, w,
327-
work, &lwork, rwork, &lrwork,
328-
iwork, &liwork, &info);
325+
const int nb = ilaenv_(&ispec, name, opts, &n1, &n2, &n3, &n4);
326+
return nb;
329327
}
330328

329+
330+
331331
// wrap function of fortran lapack routine zhegvd. (pointer version)
332332
static inline
333-
void xhegvd(const int itype, const char jobz, const char uplo, const int n,
333+
void xhegvd(const int itype, const char jobz, const char uplo, const int n,
334334
double* a, const int lda,
335335
const double* b, const int ldb, double* w,
336336
double* work, int lwork, double* rwork, int lrwork,
@@ -373,23 +373,9 @@ class LapackConnector
373373
iwork, &liwork, &info);
374374
}
375375

376-
// wrap function of fortran lapack routine zheevx.
377-
static inline
378-
void zheevx( const int itype, const char jobz, const char range, const char uplo, const int n,
379-
std::complex<double>* a, const int lda,
380-
const double vl, const double vu, const int il, const int iu, const double abstol,
381-
const int m, double* w, std::complex<double>* z, const int ldz,
382-
std::complex<double>* work, const int lwork, double* rwork, int* iwork, int* ifail, int& info)
383-
{
384-
zheevx_(&jobz, &range, &uplo, &n,
385-
a, &lda, &vl, &vu, &il, &iu,
386-
&abstol, &m, w, z, &ldz,
387-
work, &lwork, rwork, iwork, ifail, &info);
388-
}
389-
390376
// wrap function of fortran lapack routine dsyevx.
391377
static inline
392-
void xheevx(const int itype, const char jobz, const char range, const char uplo, const int n,
378+
void xheevx(const int itype, const char jobz, const char range, const char uplo, const int n,
393379
double* a, const int lda,
394380
const double vl, const double vu, const int il, const int iu, const double abstol,
395381
const int m, double* w, double* z, const int ldz,
@@ -428,6 +414,98 @@ class LapackConnector
428414
&abstol, &m, w, z, &ldz,
429415
work, &lwork, rwork, iwork, ifail, &info);
430416
}
417+
418+
// wrap function of fortran lapack routine xhegvx ( pointer version ).
419+
static inline
420+
void xhegvx( const int itype, const char jobz, const char range, const char uplo,
421+
const int n, std::complex<float>* a, const int lda, std::complex<float>* b,
422+
const int ldb, const float vl, const float vu, const int il, const int iu,
423+
const float abstol, const int m, float* w, std::complex<float>* z, const int ldz,
424+
std::complex<float>* work, const int lwork, float* rwork, int* iwork,
425+
int* ifail, int& info)
426+
{
427+
chegvx_(&itype, &jobz, &range, &uplo, &n, a, &lda, b, &ldb, &vl,
428+
&vu, &il,&iu, &abstol, &m, w, z, &ldz, work, &lwork, rwork, iwork, ifail, &info);
429+
}
430+
431+
// wrap function of fortran lapack routine xhegvx ( pointer version ).
432+
static inline
433+
void xhegvx( const int itype, const char jobz, const char range, const char uplo,
434+
const int n, std::complex<double>* a, const int lda, std::complex<double>* b,
435+
const int ldb, const double vl, const double vu, const int il, const int iu,
436+
const double abstol, const int m, double* w, std::complex<double>* z, const int ldz,
437+
std::complex<double>* work, const int lwork, double* rwork, int* iwork,
438+
int* ifail, int& info)
439+
{
440+
zhegvx_(&itype, &jobz, &range, &uplo, &n, a, &lda, b, &ldb, &vl,
441+
&vu, &il,&iu, &abstol, &m, w, z, &ldz, work, &lwork, rwork, iwork, ifail, &info);
442+
}
443+
// wrap function of fortran lapack routine xhegvx ( pointer version ).
444+
static inline
445+
void xhegvx( const int itype, const char jobz, const char range, const char uplo,
446+
const int n, double* a, const int lda, double* b,
447+
const int ldb, const double vl, const double vu, const int il, const int iu,
448+
const double abstol, const int m, double* w, double* z, const int ldz,
449+
double* work, const int lwork, double* rwork, int* iwork,
450+
int* ifail, int& info)
451+
{
452+
// dsygvx_(&itype, &jobz, &range, &uplo, &n, a, &lda, b, &ldb, &vl,
453+
// &vu, &il,&iu, &abstol, &m, w, z, &ldz, work, &lwork, rwork, iwork, ifail, &info);
454+
}
455+
456+
457+
// wrap function of fortran lapack routine xhegvx ( pointer version ).
458+
static inline
459+
void xhegv( const int itype, const char jobz, const char uplo,
460+
const int n,
461+
double* a, const int lda,
462+
double* b, const int ldb,
463+
double* w,
464+
double* work, int lwork,
465+
double* rwork, int& info)
466+
{
467+
// TODO
468+
}
469+
470+
// wrap function of fortran lapack routine xhegvx ( pointer version ).
471+
static inline
472+
void xhegv( const int itype, const char jobz, const char uplo,
473+
const int n,
474+
std::complex<float>* a, const int lda,
475+
std::complex<float>* b, const int ldb,
476+
float* w,
477+
std::complex<float>* work, int lwork,
478+
float* rwork, int& info)
479+
{
480+
// TODO
481+
}
482+
// wrap function of fortran lapack routine xhegvx ( pointer version ).
483+
static inline
484+
void xhegv( const int itype, const char jobz, const char uplo,
485+
const int n,
486+
std::complex<double>* a, const int lda,
487+
std::complex<double>* b, const int ldb,
488+
double* w,
489+
std::complex<double>* work, int lwork,
490+
double* rwork, int& info)
491+
{
492+
zhegv_(&itype, &jobz, &uplo, &n, a, &lda, b, &ldb, w, work, &lwork, rwork, &info);
493+
}
494+
495+
496+
// wrap function of fortran lapack routine zhegvd.
497+
static inline
498+
void zhegvd(const int itype, const char jobz, const char uplo, const int n,
499+
std::complex<double>* a, const int lda,
500+
const std::complex<double>* b, const int ldb, double* w,
501+
std::complex<double>* work, int lwork, double* rwork, int lrwork,
502+
int* iwork, int liwork, int& info)
503+
{
504+
zhegvd_(&itype, &jobz, &uplo, &n,
505+
a, &lda, b, &ldb, w,
506+
work, &lwork, rwork, &lrwork,
507+
iwork, &liwork, &info);
508+
}
431509

432510
// wrap function of fortran lapack routine zhegv ( ModuleBase::ComplexMatrix version ).
433511
static inline
@@ -543,30 +621,6 @@ class LapackConnector
543621
delete[] zux;
544622
}
545623

546-
// wrap function of fortran lapack routine xhegvx ( pointer version ).
547-
static inline
548-
void xhegvx( const int itype, const char jobz, const char range, const char uplo,
549-
const int n, const std::complex<float>* a, const int lda, const std::complex<float>* b,
550-
const int ldb, const float vl, const float vu, const int il, const int iu,
551-
const float abstol, const int m, float* w, std::complex<float>* z, const int ldz,
552-
std::complex<float>* work, const int lwork, float* rwork, int* iwork,
553-
int* ifail, int& info, int nbase_x)
554-
{
555-
chegvx(itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, work, lwork, rwork, iwork, ifail, info, nbase_x);
556-
}
557-
558-
// wrap function of fortran lapack routine xhegvx ( pointer version ).
559-
static inline
560-
void xhegvx( const int itype, const char jobz, const char range, const char uplo,
561-
const int n, const std::complex<double>* a, const int lda, const std::complex<double>* b,
562-
const int ldb, const double vl, const double vu, const int il, const int iu,
563-
const double abstol, const int m, double* w, std::complex<double>* z, const int ldz,
564-
std::complex<double>* work, const int lwork, double* rwork, int* iwork,
565-
int* ifail, int& info, int nbase_x)
566-
{
567-
zhegvx(itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, work, lwork, rwork, iwork, ifail, info, nbase_x);
568-
}
569-
570624
// calculate the eigenvalues and eigenfunctions of a real symmetric matrix.
571625
static inline
572626
void dsygv( const int itype,const char jobz,const char uplo,const int n,ModuleBase::matrix& a,

source/module_basis/module_ao/ORB_control.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ void ORB_control::setup_2d_division(std::ofstream& ofs_running,
203203

204204
// determine whether 2d-division or not according to ks_solver
205205
bool div_2d;
206-
if (ks_solver == "lapack" || ks_solver == "cg" || ks_solver == "dav") div_2d = false;
206+
if (ks_solver == "lapack" || ks_solver == "cg" || ks_solver == "dav" || ks_solver == "dav_subspace") div_2d = false;
207207
#ifdef __MPI
208208
else if (ks_solver == "genelpa" || ks_solver == "scalapack_gvx" || ks_solver == "cusolver" || ks_solver == "cg_in_lcao") div_2d = true;
209209
#endif

source/module_elecstate/elecstate_print.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,10 @@ void ElecState::print_etot(const bool converged,
289289
{
290290
label = "DA";
291291
}
292+
else if (ks_solver_type == "dav_subspace")
293+
{
294+
label = "DS";
295+
}
292296
else if (ks_solver_type == "scalapack_gvx")
293297
{
294298
label = "GV";

source/module_hamilt_pw/hamilt_pwdft/wavefunc.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ void diago_PAO_in_pw_k2(const psi::DEVICE_GPU *ctx,
574574
//GlobalC::hm.diagH_subspace(ik ,starting_nw, nbands, wfcatom, wfcatom, etatom.data());
575575
}
576576
}
577-
else if(GlobalV::KS_SOLVER=="dav")
577+
else if (GlobalV::KS_SOLVER == "dav" || GlobalV::KS_SOLVER == "dav_subspace")
578578
{
579579
assert(nbands <= wfcatom.nr);
580580
// replace by haozhihan 2022-11-23
@@ -685,8 +685,8 @@ void diago_PAO_in_pw_k2(const psi::DEVICE_GPU *ctx,
685685
//GlobalC::hm.diagH_subspace(ik ,starting_nw, nbands, wfcatom, wfcatom, etatom.data());
686686
}
687687
}
688-
else if(GlobalV::KS_SOLVER=="dav")
689-
{
688+
else if (GlobalV::KS_SOLVER == "dav" || GlobalV::KS_SOLVER == "dav_subspace")
689+
{
690690
assert(nbands <= wfcatom.nr);
691691
// replace by haozhihan 2022-11-23
692692
hsolver::matrixSetToAnother<std::complex<double>, psi::DEVICE_GPU>()(

source/module_hsolver/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ list(APPEND objects
22
diagh_consts.cpp
33
diago_cg.cpp
44
diago_david.cpp
5+
diago_dav_subspace.cpp
56
diago_bpcg.cpp
67
hsolver_pw.cpp
78
hsolver_pw_sdft.cpp

0 commit comments

Comments
 (0)