Skip to content

Commit 142adc2

Browse files
authored
refactor: remove template on relax_driver (#4172)
1 parent c6f8609 commit 142adc2

File tree

7 files changed

+51
-32
lines changed

7 files changed

+51
-32
lines changed

source/driver_run.cpp

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,8 @@ void Driver::driver_run(void)
5757
}
5858
else //! scf; cell relaxation; nscf; etc
5959
{
60-
// mixed-precision should not be like this, mohan 2024-05-12,
61-
// DEVICE should not depend on psi
62-
if (GlobalV::precision_flag == "single")
63-
{
64-
Relax_Driver<float, base_device::DEVICE_CPU> rl_driver;
65-
rl_driver.relax_driver(p_esolver);
66-
}
67-
else
68-
{
69-
Relax_Driver<double, base_device::DEVICE_CPU> rl_driver;
70-
rl_driver.relax_driver(p_esolver);
71-
}
60+
Relax_Driver rl_driver;
61+
rl_driver.relax_driver(p_esolver);
7262
}
7363
// "others" in ESolver should be here.
7464

source/module_esolver/esolver.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,19 @@ class ESolver
4747

4848
// temporarily
4949
// get iterstep used in current scf
50-
virtual int getniter()
50+
virtual int get_niter()
51+
{
52+
return 0;
53+
}
54+
55+
// get maxniter used in current scf
56+
virtual int get_maxniter()
57+
{
58+
return 0;
59+
}
60+
61+
// get conv_elec used in current scf
62+
virtual bool get_conv_elec()
5163
{
5264
return 0;
5365
}

source/module_esolver/esolver_ks.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -716,14 +716,33 @@ void ESolver_KS<T, Device>::write_head(std::ofstream& ofs_running, const int ist
716716
//! mohan add 2024-05-12
717717
//------------------------------------------------------------------------------
718718
template<typename T, typename Device>
719-
int ESolver_KS<T, Device>::getniter()
719+
int ESolver_KS<T, Device>::get_niter()
720720
{
721721
return this->niter;
722722
}
723723

724+
//------------------------------------------------------------------------------
725+
//! the 11th function of ESolver_KS: get_maxniter
726+
//! tqzhao add 2024-05-15
727+
//------------------------------------------------------------------------------
728+
template<typename T, typename Device>
729+
int ESolver_KS<T, Device>::get_maxniter()
730+
{
731+
return this->maxniter;
732+
}
733+
734+
//------------------------------------------------------------------------------
735+
//! the 12th function of ESolver_KS: get_conv_elec
736+
//! tqzhao add 2024-05-15
737+
//------------------------------------------------------------------------------
738+
template<typename T, typename Device>
739+
bool ESolver_KS<T, Device>::get_conv_elec()
740+
{
741+
return this->conv_elec;
742+
}
724743

725744
//------------------------------------------------------------------------------
726-
//! the 11th function of ESolver_KS: create_Output_Rho
745+
//! the 13th function of ESolver_KS: create_Output_Rho
727746
//! mohan add 2024-05-12
728747
//------------------------------------------------------------------------------
729748
template<typename T, typename Device>
@@ -765,7 +784,7 @@ ModuleIO::Output_Rho ESolver_KS<T, Device>::create_Output_Rho(
765784

766785

767786
//------------------------------------------------------------------------------
768-
//! the 12th function of ESolver_KS: create_Output_Kin
787+
//! the 14th function of ESolver_KS: create_Output_Kin
769788
//! mohan add 2024-05-12
770789
//------------------------------------------------------------------------------
771790
template<typename T, typename Device>
@@ -789,7 +808,7 @@ ModuleIO::Output_Rho ESolver_KS<T, Device>::create_Output_Kin(int is, int iter,
789808

790809

791810
//------------------------------------------------------------------------------
792-
//! the 13th function of ESolver_KS: create_Output_Potential
811+
//! the 15th function of ESolver_KS: create_Output_Potential
793812
//! mohan add 2024-05-12
794813
//------------------------------------------------------------------------------
795814
template<typename T, typename Device>
@@ -814,7 +833,7 @@ ModuleIO::Output_Potential ESolver_KS<T, Device>::create_Output_Potential(int it
814833

815834

816835
//------------------------------------------------------------------------------
817-
//! the 14th-18th functions of ESolver_KS
836+
//! the 16th-20th functions of ESolver_KS
818837
//! mohan add 2024-05-12
819838
//------------------------------------------------------------------------------
820839
//! This is for mixed-precision pw/LCAO basis sets.

source/module_esolver/esolver_ks.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,13 @@ class ESolver_KS : public ESolver_FP
5656
virtual void hamilt2estates(const double ethr){};
5757

5858
// get current step of Ionic simulation
59-
virtual int getniter() override;
59+
virtual int get_niter() override;
60+
61+
// get maxniter used in current scf
62+
virtual int get_maxniter() override;
63+
64+
// get conv_elec used in current scf
65+
virtual bool get_conv_elec() override;
6066

6167
protected:
6268
//! Something to do before SCF iterations.

source/module_esolver/esolver_of.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class ESolver_OF : public ESolver_FP
3434

3535
virtual void cal_stress(ModuleBase::matrix& stress) override;
3636

37-
virtual int getniter() override
37+
virtual int get_niter() override
3838
{
3939
return this->iter_;
4040
}

source/module_relax/relax_driver.cpp

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
#include "module_io/json_output/output_info.h"
99

1010

11-
12-
template<typename FPTYPE, typename Device>
13-
void Relax_Driver<FPTYPE, Device>::relax_driver(ModuleESolver::ESolver *p_esolver)
11+
void Relax_Driver::relax_driver(ModuleESolver::ESolver *p_esolver)
1412
{
1513
ModuleBase::TITLE("Ions", "opt_ions");
1614
ModuleBase::timer::tick("Ions", "opt_ions");
@@ -115,12 +113,10 @@ void Relax_Driver<FPTYPE, Device>::relax_driver(ModuleESolver::ESolver *p_esolve
115113
GlobalC::ucell.print_cell_cif("STRU_NOW.cif");
116114
}
117115

118-
ModuleESolver::ESolver_KS<FPTYPE, Device>* p_esolver_ks
119-
= dynamic_cast<ModuleESolver::ESolver_KS<FPTYPE, Device>*>(p_esolver);
120-
if (p_esolver_ks
116+
if (p_esolver
121117
&& stop
122-
&& p_esolver_ks->maxniter == p_esolver_ks->niter
123-
&& !(p_esolver_ks->conv_elec))
118+
&& p_esolver->get_maxniter() == p_esolver->get_niter()
119+
&& !(p_esolver->get_conv_elec()))
124120
{
125121
std::cout << "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" << std::endl;
126122
std::cout << "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" << std::endl;
@@ -165,6 +161,3 @@ void Relax_Driver<FPTYPE, Device>::relax_driver(ModuleESolver::ESolver *p_esolve
165161
ModuleBase::timer::tick("Ions", "opt_ions");
166162
return;
167163
}
168-
169-
template class Relax_Driver<float, base_device::DEVICE_CPU>;
170-
template class Relax_Driver<double, base_device::DEVICE_CPU>;

source/module_relax/relax_driver.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include "relax_new/relax.h"
77
#include "relax_old/relax_old.h"
88

9-
template <typename FPTYPE, typename Device = base_device::DEVICE_CPU>
109
class Relax_Driver
1110
{
1211

0 commit comments

Comments
 (0)