Skip to content

Commit 924d29d

Browse files
authored
Reopen unit test for DeePKS multi-k case. (#6039)
* Add default threshold and precision value for Output_HContainer. * Reopen DeePKS multi-k unittest. * Update multi-k ut in DeePKS.
1 parent 2610f1e commit 924d29d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+6059
-40253
lines changed

source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/deepks_lcao.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::contributeHR()
173173
this->ld->lmaxd,
174174
this->ld->inl2l,
175175
this->ld->inl_index,
176+
this->kvec_d,
176177
this->DM,
177178
this->ld->phialpha,
178179
*this->ucell,

source/module_hamilt_lcao/module_deepks/LCAO_deepks_interface.cpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
3939
using TH = std::conditional_t<std::is_same<TK, double>::value, ModuleBase::matrix, ModuleBase::ComplexMatrix>;
4040

4141
// These variables are frequently used in the following code
42-
const int inlmax = orb.Alpha[0].getTotal_nchi() * nat;
42+
const int nlmax = orb.Alpha[0].getTotal_nchi();
43+
const int inlmax = nlmax * nat;
4344
const int lmaxd = orb.get_lmax_d();
4445
const int nmaxd = ld->nmaxd;
4546

@@ -62,7 +63,7 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
6263
// this part is for integrated test of deepks
6364
// so it is printed no matter even if deepks_out_labels is not used
6465
DeePKS_domain::cal_pdm<
65-
TK>(init_pdm, inlmax, lmaxd, inl2l, inl_index, dm, phialpha, ucell, orb, GridD, *ParaV, pdm);
66+
TK>(init_pdm, inlmax, lmaxd, inl2l, inl_index, kvec_d, dm, phialpha, ucell, orb, GridD, *ParaV, pdm);
6667

6768
DeePKS_domain::check_pdm(inlmax, inl2l, pdm); // print out the projected dm for NSCF calculaiton
6869

@@ -312,6 +313,35 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
312313
out_hr.write();
313314
ofs_hr.close();
314315
}
316+
317+
const std::string file_vdrpre = PARAM.globalv.global_out_dir + "deepks_vdrpre.csr";
318+
std::vector<hamilt::HContainer<TR>*> h_deltaR_pre(inlmax);
319+
for (int i = 0; i < inlmax; i++)
320+
{
321+
h_deltaR_pre[i] = new hamilt::HContainer<TR>(*hR_tot);
322+
h_deltaR_pre[i]->set_zero();
323+
}
324+
// DeePKS_domain::cal_vdr_precalc<TR>();
325+
if (rank == 0)
326+
{
327+
std::ofstream ofs_hrp(file_vdrpre, std::ios::out);
328+
for (int iat = 0; iat < nat; iat++)
329+
{
330+
ofs_hrp << "- Index of atom: " << iat << std::endl;
331+
for (int nl = 0; nl < nlmax; nl++)
332+
{
333+
int inl = iat * nlmax + nl;
334+
ofs_hrp << "-- Index of nl: " << nl << std::endl;
335+
ofs_hrp << "Matrix Dimension of H_delta(R): " << h_deltaR_pre[inl]->get_nbasis() << std::endl;
336+
ofs_hrp << "Matrix number of H_delta(R): " << h_deltaR_pre[inl]->size_R_loop() << std::endl;
337+
hamilt::Output_HContainer<TR> out_hrp(h_deltaR_pre[inl], ofs_hrp, sparse_threshold, precision);
338+
out_hrp.write();
339+
ofs_hrp << std::endl;
340+
}
341+
ofs_hrp << std::endl;
342+
}
343+
ofs_hrp.close();
344+
}
315345
}
316346
}
317347

source/module_hamilt_lcao/module_deepks/deepks_pdm.cpp

Lines changed: 46 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
9393
const int lmaxd,
9494
const std::vector<int>& inl2l,
9595
const ModuleBase::IntArray* inl_index,
96+
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
9697
const elecstate::DensityMatrix<TK, double>* dm,
9798
const std::vector<hamilt::HContainer<double>*> phialpha,
9899
const UnitCell& ucell,
@@ -231,7 +232,7 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
231232
}
232233
}
233234

234-
for (int ad2 = 0; ad2 < adjs.adj_num + 1; ad2++)
235+
for (int ad2 = 0; ad2 < adjs.adj_num + 1; ad2++)
235236
{
236237
const int T2 = adjs.ntype[ad2];
237238
const int I2 = adjs.natom[ad2];
@@ -274,33 +275,31 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
274275
// prepare DM from DMR
275276
std::vector<double> dm_array(row_size * col_size, 0.0);
276277
const double* dm_current = nullptr;
277-
for (int is = 0; is < dm->get_DMR_vector().size(); is++)
278+
int dRx = 0, dRy = 0, dRz = 0;
279+
if constexpr (std::is_same<TK, std::complex<double>>::value)
278280
{
279-
int dRx = 0, dRy = 0, dRz = 0;
280-
if constexpr (std::is_same<TK, std::complex<double>>::value)
281-
{
282-
dRx = dR2.x - dR1.x;
283-
dRy = dR2.y - dR1.y;
284-
dRz = dR2.z - dR1.z;
285-
}
286-
// dm_R
287-
auto* tmp = dm->get_DMR_vector()[is]->find_matrix(ibt1, ibt2, dRx, dRy, dRz);
288-
if (tmp == nullptr)
289-
{
290-
// in case of no deepks_scf but out_deepks_label, size of DMR would mismatch with
291-
// deepks-orbitals
292-
dm_current = nullptr;
293-
break;
294-
}
295-
dm_current = tmp->get_pointer();
296-
for (int idm = 0; idm < row_size * col_size; idm++)
297-
{
298-
dm_array[idm] += dm_current[idm];
299-
}
281+
dRx = dR2.x - dR1.x;
282+
dRy = dR2.y - dR1.y;
283+
dRz = dR2.z - dR1.z;
300284
}
301-
if (dm_current == nullptr)
285+
// dm_k
286+
auto dm_k = dm->get_DMK_vector();
287+
const int nrow = pv.nrow;
288+
for (int ir = 0; ir < row_size; ir++)
302289
{
303-
continue; // skip the long range DM pair more than nonlocal term
290+
for (int ic = 0; ic < col_size; ic++)
291+
{
292+
int iglob = (pv.atom_begin_row[ibt1] + ir) + nrow * (pv.atom_begin_col[ibt2] + ic);
293+
int iloc = ir * col_size + ic;
294+
std::complex<double> tmp = 0.0;
295+
for(int ik = 0; ik < dm_k.size(); ik++) // dm_k.size() == _nk * _nspin
296+
{
297+
const double arg = (kvec_d[ik] * ModuleBase::Vector3<double>(dR1 - dR2)) * ModuleBase::TWO_PI;
298+
const std::complex<double> kphase = std::complex<double>(cos(arg), sin(arg));
299+
tmp += dm_k[ik][iglob] * kphase;
300+
}
301+
dm_array[iloc] += tmp.real();
302+
}
304303
}
305304

306305
dm_current = dm_array.data();
@@ -311,18 +310,18 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
311310
constexpr char transa = 'T', transb = 'N';
312311
const double gemm_alpha = 1.0, gemm_beta = 1.0;
313312
dgemm_(&transa,
314-
&transb,
315-
&row_size,
316-
&trace_alpha_size,
317-
&col_size,
318-
&gemm_alpha,
319-
dm_current,
320-
&col_size,
321-
s_2t.data(),
322-
&col_size,
323-
&gemm_beta,
324-
g_1dmt.data(),
325-
&row_size);
313+
&transb,
314+
&row_size,
315+
&trace_alpha_size,
316+
&col_size,
317+
&gemm_alpha,
318+
dm_current,
319+
&col_size,
320+
s_2t.data(),
321+
&col_size,
322+
&gemm_beta,
323+
g_1dmt.data(),
324+
&row_size);
326325
} // ad2
327326
if (!PARAM.inp.deepks_equiv)
328327
{
@@ -340,10 +339,10 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
340339
for (int m2 = 0; m2 < nm; ++m2) // m1 = 1 for s, 3 for p, 5 for d
341340
{
342341
accessor[m1][m2] += ddot_(&row_size,
343-
g_1dmt.data() + index * row_size,
344-
&inc,
345-
s_1t.data() + index * row_size,
346-
&inc);
342+
g_1dmt.data() + index * row_size,
343+
&inc,
344+
s_1t.data() + index * row_size,
345+
&inc);
347346
index++;
348347
}
349348
}
@@ -366,10 +365,10 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
366365
// ddot_: dot product of two vectors
367366
// inc means the increment of the index
368367
accessor[iproj * nproj + jproj] += ddot_(&row_size,
369-
g_1dmt.data() + index * row_size,
370-
&inc,
371-
s_1t.data() + index * row_size,
372-
&inc);
368+
g_1dmt.data() + index * row_size,
369+
&inc,
370+
s_1t.data() + index * row_size,
371+
&inc);
373372
index++;
374373
}
375374
}
@@ -414,6 +413,7 @@ template void DeePKS_domain::cal_pdm<double>(bool& init_pdm,
414413
const int lmaxd,
415414
const std::vector<int>& inl2l,
416415
const ModuleBase::IntArray* inl_index,
416+
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
417417
const elecstate::DensityMatrix<double, double>* dm,
418418
const std::vector<hamilt::HContainer<double>*> phialpha,
419419
const UnitCell& ucell,
@@ -428,6 +428,7 @@ template void DeePKS_domain::cal_pdm<std::complex<double>>(
428428
const int lmaxd,
429429
const std::vector<int>& inl2l,
430430
const ModuleBase::IntArray* inl_index,
431+
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
431432
const elecstate::DensityMatrix<std::complex<double>, double>* dm,
432433
const std::vector<hamilt::HContainer<double>*> phialpha,
433434
const UnitCell& ucell,

source/module_hamilt_lcao/module_deepks/deepks_pdm.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ void cal_pdm(bool& init_pdm,
5050
const int lmaxd,
5151
const std::vector<int>& inl2l,
5252
const ModuleBase::IntArray* inl_index,
53+
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
5354
const elecstate::DensityMatrix<TK, double>* dm,
5455
const std::vector<hamilt::HContainer<double>*> phialpha,
5556
const UnitCell& ucell,

source/module_hamilt_lcao/module_deepks/deepks_vdpre.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ void DeePKS_domain::cal_v_delta_precalc(const int nlocal,
3838
const Grid_Driver& GridD,
3939
torch::Tensor& v_delta_precalc)
4040
{
41-
ModuleBase::TITLE("DeePKS_domain", "calc_v_delta_precalc");
42-
ModuleBase::timer::tick("DeePKS_domain", "calc_v_delta_precalc");
41+
ModuleBase::TITLE("DeePKS_domain", "cal_v_delta_precalc");
42+
ModuleBase::timer::tick("DeePKS_domain", "cal_v_delta_precalc");
4343
// timeval t_start;
4444
// gettimeofday(&t_start,NULL);
4545

@@ -230,7 +230,7 @@ void DeePKS_domain::cal_v_delta_precalc(const int nlocal,
230230
// std::cout<<"calculate v_delta_precalc time:\t"<<(double)(t_end.tv_sec-t_start.tv_sec) +
231231
// (double)(t_end.tv_usec-t_start.tv_usec)/1000000.0<<std::endl;
232232

233-
ModuleBase::timer::tick("DeePKS_domain", "calc_v_delta_precalc");
233+
ModuleBase::timer::tick("DeePKS_domain", "cal_v_delta_precalc");
234234
return;
235235
}
236236

source/module_hamilt_lcao/module_deepks/test/LCAO_deepks_test.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ void test_deepks<T>::check_pdm()
142142
this->ld.lmaxd,
143143
this->ld.inl2l,
144144
this->ld.inl_index,
145+
kv.kvec_d,
145146
p_elec_DM,
146147
this->ld.phialpha,
147148
ucell,

source/module_hamilt_lcao/module_hcontainer/output_hcontainer.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ Output_HContainer<T>::Output_HContainer(hamilt::HContainer<T>* hcontainer,
1818
int precision)
1919
: _hcontainer(hcontainer), _ofs(ofs), _sparse_threshold(sparse_threshold), _precision(precision)
2020
{
21+
if (this->_sparse_threshold == -1)
22+
{
23+
this->_sparse_threshold = 1e-10;
24+
}
25+
if (this->_precision == -1)
26+
{
27+
this->_precision = 8;
28+
}
2129
}
2230

2331
template <typename T>

source/module_hamilt_lcao/module_hcontainer/output_hcontainer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ template <typename T>
1313
class Output_HContainer
1414
{
1515
public:
16-
Output_HContainer(hamilt::HContainer<T>* hcontainer, std::ostream& ofs, double sparse_threshold, int precision);
16+
Output_HContainer(hamilt::HContainer<T>* hcontainer, std::ostream& ofs, double sparse_threshold = -1, int precision = -1);
1717
// write the matrices of all R vectors to the output stream
1818
void write();
1919

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
etotref -465.9986234579913
2-
etotperatomref -155.3328744860
3-
totalforceref 5.535112
4-
totalstressref 1.522354
5-
totaldes 2.163682
6-
deepks_e_dm -57.88576364957592
7-
deepks_f_label 19.095631983991726
8-
deepks_s_label 19.250613228828858
9-
totaltimeref 22.06
1+
etotref -465.9986233931722
2+
etotperatomref -155.3328744644
3+
totalforceref 5.535484
4+
totalstressref 1.522431
5+
totaldes 2.163702
6+
deepks_e_dm -57.88572052925276
7+
deepks_f_label 19.09562238352583
8+
deepks_s_label 19.250613977989474
9+
totaltimeref 14.90
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
etotref -466.8189388994859
2-
etotperatomref -155.6063129665
1+
etotref -466.8189389058687
2+
etotperatomref -155.6063129686
33
totaldes 4.392987
4-
deepks_e_dm -49.145154045309184
5-
odelta 0.05196672366779986
6-
oprec 0.3729036159496012
7-
totaltimeref 11.14
4+
deepks_e_dm -49.145154051850646
5+
odelta 0.05196672144400882
6+
oprec 0.3729036136469043
7+
totaltimeref 11.03
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
etotref -466.8999964506085
2-
etotperatomref -155.6333321502
3-
totalforceref 10.047085
4-
totaltimeref 21.21
1+
etotref -466.8143178204656
2+
etotperatomref -155.6047726068
3+
totalforceref 10.244182
4+
totaltimeref 7.52
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.00123076105
1+
0.009421273749
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
-0.1964491488
1+
-0.1883243273
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
0.01002305073 0.001383599869 -0.008194913779
2-
-0.006744232782 -0.004657772796 0.0005568916214
3-
-0.00327881795 0.003274172927 0.007638022158
1+
0.002181260802 2.610262729e-05 -0.002150222026
2+
-0.001268781011 -0.0007629016191 0.0002663644433
3+
-0.000912479791 0.0007367989918 0.001883857582

tests/deepks/604_NO_deepks_ut_H2O_multik/STRU

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ LATTICE_CONSTANT
66
1
77

88
LATTICE_VECTORS
9-
10 0 0
10-
0 10 0
11-
0 0 10
9+
15 0 0
10+
0 15 0
11+
0 0 15
1212

1313
ATOMIC_POSITIONS
1414
Direct # Cartesian(Unit is LATTICE_CONSTANT)

0 commit comments

Comments
 (0)