Skip to content

Commit 2610f1e

Browse files
authored
Add templates for recip_to_real in the pw_basis (#6023)
* update recip_to_real in the rhpw * add the operator in pw_basis * remove ctx in the file * remove ctx in bundle * update compile bug * add T,device * moidfy back the func * update func
1 parent d591a60 commit 2610f1e

File tree

16 files changed

+798
-152
lines changed

16 files changed

+798
-152
lines changed

source/module_basis/module_pw/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ list(APPEND objects
3030
pw_distributer.cpp
3131
pw_init.cpp
3232
pw_transform.cpp
33+
pw_transform_gpu.cpp
3334
pw_transform_k.cpp
3435
module_fft/fft_bundle.cpp
3536
module_fft/fft_cpu.cpp

source/module_basis/module_pw/kernels/cuda/pw_op.cu

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,24 @@ __global__ void set_recip_to_real_output(
4141
}
4242
}
4343

44+
template<class FPTYPE>
45+
__global__ void set_recip_to_real_output(
46+
const int nrxx,
47+
const bool add,
48+
const FPTYPE factor,
49+
const thrust::complex<FPTYPE>* in,
50+
FPTYPE* out)
51+
{
52+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
53+
if(idx >= nrxx) {return;}
54+
if(add) {
55+
out[idx] += factor * in[idx].real();
56+
}
57+
else {
58+
out[idx] = in[idx].real();
59+
}
60+
}
61+
4462
template<class FPTYPE>
4563
__global__ void set_real_to_recip_output(
4664
const int npwk,
@@ -61,9 +79,28 @@ __global__ void set_real_to_recip_output(
6179
}
6280
}
6381

82+
template<class FPTYPE>
83+
__global__ void set_real_to_recip_output(
84+
const int npwk,
85+
const int nxyz,
86+
const bool add,
87+
const FPTYPE factor,
88+
const int* box_index,
89+
const thrust::complex<FPTYPE>* in,
90+
FPTYPE* out)
91+
{
92+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
93+
if(idx >= npwk) {return;}
94+
if(add) {
95+
out[idx] += factor / nxyz * in[box_index[idx]].real();
96+
}
97+
else {
98+
out[idx] = in[box_index[idx]].real() / nxyz;
99+
}
100+
}
101+
64102
template <typename FPTYPE>
65-
void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
66-
const int npwk,
103+
void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk,
67104
const int* box_index,
68105
const std::complex<FPTYPE>* in,
69106
std::complex<FPTYPE>* out)
@@ -79,8 +116,7 @@ void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_d
79116
}
80117

81118
template <typename FPTYPE>
82-
void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
83-
const int nrxx,
119+
void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int nrxx,
84120
const bool add,
85121
const FPTYPE factor,
86122
const std::complex<FPTYPE>* in,
@@ -98,8 +134,25 @@ void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co
98134
}
99135

100136
template <typename FPTYPE>
101-
void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
102-
const int npwk,
137+
void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int nrxx,
138+
const bool add,
139+
const FPTYPE factor,
140+
const std::complex<FPTYPE>* in,
141+
FPTYPE* out)
142+
{
143+
const int block = (nrxx + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
144+
set_recip_to_real_output<FPTYPE><<<block, THREADS_PER_BLOCK>>>(
145+
nrxx,
146+
add,
147+
factor,
148+
reinterpret_cast<const thrust::complex<FPTYPE>*>(in),
149+
reinterpret_cast<FPTYPE*>(out));
150+
151+
cudaCheckOnDebug();
152+
}
153+
154+
template <typename FPTYPE>
155+
void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk,
103156
const int nxyz,
104157
const bool add,
105158
const FPTYPE factor,
@@ -120,6 +173,28 @@ void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co
120173
cudaCheckOnDebug();
121174
}
122175

176+
template <typename FPTYPE>
177+
void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk,
178+
const int nxyz,
179+
const bool add,
180+
const FPTYPE factor,
181+
const int* box_index,
182+
const std::complex<FPTYPE>* in,
183+
FPTYPE* out)
184+
{
185+
const int block = (npwk + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
186+
set_real_to_recip_output<FPTYPE><<<block, THREADS_PER_BLOCK>>>(
187+
npwk,
188+
nxyz,
189+
add,
190+
factor,
191+
box_index,
192+
reinterpret_cast<const thrust::complex<FPTYPE>*>(in),
193+
reinterpret_cast<FPTYPE*>(out));
194+
195+
cudaCheckOnDebug();
196+
}
197+
123198
template struct set_3d_fft_box_op<float, base_device::DEVICE_GPU>;
124199
template struct set_recip_to_real_output_op<float, base_device::DEVICE_GPU>;
125200
template struct set_real_to_recip_output_op<float, base_device::DEVICE_GPU>;

source/module_basis/module_pw/kernels/pw_op.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ namespace ModulePW {
55
template <typename FPTYPE>
66
struct set_3d_fft_box_op<FPTYPE, base_device::DEVICE_CPU>
77
{
8-
void operator()(const base_device::DEVICE_CPU* /*dev*/,
9-
const int npwk,
8+
void operator()(const int npwk,
109
const int* box_index,
1110
const std::complex<FPTYPE>* in,
1211
std::complex<FPTYPE>* out)
@@ -21,8 +20,7 @@ struct set_3d_fft_box_op<FPTYPE, base_device::DEVICE_CPU>
2120
template <typename FPTYPE>
2221
struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_CPU>
2322
{
24-
void operator()(const base_device::DEVICE_CPU* /*dev*/,
25-
const int nrxx,
23+
void operator()(const int nrxx,
2624
const bool add,
2725
const FPTYPE factor,
2826
const std::complex<FPTYPE>* in,
@@ -39,13 +37,34 @@ struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_CPU>
3937
}
4038
}
4139
}
40+
41+
void operator()(const int nrxx,
42+
const bool add,
43+
const FPTYPE factor,
44+
const std::complex<FPTYPE>* in,
45+
FPTYPE* out)
46+
{
47+
if (add)
48+
{
49+
for (int ir = 0; ir < nrxx; ++ir)
50+
{
51+
out[ir] += factor * in[ir].real();
52+
}
53+
}
54+
else
55+
{
56+
for (int ir = 0; ir < nrxx; ++ir)
57+
{
58+
out[ir] = in[ir].real();
59+
}
60+
}
61+
}
4262
};
4363

4464
template <typename FPTYPE>
4565
struct set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_CPU>
4666
{
47-
void operator()(const base_device::DEVICE_CPU* /*dev*/,
48-
const int npw_k,
67+
void operator()(const int npw_k,
4968
const int nxyz,
5069
const bool add,
5170
const FPTYPE factor,

source/module_basis/module_pw/kernels/pw_op.h

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ struct set_3d_fft_box_op {
1919
/// Output Parameters
2020
/// @param out - output psi within the 3D box(in recip space)
2121
void operator() (
22-
const Device* dev,
2322
const int npwk,
2423
const int* box_index,
2524
const std::complex<FPTYPE>* in,
@@ -39,12 +38,18 @@ struct set_recip_to_real_output_op {
3938
/// Output Parameters
4039
/// @param out - output psi within the 3D box(in real space)
4140
void operator() (
42-
const Device* dev,
4341
const int nrxx,
4442
const bool add,
4543
const FPTYPE factor,
4644
const std::complex<FPTYPE>* in,
4745
std::complex<FPTYPE>* out);
46+
47+
void operator() (
48+
const int nrxx,
49+
const bool add,
50+
const FPTYPE factor,
51+
const std::complex<FPTYPE>* in,
52+
FPTYPE* out);
4853
};
4954

5055
template <typename FPTYPE, typename Device>
@@ -62,23 +67,30 @@ struct set_real_to_recip_output_op {
6267
/// Output Parameters
6368
/// @param out - output psi within the 3D box(in recip space)
6469
void operator() (
65-
const Device* dev,
6670
const int npw_k,
6771
const int nxyz,
6872
const bool add,
6973
const FPTYPE factor,
7074
const int* box_index,
7175
const std::complex<FPTYPE>* in,
7276
std::complex<FPTYPE>* out);
77+
78+
void operator() (
79+
const int npw_k,
80+
const int nxyz,
81+
const bool add,
82+
const FPTYPE factor,
83+
const int* box_index,
84+
const std::complex<FPTYPE>* in,
85+
FPTYPE* out);
7386
};
7487

7588
#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
7689
// Partially specialize functor for base_device::GpuDevice.
7790
template <typename FPTYPE>
7891
struct set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>
7992
{
80-
void operator()(const base_device::DEVICE_GPU* dev,
81-
const int npwk,
93+
void operator()(const int npwk,
8294
const int* box_index,
8395
const std::complex<FPTYPE>* in,
8496
std::complex<FPTYPE>* out);
@@ -87,25 +99,36 @@ struct set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>
8799
template <typename FPTYPE>
88100
struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>
89101
{
90-
void operator()(const base_device::DEVICE_GPU* dev,
91-
const int nrxx,
102+
void operator()(const int nrxx,
92103
const bool add,
93104
const FPTYPE factor,
94105
const std::complex<FPTYPE>* in,
95106
std::complex<FPTYPE>* out);
107+
108+
void operator()(const int nrxx,
109+
const bool add,
110+
const FPTYPE factor,
111+
const std::complex<FPTYPE>* in,
112+
FPTYPE* out);
96113
};
97114

98115
template <typename FPTYPE>
99116
struct set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>
100117
{
101-
void operator()(const base_device::DEVICE_GPU* dev,
102-
const int npw_k,
118+
void operator()(const int npw_k,
103119
const int nxyz,
104120
const bool add,
105121
const FPTYPE factor,
106122
const int* box_index,
107123
const std::complex<FPTYPE>* in,
108124
std::complex<FPTYPE>* out);
125+
void operator()(const int npw_k,
126+
const int nxyz,
127+
const bool add,
128+
const FPTYPE factor,
129+
const int* box_index,
130+
const std::complex<FPTYPE>* in,
131+
FPTYPE* out);
109132
};
110133

111134
#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM

0 commit comments

Comments
 (0)