Skip to content

Commit 391b41b

Browse files
authored
Refactor: Remove some files in DeePKS and add some features. (#6065)
* Add HR precalc functions for DeePKS and fix some bugs. * Remove and combine some files in DeePKS. * Fix wrong function position.
1 parent 4adc91b commit 391b41b

16 files changed

+274
-323
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ OBJS_DEEPKS=LCAO_deepks.o\
208208
deepks_orbpre.o\
209209
deepks_vdelta.o\
210210
deepks_vdpre.o\
211-
deepks_hmat.o\
211+
deepks_vdrpre.o\
212212
deepks_pdm.o\
213213
deepks_phialpha.o\
214214
LCAO_deepks_io.o\

source/module_hamilt_lcao/module_deepks/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ if(ENABLE_DEEPKS)
1111
deepks_orbpre.cpp
1212
deepks_vdelta.cpp
1313
deepks_vdpre.cpp
14-
deepks_hmat.cpp
14+
deepks_vdrpre.cpp
1515
deepks_pdm.cpp
1616
deepks_phialpha.cpp
1717
LCAO_deepks_io.cpp

source/module_hamilt_lcao/module_deepks/LCAO_deepks.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
#include "deepks_descriptor.h"
88
#include "deepks_force.h"
99
#include "deepks_fpre.h"
10-
#include "deepks_hmat.h"
1110
#include "deepks_orbital.h"
1211
#include "deepks_orbpre.h"
1312
#include "deepks_pdm.h"
1413
#include "deepks_phialpha.h"
1514
#include "deepks_spre.h"
1615
#include "deepks_vdelta.h"
1716
#include "deepks_vdpre.h"
17+
#include "deepks_vdrpre.h"
1818
#include "module_base/complexmatrix.h"
1919
#include "module_base/intarray.h"
2020
#include "module_base/matrix.h"
@@ -56,10 +56,10 @@ class LCAO_Deepks
5656
public:
5757
///(Unit: Ry) Correction energy provided by NN
5858
double E_delta = 0.0;
59-
///(Unit: Ry) \f$tr(\rho H_\delta), \rho = \sum_i{c_{i, \mu}c_{i,\nu}} \f$ (for gamma_only)
59+
///(Unit: Ry) \f$tr(\rho H_\delta), \rho = \sum_i{c_{i, \mu}c_{i,\nu}} \f$
6060
double e_delta_band = 0.0;
6161

62-
/// Correction term to the Hamiltonian matrix: \f$\langle\phi|V_\delta|\phi\rangle\f$ (for gamma only)
62+
/// Correction term to the Hamiltonian matrix: \f$\langle\phi|V_\delta|\phi\rangle\f$
6363
/// The first dimension is for k-points V_delta(k)
6464
std::vector<std::vector<T>> V_delta;
6565

source/module_hamilt_lcao/module_deepks/LCAO_deepks_interface.cpp

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -314,34 +314,18 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
314314
ofs_hr.close();
315315
}
316316

317-
const std::string file_vdrpre = PARAM.globalv.global_out_dir + "deepks_vdrpre.csr";
318-
std::vector<hamilt::HContainer<TR>*> h_deltaR_pre(inlmax);
319-
for (int i = 0; i < inlmax; i++)
320-
{
321-
h_deltaR_pre[i] = new hamilt::HContainer<TR>(*hR_tot);
322-
h_deltaR_pre[i]->set_zero();
323-
}
324-
// DeePKS_domain::cal_vdr_precalc<TR>();
325-
if (rank == 0)
326-
{
327-
std::ofstream ofs_hrp(file_vdrpre, std::ios::out);
328-
for (int iat = 0; iat < nat; iat++)
329-
{
330-
ofs_hrp << "- Index of atom: " << iat << std::endl;
331-
for (int nl = 0; nl < nlmax; nl++)
332-
{
333-
int inl = iat * nlmax + nl;
334-
ofs_hrp << "-- Index of nl: " << nl << std::endl;
335-
ofs_hrp << "Matrix Dimension of H_delta(R): " << h_deltaR_pre[inl]->get_nbasis() << std::endl;
336-
ofs_hrp << "Matrix number of H_delta(R): " << h_deltaR_pre[inl]->size_R_loop() << std::endl;
337-
hamilt::Output_HContainer<TR> out_hrp(h_deltaR_pre[inl], ofs_hrp, sparse_threshold, precision);
338-
out_hrp.write();
339-
ofs_hrp << std::endl;
340-
}
341-
ofs_hrp << std::endl;
342-
}
343-
ofs_hrp.close();
344-
}
317+
torch::Tensor phialpha_r_out;
318+
torch::Tensor R_query;
319+
DeePKS_domain::prepare_phialpha_r(nlocal, lmaxd, inlmax, nat, phialpha, ucell, orb, *ParaV, GridD, phialpha_r_out, R_query);
320+
const std::string file_phialpha_r = PARAM.globalv.global_out_dir + "deepks_phialpha_r.npy";
321+
const std::string file_R_query = PARAM.globalv.global_out_dir + "deepks_R_query.npy";
322+
LCAO_deepks_io::save_tensor2npy<double>(file_phialpha_r, phialpha_r_out, rank);
323+
LCAO_deepks_io::save_tensor2npy<int>(file_R_query, R_query, rank);
324+
325+
torch::Tensor gevdm_out;
326+
DeePKS_domain::prepare_gevdm(nat, lmaxd, inlmax, orb, gevdm, gevdm_out);
327+
const std::string file_gevdm = PARAM.globalv.global_out_dir + "deepks_gevdm.npy";
328+
LCAO_deepks_io::save_tensor2npy<double>(file_gevdm, gevdm_out, rank);
345329
}
346330
}
347331

source/module_hamilt_lcao/module_deepks/LCAO_deepks_io.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -275,18 +275,18 @@ void LCAO_deepks_io::save_tensor2npy(const std::string& file_name, const torch::
275275

276276
std::vector<T> data(tensor.numel());
277277

278-
if constexpr (std::is_same<T, double>::value)
279-
{
280-
std::memcpy(data.data(), tensor.data_ptr<double>(), tensor.numel() * sizeof(double));
281-
}
282-
else
278+
if constexpr (std::is_same<T, std::complex<double>>::value)
283279
{
284280
auto tensor_data = tensor.data_ptr<c10::complex<double>>();
285281
for (size_t i = 0; i < tensor.numel(); ++i)
286282
{
287283
data[i] = std::complex<double>(tensor_data[i].real(), tensor_data[i].imag());
288284
}
289285
}
286+
else
287+
{
288+
std::memcpy(data.data(), tensor.data_ptr<T>(), tensor.numel() * sizeof(T));
289+
}
290290

291291
npy::SaveArrayAsNumpy(file_name, false, shape.size(), shape.data(), data);
292292
}
@@ -313,6 +313,10 @@ template void LCAO_deepks_io::save_npy_h<std::complex<double>>(const std::vector
313313
const int nks,
314314
const int rank);
315315

316+
template void LCAO_deepks_io::save_tensor2npy<int>(const std::string& file_name,
317+
const torch::Tensor& tensor,
318+
const int rank);
319+
316320
template void LCAO_deepks_io::save_tensor2npy<double>(const std::string& file_name,
317321
const torch::Tensor& tensor,
318322
const int rank);

source/module_hamilt_lcao/module_deepks/deepks_hmat.cpp

Lines changed: 0 additions & 117 deletions
This file was deleted.

source/module_hamilt_lcao/module_deepks/deepks_hmat.h

Lines changed: 0 additions & 31 deletions
This file was deleted.

source/module_hamilt_lcao/module_deepks/deepks_vdelta.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,73 @@ void DeePKS_domain::cal_e_delta_band(const std::vector<std::vector<TK>>& dm,
7777
return;
7878
}
7979

80+
template <typename TK, typename TH>
81+
void DeePKS_domain::collect_h_mat(const Parallel_Orbitals& pv,
82+
const std::vector<std::vector<TK>>& h_in,
83+
std::vector<TH>& h_out,
84+
const int nlocal,
85+
const int nks)
86+
{
87+
ModuleBase::TITLE("DeePKS_domain", "collect_h_tot");
88+
89+
// construct the total H matrix
90+
for (int k = 0; k < nks; k++)
91+
{
92+
#ifdef __MPI
93+
int ir = 0;
94+
int ic = 0;
95+
for (int i = 0; i < nlocal; i++)
96+
{
97+
std::vector<TK> lineH(nlocal - i, TK(0.0));
98+
99+
ir = pv.global2local_row(i);
100+
if (ir >= 0)
101+
{
102+
// data collection
103+
for (int j = i; j < nlocal; j++)
104+
{
105+
ic = pv.global2local_col(j);
106+
if (ic >= 0)
107+
{
108+
int iic = 0;
109+
if (ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER(PARAM.inp.ks_solver))
110+
{
111+
iic = ir + ic * pv.nrow;
112+
}
113+
else
114+
{
115+
iic = ir * pv.ncol + ic;
116+
}
117+
lineH[j - i] = h_in[k][iic];
118+
}
119+
}
120+
}
121+
else
122+
{
123+
// do nothing
124+
}
125+
126+
Parallel_Reduce::reduce_all(lineH.data(), nlocal - i);
127+
128+
for (int j = i; j < nlocal; j++)
129+
{
130+
h_out[k](i, j) = lineH[j - i];
131+
h_out[k](j, i) = h_out[k](i, j); // H is a symmetric matrix
132+
}
133+
}
134+
#else
135+
for (int i = 0; i < nlocal; i++)
136+
{
137+
for (int j = i; j < nlocal; j++)
138+
{
139+
h_out[k](i, j) = h_in[k][i * nlocal + j];
140+
h_out[k](j, i) = h_out[k](i, j); // H is a symmetric matrix
141+
}
142+
}
143+
#endif
144+
}
145+
}
146+
80147
template void DeePKS_domain::cal_e_delta_band<double>(const std::vector<std::vector<double>>& dm,
81148
const std::vector<std::vector<double>>& V_delta,
82149
const int nks,
@@ -89,4 +156,18 @@ template void DeePKS_domain::cal_e_delta_band<std::complex<double>>(
89156
const Parallel_Orbitals* pv,
90157
double& e_delta_band);
91158

159+
template void DeePKS_domain::collect_h_mat<double, ModuleBase::matrix>(
160+
const Parallel_Orbitals& pv,
161+
const std::vector<std::vector<double>>& h_in,
162+
std::vector<ModuleBase::matrix>& h_out,
163+
const int nlocal,
164+
const int nks);
165+
166+
template void DeePKS_domain::collect_h_mat<std::complex<double>, ModuleBase::ComplexMatrix>(
167+
const Parallel_Orbitals& pv,
168+
const std::vector<std::vector<std::complex<double>>>& h_in,
169+
std::vector<ModuleBase::ComplexMatrix>& h_out,
170+
const int nlocal,
171+
const int nks);
172+
92173
#endif

0 commit comments

Comments
 (0)