Skip to content

Commit 824f140

Browse files
authored
[PHI] Fix grid sample 3d kernel for big tensor (#73253)
1 parent 87465ef commit 824f140

File tree

1 file changed

+124
-122
lines changed

1 file changed

+124
-122
lines changed

paddle/phi/kernels/gpu/grid_sample_kernel.cu

Lines changed: 124 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -152,45 +152,45 @@ __global__ void GridSampleCudaKernel(IndexT n,
152152
}
153153
}
154154

155-
template <typename T>
156-
__global__ void GridSample3DCudaKernel(const int nthreads,
157-
int out_c,
158-
int out_d,
159-
int out_h,
160-
int out_w,
161-
int in_d,
162-
int in_h,
163-
int in_w,
155+
template <typename T, typename IndexT>
156+
__global__ void GridSample3DCudaKernel(const IndexT nthreads,
157+
IndexT out_c,
158+
IndexT out_d,
159+
IndexT out_h,
160+
IndexT out_w,
161+
IndexT in_d,
162+
IndexT in_h,
163+
IndexT in_w,
164164
const T* input,
165165
const T* grid,
166166
T* output,
167167
const Mode interpolation_mode,
168168
const PaddingMode padding_mode,
169169
bool align_corners) {
170-
int inp_sW = 1;
171-
int inp_sH = in_w;
172-
int inp_sD = in_h * in_w;
173-
int inp_sC = in_d * inp_sD;
174-
int inp_sN = out_c * inp_sC;
175-
176-
int grid_sCoor = 1;
177-
int grid_sW = 3;
178-
int grid_sH = out_w * grid_sW;
179-
int grid_sD = out_h * grid_sH;
180-
int grid_sN = out_d * grid_sD;
181-
182-
int out_sW = 1;
183-
int out_sH = out_w;
184-
int out_sD = out_h * out_w;
185-
int out_sC = out_d * out_sD;
186-
int out_sN = out_c * out_sC;
187-
188-
CUDA_KERNEL_LOOP_TYPE(index, nthreads, int) {
189-
const int w = index % out_w;
190-
const int h = (index / out_w) % out_h;
191-
const int d = (index / (out_h * out_w)) % out_d;
192-
const int n = index / (out_d * out_h * out_w);
193-
const int grid_offset =
170+
IndexT inp_sW = 1;
171+
IndexT inp_sH = in_w;
172+
IndexT inp_sD = in_h * in_w;
173+
IndexT inp_sC = in_d * inp_sD;
174+
IndexT inp_sN = out_c * inp_sC;
175+
176+
IndexT grid_sCoor = 1;
177+
IndexT grid_sW = 3;
178+
IndexT grid_sH = out_w * grid_sW;
179+
IndexT grid_sD = out_h * grid_sH;
180+
IndexT grid_sN = out_d * grid_sD;
181+
182+
IndexT out_sW = 1;
183+
IndexT out_sH = out_w;
184+
IndexT out_sD = out_h * out_w;
185+
IndexT out_sC = out_d * out_sD;
186+
IndexT out_sN = out_c * out_sC;
187+
188+
CUDA_KERNEL_LOOP_TYPE(index, nthreads, IndexT) {
189+
const IndexT w = index % out_w;
190+
const IndexT h = (index / out_w) % out_h;
191+
const IndexT d = (index / (out_h * out_w)) % out_d;
192+
const IndexT n = index / (out_d * out_h * out_w);
193+
const IndexT grid_offset =
194194
n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;
195195
// get the corresponding input x, y, z coordinates from grid
196196
T ix = grid[grid_offset];
@@ -203,37 +203,37 @@ __global__ void GridSample3DCudaKernel(const int nthreads,
203203
// get corner pixel values from (x, y, z)
204204
// for 4d, we used north-east-south-west
205205
// for 5d, we add top-bottom
206-
int ix_tnw = static_cast<int>(std::floor(ix));
207-
int iy_tnw = static_cast<int>(std::floor(iy));
208-
int iz_tnw = static_cast<int>(std::floor(iz));
206+
IndexT ix_tnw = static_cast<IndexT>(std::floor(ix));
207+
IndexT iy_tnw = static_cast<IndexT>(std::floor(iy));
208+
IndexT iz_tnw = static_cast<IndexT>(std::floor(iz));
209209

210-
int ix_tne = ix_tnw + 1;
211-
int iy_tne = iy_tnw;
212-
int iz_tne = iz_tnw;
210+
IndexT ix_tne = ix_tnw + 1;
211+
IndexT iy_tne = iy_tnw;
212+
IndexT iz_tne = iz_tnw;
213213

214-
int ix_tsw = ix_tnw;
215-
int iy_tsw = iy_tnw + 1;
216-
int iz_tsw = iz_tnw;
214+
IndexT ix_tsw = ix_tnw;
215+
IndexT iy_tsw = iy_tnw + 1;
216+
IndexT iz_tsw = iz_tnw;
217217

218-
int ix_tse = ix_tnw + 1;
219-
int iy_tse = iy_tnw + 1;
220-
int iz_tse = iz_tnw;
218+
IndexT ix_tse = ix_tnw + 1;
219+
IndexT iy_tse = iy_tnw + 1;
220+
IndexT iz_tse = iz_tnw;
221221

222-
int ix_bnw = ix_tnw;
223-
int iy_bnw = iy_tnw;
224-
int iz_bnw = iz_tnw + 1;
222+
IndexT ix_bnw = ix_tnw;
223+
IndexT iy_bnw = iy_tnw;
224+
IndexT iz_bnw = iz_tnw + 1;
225225

226-
int ix_bne = ix_tnw + 1;
227-
int iy_bne = iy_tnw;
228-
int iz_bne = iz_tnw + 1;
226+
IndexT ix_bne = ix_tnw + 1;
227+
IndexT iy_bne = iy_tnw;
228+
IndexT iz_bne = iz_tnw + 1;
229229

230-
int ix_bsw = ix_tnw;
231-
int iy_bsw = iy_tnw + 1;
232-
int iz_bsw = iz_tnw + 1;
230+
IndexT ix_bsw = ix_tnw;
231+
IndexT iy_bsw = iy_tnw + 1;
232+
IndexT iz_bsw = iz_tnw + 1;
233233

234-
int ix_bse = ix_tnw + 1;
235-
int iy_bse = iy_tnw + 1;
236-
int iz_bse = iz_tnw + 1;
234+
IndexT ix_bse = ix_tnw + 1;
235+
IndexT iy_bse = iy_tnw + 1;
236+
IndexT iz_bse = iz_tnw + 1;
237237

238238
// get surfaces to each neighbor:
239239
T tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
@@ -245,10 +245,10 @@ __global__ void GridSample3DCudaKernel(const int nthreads,
245245
T bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
246246
T bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
247247

248-
auto inp_ptr_NC = input + n * inp_sN;
249-
auto out_ptr_NCDHW =
250-
output + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
251-
for (int c = 0; c < out_c;
248+
const T* inp_ptr_NC = input + n * inp_sN;
249+
T* out_ptr_NCDHW =
250+
output + (n * out_sN + d * out_sD + h * out_sH + w * out_sW);
251+
for (IndexT c = 0; c < out_c;
252252
++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) {
253253
*out_ptr_NCDHW = static_cast<T>(0);
254254
if (InBounds3D(iz_tnw, iy_tnw, ix_tnw, in_d, in_h, in_w)) {
@@ -293,15 +293,15 @@ __global__ void GridSample3DCudaKernel(const int nthreads,
293293
}
294294
}
295295
} else if (interpolation_mode == Mode::nearest) {
296-
int ix_nearest = static_cast<int>(std::round(ix));
297-
int iy_nearest = static_cast<int>(std::round(iy));
298-
int iz_nearest = static_cast<int>(std::round(iz));
296+
IndexT ix_nearest = static_cast<IndexT>(std::round(ix));
297+
IndexT iy_nearest = static_cast<IndexT>(std::round(iy));
298+
IndexT iz_nearest = static_cast<IndexT>(std::round(iz));
299299

300300
// assign nearest neighbor pixel value to output pixel
301-
auto inp_ptr_NC = input + n * inp_sN;
302-
auto out_ptr_NCDHW =
303-
output + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
304-
for (int c = 0; c < out_c;
301+
const T* inp_ptr_NC = input + n * inp_sN;
302+
T* out_ptr_NCDHW =
303+
output + (n * out_sN + d * out_sD + h * out_sH + w * out_sW);
304+
for (IndexT c = 0; c < out_c;
305305
++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) {
306306
if (InBounds3D(iz_nearest, iy_nearest, ix_nearest, in_d, in_h, in_w)) {
307307
*out_ptr_NCDHW =
@@ -343,6 +343,10 @@ void GridSampleKernel(const Context& dev_ctx,
343343
enum_mode = Mode::bilinear;
344344
}
345345

346+
bool use_int32_index = x.numel() <= std::numeric_limits<int>::max() &&
347+
grid.numel() <= std::numeric_limits<int>::max() &&
348+
out->numel() <= std::numeric_limits<int>::max();
349+
346350
if (x.dims().size() == 4) {
347351
const int64_t n = grid.dims()[0];
348352
const int64_t out_h = grid.dims()[1];
@@ -361,46 +365,36 @@ void GridSampleKernel(const Context& dev_ctx,
361365
auto cu_stream = dev_ctx.stream();
362366
backends::gpu::GpuLaunchConfig config =
363367
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, count);
364-
if (x.numel() <= std::numeric_limits<int>::max() &&
365-
grid.numel() <= std::numeric_limits<int>::max() &&
366-
out->numel() <= std::numeric_limits<int>::max()) {
367-
GridSampleCudaKernel<T, int>
368-
<<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>(
369-
n,
370-
c,
371-
out_h * out_w,
372-
in_h,
373-
in_w,
374-
x.data<T>(),
375-
grid.data<T>(),
376-
output_data,
377-
enum_mode,
378-
enum_padding_mode,
379-
align_corners);
368+
369+
#define LAUNCH_KERNEL(INDEX_TYPE) \
370+
GridSampleCudaKernel<T, INDEX_TYPE> \
371+
<<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>( \
372+
n, \
373+
c, \
374+
out_h * out_w, \
375+
in_h, \
376+
in_w, \
377+
x.data<T>(), \
378+
grid.data<T>(), \
379+
output_data, \
380+
enum_mode, \
381+
enum_padding_mode, \
382+
align_corners)
383+
if (use_int32_index) {
384+
LAUNCH_KERNEL(int);
380385
} else {
381-
GridSampleCudaKernel<T, int64_t>
382-
<<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>(
383-
n,
384-
c,
385-
out_h * out_w,
386-
in_h,
387-
in_w,
388-
x.data<T>(),
389-
grid.data<T>(),
390-
output_data,
391-
enum_mode,
392-
enum_padding_mode,
393-
align_corners);
386+
LAUNCH_KERNEL(int64_t);
394387
}
388+
#undef LAUNCH_KERNEL
395389
} else {
396-
const int n = grid.dims()[0];
397-
const int out_d = grid.dims()[1];
398-
const int out_h = grid.dims()[2];
399-
const int out_w = grid.dims()[3];
400-
const int c = x.dims()[1];
401-
const int in_d = x.dims()[2];
402-
const int in_h = x.dims()[3];
403-
const int in_w = x.dims()[4];
390+
const int64_t n = grid.dims()[0];
391+
const int64_t out_d = grid.dims()[1];
392+
const int64_t out_h = grid.dims()[2];
393+
const int64_t out_w = grid.dims()[3];
394+
const int64_t c = x.dims()[1];
395+
const int64_t in_d = x.dims()[2];
396+
const int64_t in_h = x.dims()[3];
397+
const int64_t in_w = x.dims()[4];
404398

405399
VLOG(3) << "n: " << n << "; c: " << c << "; out_d: " << out_d
406400
<< "; out_h: " << out_h << "; out_w: " << out_w;
@@ -410,26 +404,34 @@ void GridSampleKernel(const Context& dev_ctx,
410404
<< out->dims()[2] << "; " << out->dims()[3] << "; "
411405
<< out->dims()[4];
412406

413-
int count = static_cast<int>(n * out_d * out_h * out_w);
407+
int64_t count = n * out_d * out_h * out_w;
414408
auto cu_stream = dev_ctx.stream();
415409
backends::gpu::GpuLaunchConfig config =
416410
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, count);
417-
GridSample3DCudaKernel<T>
418-
<<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>(
419-
count,
420-
c,
421-
out_d,
422-
out_h,
423-
out_w,
424-
in_d,
425-
in_h,
426-
in_w,
427-
x.data<T>(),
428-
grid.data<T>(),
429-
output_data,
430-
enum_mode,
431-
enum_padding_mode,
432-
align_corners);
411+
412+
#define LAUNCH_KERNEL(INDEX_TYPE) \
413+
GridSample3DCudaKernel<T, INDEX_TYPE> \
414+
<<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>( \
415+
count, \
416+
c, \
417+
out_d, \
418+
out_h, \
419+
out_w, \
420+
in_d, \
421+
in_h, \
422+
in_w, \
423+
x.data<T>(), \
424+
grid.data<T>(), \
425+
output_data, \
426+
enum_mode, \
427+
enum_padding_mode, \
428+
align_corners)
429+
if (use_int32_index) {
430+
LAUNCH_KERNEL(int);
431+
} else {
432+
LAUNCH_KERNEL(int64_t);
433+
}
434+
#undef LAUNCH_KERNEL
433435
}
434436
}
435437

0 commit comments

Comments
 (0)