Skip to content

Commit ef53ec1

Browse files
QianruipkuWHUweiqingzhoumohanchen
authored
Fix: bug of out_band when kpar > 1 (#3963)
* Fix: bug of out_band when kpar > 1 * fix bug in UT of Gatherkvec * add ndim variables in parallel_kpoints * initialize variables in parallel_kpoints --------- Co-authored-by: wqzhou <[email protected]> Co-authored-by: Mohan Chen <[email protected]>
1 parent 426fc92 commit ef53ec1

File tree

9 files changed

+198
-82
lines changed

9 files changed

+198
-82
lines changed

source/module_cell/parallel_kpoints.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ void Parallel_Kpoints::kinfo(int &nkstot)
2525
this->get_nks_pool(nkstot);
2626
this->get_startk_pool(nkstot);
2727
this->get_whichpool(nkstot);
28+
this->kpar = GlobalV::KPAR;
29+
this->nkstot_np = nkstot;
30+
this->nks_np = this->nks_pool[GlobalV::MY_POOL];
31+
#else
32+
this->kpar = 1;
33+
this->nkstot_np = nkstot;
34+
this->nks_np = nkstot;
2835
#endif
2936
return;
3037
}
@@ -91,6 +98,19 @@ void Parallel_Kpoints::get_startk_pool(const int &nkstot)
9198
}
9299
return;
93100
}
101+
102+
void Parallel_Kpoints::gatherkvec(const std::vector<ModuleBase::Vector3<double>>& vec_local,
103+
std::vector<ModuleBase::Vector3<double>>& vec_global) const
104+
{
105+
vec_global.resize(this->nkstot_np, ModuleBase::Vector3<double>(0.0, 0.0, 0.0));
106+
for (int i = 0; i < this->nks_np; ++i)
107+
{
108+
vec_global[i + startk_pool[GlobalV::MY_POOL]] = vec_local[i] / double(GlobalV::NPROC_IN_POOL);
109+
}
110+
111+
MPI_Allreduce(MPI_IN_PLACE, &vec_global[0], 3 * this->nkstot_np, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
112+
return;
113+
}
94114
#endif
95115

96116

source/module_cell/parallel_kpoints.h

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,61 @@
11
#ifndef PARALLEL_KPOINTS_H
22
#define PARALLEL_KPOINTS_H
33

4+
#include "module_base/complexarray.h"
45
#include "module_base/global_function.h"
56
#include "module_base/global_variable.h"
6-
#include "module_base/complexarray.h"
77
#include "module_base/realarray.h"
8+
#include "module_base/vector3.h"
89

910
class Parallel_Kpoints
1011
{
11-
public:
12-
13-
Parallel_Kpoints();
14-
~Parallel_Kpoints();
15-
16-
void kinfo(int &nkstot);
17-
18-
// collect value from each pool to wk.
19-
void pool_collection(double &value, const double *wk, const int &ik);
20-
21-
// collect value from each pool to overlap.
22-
void pool_collection(double *valuea, double *valueb, const ModuleBase::realArray &a, const ModuleBase::realArray &b, const int &ik);
23-
void pool_collection(std::complex<double> *value, const ModuleBase::ComplexArray &w, const int &ik);
12+
public:
13+
Parallel_Kpoints();
14+
~Parallel_Kpoints();
15+
16+
void kinfo(int& nkstot);
17+
18+
// collect value from each pool to wk.
19+
void pool_collection(double& value, const double* wk, const int& ik);
20+
21+
// collect value from each pool to overlap.
22+
void pool_collection(double* valuea,
23+
double* valueb,
24+
const ModuleBase::realArray& a,
25+
const ModuleBase::realArray& b,
26+
const int& ik);
27+
void pool_collection(std::complex<double>* value, const ModuleBase::ComplexArray& w, const int& ik);
28+
#ifdef __MPI
29+
/**
30+
* @brief gather kpoints from all processors
31+
*
32+
* @param vec_local kpoint vector in local processor
33+
* @param vec_global kpoint vector in all processors
34+
*/
35+
void gatherkvec(const std::vector<ModuleBase::Vector3<double>>& vec_local,
36+
std::vector<ModuleBase::Vector3<double>>& vec_global) const;
37+
#endif
2438

25-
// information about pool
26-
int *nproc_pool;
27-
int *startpro_pool;
39+
// information about pool, dim: GlobalV::KPAR
40+
int* nproc_pool = nullptr;
41+
int* startpro_pool = nullptr;
2842

29-
// inforamation about kpoints //qianrui add comment
30-
int* nks_pool; //number of k-points in each pool
31-
int* startk_pool; //the first k-point in each pool
32-
int* whichpool; //whichpool[k] : the pool which k belongs to
43+
// inforamation about kpoints, dim: GlobalV::KPAR
44+
int* nks_pool = nullptr; // number of k-points in each pool
45+
int* startk_pool = nullptr; // the first k-point in each pool
46+
int kpar = 0; // number of pools
3347

34-
private:
48+
// information about which pool each k-point belongs to,
49+
int* whichpool = nullptr; // whichpool[k] : the pool which k belongs to, dim: nkstot_np
50+
int nkstot_np = 0; // number of k-points without spin, kv.nkstot = nkstot_np * nspin(1 or 2)
51+
int nks_np = 0; // number of k-points without spin in the present pool
3552

53+
private:
3654
#ifdef __MPI
37-
void get_nks_pool(const int &nkstot);
38-
void get_startk_pool(const int &nkstot);
39-
void get_whichpool(const int &nkstot);
55+
void get_nks_pool(const int& nkstot);
56+
void get_startk_pool(const int& nkstot);
57+
void get_whichpool(const int& nkstot);
4058
#endif
41-
42-
4359
};
4460

4561
#endif

source/module_cell/test/parallel_kpoints_test.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
* to call another three functions: get_nks_pool(),
2424
* get_startk_pool(), get_whichpool(), which divide all kpoints
2525
* into KPAR groups.
26+
* iii.Parallel_Kpoints::gatherkvec() is an interface to gather kpoints
27+
* vectors from all processors.
2628
* The default number of processes is set to 4 in parallel_kpoints_test.sh.
2729
* One may modify it to do more tests, or adapt this unittest to local
2830
* environment.
@@ -120,6 +122,88 @@ class ParaKpoints : public ::testing::TestWithParam<ParaPrepare>
120122
{
121123
};
122124

125+
TEST(Parallel_KpointsTest, GatherkvecTest) {
126+
// Initialize Parallel_Kpoints object
127+
Parallel_Kpoints parallel_kpoints;
128+
129+
// Initialize local and global vectors
130+
std::vector<ModuleBase::Vector3<double>> vec_local;
131+
std::vector<ModuleBase::Vector3<double>> vec_global;
132+
133+
// Populate vec_local with some data
134+
int npool = 1;
135+
if(GlobalV::NPROC > 2)
136+
{
137+
npool = 3;
138+
}
139+
else if(GlobalV::NPROC == 2)
140+
{
141+
npool = 2;
142+
}
143+
GlobalV::KPAR = npool;
144+
145+
if(GlobalV::MY_RANK == 0)
146+
{
147+
vec_local.push_back(ModuleBase::Vector3<double>(1.0, 1.0, 1.0));
148+
GlobalV::NPROC_IN_POOL = 1;
149+
GlobalV::MY_POOL = 0;
150+
parallel_kpoints.nks_np = 1;
151+
}
152+
else if(GlobalV::MY_RANK == 1)
153+
{
154+
vec_local.push_back(ModuleBase::Vector3<double>(2.0, 2.0, 2.0));
155+
vec_local.push_back(ModuleBase::Vector3<double>(3.0, 4.0, 5.0));
156+
GlobalV::NPROC_IN_POOL = 1;
157+
GlobalV::MY_POOL = 1;
158+
parallel_kpoints.nks_np = 2;
159+
}
160+
else
161+
{
162+
vec_local.push_back(ModuleBase::Vector3<double>(3.0, 3.0, 3.0));
163+
GlobalV::NPROC_IN_POOL = GlobalV::NPROC - 2;
164+
GlobalV::MY_POOL = 2;
165+
parallel_kpoints.nks_np = 1;
166+
}
167+
168+
parallel_kpoints.startk_pool = new int[npool];
169+
parallel_kpoints.nkstot_np = 1;
170+
parallel_kpoints.startk_pool[0] = 0;
171+
if (npool >= 2)
172+
{
173+
parallel_kpoints.nkstot_np += 2;
174+
parallel_kpoints.startk_pool[1] = 1;
175+
}
176+
if(npool >= 3)
177+
{
178+
parallel_kpoints.nkstot_np += 1;
179+
parallel_kpoints.startk_pool[2] = 3;
180+
}
181+
182+
// Call gatherkvec method
183+
parallel_kpoints.gatherkvec(vec_local, vec_global);
184+
185+
// Check the values of vec_global
186+
EXPECT_EQ(vec_global[0].x, 1.0);
187+
EXPECT_EQ(vec_global[0].y, 1.0);
188+
EXPECT_EQ(vec_global[0].z, 1.0);
189+
190+
if(npool >= 2)
191+
{
192+
EXPECT_EQ(vec_global[1].x, 2.0);
193+
EXPECT_EQ(vec_global[1].y, 2.0);
194+
EXPECT_EQ(vec_global[1].z, 2.0);
195+
EXPECT_EQ(vec_global[2].x, 3.0);
196+
EXPECT_EQ(vec_global[2].y, 4.0);
197+
EXPECT_EQ(vec_global[2].z, 5.0);
198+
}
199+
if(npool >= 3)
200+
{
201+
EXPECT_EQ(vec_global[3].x, 3.0);
202+
EXPECT_EQ(vec_global[3].y, 3.0);
203+
EXPECT_EQ(vec_global[3].z, 3.0);
204+
}
205+
}
206+
123207
TEST_P(ParaKpoints,DividePools)
124208
{
125209
ParaPrepare pp = GetParam();

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -393,24 +393,13 @@ void ESolver_KS_LCAO<TK, TR>::post_process(void)
393393

394394
if (INPUT.out_band[0]) // pengfei 2014-10-13
395395
{
396-
int nks = 0;
397-
if (nspin0 == 1)
398-
{
399-
nks = this->kv.nkstot;
400-
}
401-
else if (nspin0 == 2)
402-
{
403-
nks = this->kv.nkstot / 2;
404-
}
405-
406396
for (int is = 0; is < nspin0; is++)
407397
{
408398
std::stringstream ss2;
409399
ss2 << GlobalV::global_out_dir << "BANDS_" << is + 1 << ".dat";
410400
GlobalV::ofs_running << "\n Output bands in file: " << ss2.str() << std::endl;
411401
ModuleIO::nscf_band(is,
412402
ss2.str(),
413-
nks,
414403
GlobalV::NBANDS,
415404
0.0,
416405
INPUT.out_band[1],

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,23 +1257,13 @@ void ESolver_KS_PW<T, Device>::post_process(void)
12571257

12581258
if (INPUT.out_band[0]) // pengfei 2014-10-13
12591259
{
1260-
int nks = 0;
1261-
if (nspin0 == 1)
1262-
{
1263-
nks = this->kv.nkstot;
1264-
}
1265-
else if (nspin0 == 2)
1266-
{
1267-
nks = this->kv.nkstot / 2;
1268-
}
12691260
for (int is = 0; is < nspin0; is++)
12701261
{
12711262
std::stringstream ss2;
12721263
ss2 << GlobalV::global_out_dir << "BANDS_" << is + 1 << ".dat";
12731264
GlobalV::ofs_running << "\n Output bands in file: " << ss2.str() << std::endl;
12741265
ModuleIO::nscf_band(is,
12751266
ss2.str(),
1276-
nks,
12771267
GlobalV::NBANDS,
12781268
0.0,
12791269
INPUT.out_band[1],

source/module_io/nscf_band.cpp

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
void ModuleIO::nscf_band(
99
const int &is,
1010
const std::string &out_band_dir,
11-
const int &nks,
1211
const int &nband,
1312
const double &fermie,
1413
const int &precision,
@@ -18,6 +17,10 @@ void ModuleIO::nscf_band(
1817
{
1918
ModuleBase::TITLE("ModuleIO","nscf_band");
2019
ModuleBase::timer::tick("ModuleIO", "nscf_band");
20+
// number of k points without spin; nspin = 1,2, nkstot = nkstot_np * nspin;
21+
// nspin = 4, nkstot = nkstot_np
22+
const int nkstot_np = Pkpoints->nkstot_np;
23+
const int nks_np = Pkpoints->nks_np;
2124

2225
#ifdef __MPI
2326
if(GlobalV::MY_RANK==0)
@@ -26,15 +29,16 @@ void ModuleIO::nscf_band(
2629
ofs.close();
2730
}
2831
MPI_Barrier(MPI_COMM_WORLD);
29-
3032
std::vector<double> klength;
31-
klength.resize(nks);
33+
klength.resize(nkstot_np);
3234
klength[0] = 0.0;
33-
for(int ik=0; ik<nks; ik++)
35+
std::vector<ModuleBase::Vector3<double>> kvec_c_global;
36+
Pkpoints->gatherkvec(kv.kvec_c, kvec_c_global);
37+
for(int ik=0; ik<nkstot_np; ik++)
3438
{
3539
if (ik>0)
3640
{
37-
auto delta=kv.kvec_c[ik]-kv.kvec_c[ik-1];
41+
auto delta=kvec_c_global[ik]-kvec_c_global[ik-1];
3842
klength[ik] = klength[ik-1];
3943
klength[ik] += (kv.kl_segids[ik] == kv.kl_segids[ik-1]) ? delta.norm() : 0.0;
4044
}
@@ -44,23 +48,21 @@ void ModuleIO::nscf_band(
4448
/* then get the local kpoint index, which starts definitly from 0 */
4549
const int ik_now = ik - Pkpoints->startk_pool[GlobalV::MY_POOL];
4650
/* if present kpoint corresponds the spin of the present one */
47-
if( kv.isk[ik_now+is*nks] == is )
48-
{
49-
if ( GlobalV::RANK_IN_POOL == 0)
51+
assert( kv.isk[ik_now+is*nks_np] == is );
52+
if ( GlobalV::RANK_IN_POOL == 0)
53+
{
54+
formatter::PhysicalFmt physfmt; // create a physical formatter temporarily
55+
std::ofstream ofs(out_band_dir.c_str(), std::ios::app);
56+
physfmt.adjust_formatter_flexible(4, 0, false); // for integer
57+
ofs << physfmt.get_p_formatter()->format(ik+1);
58+
physfmt.adjust_formatter_flexible(precision, 4.0/double(precision), false); // for decimal
59+
ofs << physfmt.get_p_formatter()->format(klength[ik]);
60+
for(int ib = 0; ib < nband; ib++)
5061
{
51-
formatter::PhysicalFmt physfmt; // create a physical formatter temporarily
52-
std::ofstream ofs(out_band_dir.c_str(), std::ios::app);
53-
physfmt.adjust_formatter_flexible(4, 0, false); // for integer
54-
ofs << physfmt.get_p_formatter()->format(ik+1);
55-
physfmt.adjust_formatter_flexible(precision, 4.0/double(precision), false); // for decimal
56-
ofs << physfmt.get_p_formatter()->format(klength[ik]);
57-
for(int ib = 0; ib < nband; ib++)
58-
{
59-
ofs << physfmt.get_p_formatter()->format((ekb(ik_now+is*nks, ib)-fermie) * ModuleBase::Ry_to_eV);
60-
}
61-
ofs << std::endl;
62-
ofs.close();
62+
ofs << physfmt.get_p_formatter()->format((ekb(ik_now+is*nks_np, ib)-fermie) * ModuleBase::Ry_to_eV);
6363
}
64+
ofs << std::endl;
65+
ofs.close();
6466
}
6567
}
6668
MPI_Barrier(MPI_COMM_WORLD);
@@ -73,7 +75,7 @@ void ModuleIO::nscf_band(
7375
if(GlobalV::MY_POOL == ip && GlobalV::RANK_IN_POOL == 0)
7476
{
7577
std::ofstream ofs(out_band_dir.c_str(),ios::app);
76-
for(int ik=0;ik<nks;ik++)
78+
for(int ik=0;ik<nkstot_np;ik++)
7779
{
7880
ofs<<std::setw(12)<<ik;
7981
for(int ib = 0; ib < nband; ib++)
@@ -92,10 +94,10 @@ void ModuleIO::nscf_band(
9294
// std::cout<<out_band_dir<<std::endl;
9395
formatter::PhysicalFmt physfmt; // create a physical formatter temporarily
9496
std::vector<double> klength;
95-
klength.resize(nks);
97+
klength.resize(nkstot_np);
9698
klength[0] = 0.0;
9799
std::ofstream ofs(out_band_dir.c_str());
98-
for(int ik=0;ik<nks;ik++)
100+
for(int ik=0;ik<nkstot_np;ik++)
99101
{
100102
if (ik>0)
101103
{

0 commit comments

Comments
 (0)