Skip to content

Commit ef61f00

Browse files
authored
Refactor: set dm_to_rho as a new esolver (#6059)
1 parent e119b1f commit ef61f00

11 files changed

+158
-117
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ OBJS_ESOLVER_LCAO=esolver_ks_lcao.o\
266266
lcao_after_scf.o\
267267
esolver_gets.o\
268268
lcao_others.o\
269+
esolver_dm2rho.o\
269270

270271
OBJS_GINT=gint.o\
271272
gint_gamma_env.o\

source/module_elecstate/elecstate_lcao.cpp

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,32 +23,6 @@ void ElecStateLCAO<std::complex<double>>::psiToRho(const psi::Psi<std::complex<d
2323
ModuleBase::TITLE("ElecStateLCAO", "psiToRho");
2424
ModuleBase::timer::tick("ElecStateLCAO", "psiToRho");
2525

26-
// // the calculations of dm, and dm -> rho are, technically, two separate
27-
// // functionalities, as we cannot rule out the possibility that we may have a
28-
// // dm from other sources, such as read from file. However, since we are not
29-
// // separating them now, I opt to add a flag to control how dm is obtained as
30-
// // of now
31-
// if (!PARAM.inp.dm_to_rho)
32-
// {
33-
// ModuleBase::GlobalFunc::NOTE("Calculate the density matrix.");
34-
35-
// // this part for calculating DMK in 2d-block format, not used for charge
36-
// // now
37-
// // psi::Psi<std::complex<double>> dm_k_2d();
38-
39-
// if (PARAM.inp.ks_solver == "genelpa" || PARAM.inp.ks_solver == "elpa" || PARAM.inp.ks_solver ==
40-
// "scalapack_gvx" || PARAM.inp.ks_solver == "lapack"
41-
// || PARAM.inp.ks_solver == "cusolver" || PARAM.inp.ks_solver == "cusolvermp"
42-
// || PARAM.inp.ks_solver == "cg_in_lcao") // Peize Lin test 2019-05-15
43-
// {
44-
// elecstate::cal_dm_psi(this->DM->get_paraV_pointer(),
45-
// this->wg,
46-
// psi,
47-
// *(this->DM));
48-
// this->DM->cal_DMR();
49-
// }
50-
// }
51-
5226
for (int is = 0; is < PARAM.inp.nspin; is++)
5327
{
5428
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is],

source/module_esolver/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ if(ENABLE_LCAO)
2020
lcao_after_scf.cpp
2121
esolver_gets.cpp
2222
lcao_others.cpp
23+
esolver_dm2rho.cpp
2324
)
2425
endif()
2526

source/module_esolver/esolver.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "module_base/module_device/device.h"
66
#include "module_parameter/parameter.h"
77
#ifdef __LCAO
8+
#include "esolver_dm2rho.h"
89
#include "esolver_gets.h"
910
#include "esolver_ks_lcao.h"
1011
#include "esolver_ks_lcao_tddft.h"
@@ -204,11 +205,25 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell)
204205
}
205206
else if (PARAM.inp.nspin < 4)
206207
{
207-
return new ESolver_KS_LCAO<std::complex<double>, double>();
208+
if (PARAM.inp.dm_to_rho)
209+
{
210+
return new ESolver_DM2rho<std::complex<double>, double>();
211+
}
212+
else
213+
{
214+
return new ESolver_KS_LCAO<std::complex<double>, double>();
215+
}
208216
}
209217
else
210218
{
211-
return new ESolver_KS_LCAO<std::complex<double>, std::complex<double>>();
219+
if (PARAM.inp.dm_to_rho)
220+
{
221+
return new ESolver_DM2rho<std::complex<double>, std::complex<double>>();
222+
}
223+
else
224+
{
225+
return new ESolver_KS_LCAO<std::complex<double>, std::complex<double>>();
226+
}
212227
}
213228
}
214229
else if (esolver_type == "ksdft_lcao_tddft")
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#include "esolver_dm2rho.h"
2+
3+
#include "module_base/timer.h"
4+
#include "module_cell/module_neighbor/sltk_atom_arrange.h"
5+
#include "module_elecstate/elecstate_lcao.h"
6+
#include "module_elecstate/read_pseudo.h"
7+
#include "module_hamilt_lcao/hamilt_lcaodft/LCAO_domain.h"
8+
#include "module_hamilt_lcao/hamilt_lcaodft/hamilt_lcao.h"
9+
#include "module_hamilt_lcao/hamilt_lcaodft/operator_lcao/operator_lcao.h"
10+
#include "module_io/cube_io.h"
11+
#include "module_io/io_npz.h"
12+
#include "module_io/print_info.h"
13+
14+
namespace ModuleESolver
15+
{
16+
17+
template <typename TK, typename TR>
18+
ESolver_DM2rho<TK, TR>::ESolver_DM2rho()
19+
{
20+
this->classname = "ESolver_DM2rho";
21+
this->basisname = "LCAO";
22+
}
23+
24+
template <typename TK, typename TR>
25+
ESolver_DM2rho<TK, TR>::~ESolver_DM2rho()
26+
{
27+
}
28+
29+
template <typename TK, typename TR>
30+
void ESolver_DM2rho<TK, TR>::before_all_runners(UnitCell& ucell, const Input_para& inp)
31+
{
32+
ModuleBase::TITLE("ESolver_DM2rho", "before_all_runners");
33+
ModuleBase::timer::tick("ESolver_DM2rho", "before_all_runners");
34+
35+
ESolver_KS_LCAO<TK, TR>::before_all_runners(ucell, inp);
36+
37+
ModuleBase::timer::tick("ESolver_DM2rho", "before_all_runners");
38+
}
39+
40+
template <typename TK, typename TR>
41+
void ESolver_DM2rho<TK, TR>::runner(UnitCell& ucell, const int istep)
42+
{
43+
ModuleBase::TITLE("ESolver_DM2rho", "runner");
44+
ModuleBase::timer::tick("ESolver_DM2rho", "runner");
45+
46+
ESolver_KS_LCAO<TK, TR>::before_scf(ucell, istep);
47+
48+
// file name of DM
49+
std::string zipname = "output_DM0.npz";
50+
elecstate::DensityMatrix<TK, double>* dm = dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM();
51+
52+
// read DM from file
53+
ModuleIO::read_mat_npz(&(this->pv), ucell, zipname, *(dm->get_DMR_pointer(1)));
54+
55+
// if nspin=2, need extra reading
56+
if (PARAM.inp.nspin == 2)
57+
{
58+
zipname = "output_DM1.npz";
59+
ModuleIO::read_mat_npz(&(this->pv), ucell, zipname, *(dm->get_DMR_pointer(2)));
60+
}
61+
62+
this->pelec->psiToRho(*this->psi);
63+
64+
int nspin0 = PARAM.inp.nspin == 2 ? 2 : 1;
65+
66+
for (int is = 0; is < nspin0; is++)
67+
{
68+
std::string fn = PARAM.globalv.global_out_dir + "/SPIN" + std::to_string(is + 1) + "_CHG.cube";
69+
70+
// write electron density
71+
ModuleIO::write_vdata_palgrid(this->Pgrid,
72+
this->chr.rho[is],
73+
is,
74+
PARAM.inp.nspin,
75+
istep,
76+
fn,
77+
this->pelec->eferm.get_efval(is),
78+
&(ucell),
79+
3,
80+
1);
81+
}
82+
83+
ModuleBase::timer::tick("ESolver_DM2rho", "runner");
84+
}
85+
86+
template <typename TK, typename TR>
87+
void ESolver_DM2rho<TK, TR>::after_all_runners(UnitCell& ucell)
88+
{
89+
ModuleBase::TITLE("ESolver_DM2rho", "after_all_runners");
90+
ModuleBase::timer::tick("ESolver_DM2rho", "after_all_runners");
91+
92+
ESolver_KS_LCAO<TK, TR>::after_all_runners(ucell);
93+
94+
ModuleBase::timer::tick("ESolver_DM2rho", "after_all_runners");
95+
};
96+
97+
template class ESolver_DM2rho<std::complex<double>, double>;
98+
template class ESolver_DM2rho<std::complex<double>, std::complex<double>>;
99+
100+
} // namespace ModuleESolver
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#ifndef ESOLVER_DM2RHO_H
2+
#define ESOLVER_DM2RHO_H
3+
4+
#include "module_esolver/esolver_ks_lcao.h"
5+
6+
#include <memory>
7+
8+
namespace ModuleESolver
9+
{
10+
11+
template <typename TK, typename TR>
12+
class ESolver_DM2rho : public ESolver_KS_LCAO<TK, TR>
13+
{
14+
public:
15+
ESolver_DM2rho();
16+
~ESolver_DM2rho();
17+
18+
void before_all_runners(UnitCell& ucell, const Input_para& inp) override;
19+
20+
void after_all_runners(UnitCell& ucell) override;
21+
22+
void runner(UnitCell& ucell, const int istep) override;
23+
};
24+
} // namespace ModuleESolver
25+
#endif

source/module_esolver/esolver_fp.cpp

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -152,20 +152,10 @@ void ESolver_FP::after_scf(UnitCell& ucell, const int istep, const bool conv_eso
152152
{
153153
for (int is = 0; is < PARAM.inp.nspin; is++)
154154
{
155-
double* data = nullptr;
156-
if (PARAM.inp.dm_to_rho)
157-
{
158-
data = this->chr.rho[is];
159-
this->pw_rhod->real2recip(this->chr.rho[is], this->chr.rhog[is]);
160-
}
161-
else
162-
{
163-
data = this->chr.rho_save[is];
164-
this->pw_rhod->real2recip(this->chr.rho_save[is], this->chr.rhog_save[is]);
165-
}
155+
this->pw_rhod->real2recip(this->chr.rho_save[is], this->chr.rhog_save[is]);
166156
std::string fn =PARAM.globalv.global_out_dir + "/SPIN" + std::to_string(is + 1) + "_CHG.cube";
167157
ModuleIO::write_vdata_palgrid(Pgrid,
168-
data,
158+
this->chr.rho_save[is],
169159
is,
170160
PARAM.inp.nspin,
171161
istep,
@@ -381,19 +371,16 @@ void ESolver_FP::iter_finish(UnitCell& ucell, const int istep, int& iter, bool&
381371
{
382372
if (iter % PARAM.inp.out_freq_elec == 0 || iter == PARAM.inp.scf_nmax || conv_esolver)
383373
{
384-
std::complex<double>** rhog_tot
385-
= (PARAM.inp.dm_to_rho) ? this->chr.rhog : this->chr.rhog_save;
386-
double** rhor_tot = (PARAM.inp.dm_to_rho) ? this->chr.rho : this->chr.rho_save;
387374
for (int is = 0; is < PARAM.inp.nspin; is++)
388375
{
389-
this->pw_rhod->real2recip(rhor_tot[is], rhog_tot[is]);
376+
this->pw_rhod->real2recip(this->chr.rho_save[is], this->chr.rhog_save[is]);
390377
}
391378
ModuleIO::write_rhog(PARAM.globalv.global_out_dir + PARAM.inp.suffix + "-CHARGE-DENSITY.restart",
392379
PARAM.globalv.gamma_only_pw || PARAM.globalv.gamma_only_local,
393380
this->pw_rhod,
394381
PARAM.inp.nspin,
395382
ucell.GT,
396-
rhog_tot,
383+
this->chr.rhog_save,
397384
GlobalV::MY_POOL,
398385
GlobalV::RANK_IN_POOL,
399386
GlobalV::NPROC_IN_POOL);

source/module_esolver/esolver_ks.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -439,13 +439,6 @@ void ESolver_KS<T, Device>::runner(UnitCell& ucell, const int istep)
439439
// 2) before_scf (electronic iteration loops)
440440
this->before_scf(ucell, istep);
441441

442-
// 3) write charge density
443-
if (PARAM.inp.dm_to_rho)
444-
{
445-
ModuleBase::timer::tick(this->classname, "runner");
446-
return; // nothing further is needed
447-
}
448-
449442
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT SCF");
450443

451444
// 4) SCF iterations

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -637,13 +637,11 @@ void ESolver_KS_LCAO<TK, TR>::iter_init(UnitCell& ucell, const int istep, const
637637
this->pelec->nelec_spin,
638638
this->pelec->skip_weights);
639639

640-
if (!PARAM.inp.dm_to_rho)
641-
{
642-
auto _pelec = dynamic_cast<elecstate::ElecStateLCAO<TK>*>(this->pelec);
643-
elecstate::calEBand(_pelec->ekb,_pelec->wg,_pelec->f_en);
644-
elecstate::cal_dm_psi(_pelec->DM->get_paraV_pointer(), _pelec->wg, *this->psi, *(_pelec->DM));
645-
_pelec->DM->cal_DMR();
646-
}
640+
auto _pelec = dynamic_cast<elecstate::ElecStateLCAO<TK>*>(this->pelec);
641+
elecstate::calEBand(_pelec->ekb, _pelec->wg, _pelec->f_en);
642+
elecstate::cal_dm_psi(_pelec->DM->get_paraV_pointer(), _pelec->wg, *this->psi, *(_pelec->DM));
643+
_pelec->DM->cal_DMR();
644+
647645
this->pelec->psiToRho(*this->psi);
648646
this->pelec->skip_weights = false;
649647

source/module_esolver/lcao_before_scf.cpp

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -251,57 +251,6 @@ void ESolver_KS_LCAO<TK, TR>::before_scf(UnitCell& ucell, const int istep)
251251
dynamic_cast<elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM()->cal_DMR();
252252
}
253253

254-
if (PARAM.inp.dm_to_rho)
255-
{
256-
// file name of DM
257-
std::string zipname = "output_DM0.npz";
258-
elecstate::DensityMatrix<TK, double>* dm
259-
= dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM();
260-
261-
// read DM from file
262-
ModuleIO::read_mat_npz(&(this->pv), ucell, zipname, *(dm->get_DMR_pointer(1)));
263-
264-
// if nspin=2, need extra reading
265-
if (PARAM.inp.nspin == 2)
266-
{
267-
zipname = "output_DM1.npz";
268-
ModuleIO::read_mat_npz(&(this->pv), ucell, zipname, *(dm->get_DMR_pointer(2)));
269-
}
270-
271-
elecstate::calculate_weights(this->pelec->ekb,
272-
this->pelec->wg,
273-
this->pelec->klist,
274-
this->pelec->eferm,
275-
this->pelec->f_en,
276-
this->pelec->nelec_spin,
277-
this->pelec->skip_weights);
278-
279-
this->pelec->psiToRho(*this->psi);
280-
281-
int nspin0 = PARAM.inp.nspin == 2 ? 2 : 1;
282-
283-
for (int is = 0; is < nspin0; is++)
284-
{
285-
std::string fn = PARAM.globalv.global_out_dir + "/SPIN" + std::to_string(is + 1) + "_CHG.cube";
286-
287-
// write electron density
288-
ModuleIO::write_vdata_palgrid(this->Pgrid,
289-
this->chr.rho[is],
290-
is,
291-
PARAM.inp.nspin,
292-
istep,
293-
fn,
294-
this->pelec->eferm.get_efval(is),
295-
&(ucell),
296-
3,
297-
1);
298-
}
299-
300-
// why we need to return here? mohan add 2025-03-10
301-
ModuleBase::timer::tick("ESolver_KS_LCAO", "before_scf");
302-
return;
303-
}
304-
305254
// 16) the electron charge density should be symmetrized,
306255
// here is the initialization
307256
Symmetry_rho srho;

source/module_hsolver/hsolver_lcao.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,11 @@ void HSolverLCAO<T, Device>::solve(hamilt::Hamilt<T>* pHamilt,
8383
pes->f_en,
8484
pes->nelec_spin,
8585
pes->skip_weights);
86-
if (!PARAM.inp.dm_to_rho)
87-
{
88-
auto _pes_lcao = dynamic_cast<elecstate::ElecStateLCAO<T>*>(pes);
89-
elecstate::calEBand(_pes_lcao->ekb,_pes_lcao->wg,_pes_lcao->f_en);
90-
elecstate::cal_dm_psi(_pes_lcao->DM->get_paraV_pointer(), _pes_lcao->wg, psi, *(_pes_lcao->DM));
91-
_pes_lcao->DM->cal_DMR();
92-
}
86+
87+
auto _pes_lcao = dynamic_cast<elecstate::ElecStateLCAO<T>*>(pes);
88+
elecstate::calEBand(_pes_lcao->ekb, _pes_lcao->wg, _pes_lcao->f_en);
89+
elecstate::cal_dm_psi(_pes_lcao->DM->get_paraV_pointer(), _pes_lcao->wg, psi, *(_pes_lcao->DM));
90+
_pes_lcao->DM->cal_DMR();
9391

9492
if (!skip_charge)
9593
{

0 commit comments

Comments
 (0)