Skip to content

Commit 831663e

Browse files
lshpkuwanghuancoder
authored andcommitted
[PHI] Fix grid sample kernel for big tensor (PaddlePaddle#72628)
1 parent 4e331fd commit 831663e

File tree

1 file changed

+96
-89
lines changed

1 file changed

+96
-89
lines changed

paddle/phi/kernels/gpu/grid_sample_kernel.cu

Lines changed: 96 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -23,132 +23,123 @@
2323

2424
namespace phi {
2525

26-
template <typename T>
26+
template <typename T, typename IndexT>
2727
static __forceinline__ __device__ T Unnormalize(T coord,
28-
int size,
28+
IndexT size,
2929
bool align_corners) {
3030
return align_corners ? ((coord + 1.f) / 2) * (size - 1)
3131
: ((coord + 1.f) * size - 1) / 2;
3232
}
3333

34-
template <typename T>
35-
static __forceinline__ __device__ T ClipIndexes(T in, int max_value) {
34+
template <typename T, typename IndexT>
35+
static __forceinline__ __device__ T ClipIndexes(T in, IndexT max_value) {
3636
return min(static_cast<T>(max_value - 1), max(in, static_cast<T>(0)));
3737
}
3838

39-
template <typename T>
39+
template <typename T, typename IndexT>
4040
static __forceinline__ __device__ T ReflectIndexes(T in,
41-
int twice_low,
42-
int twice_high) {
41+
IndexT twice_low,
42+
IndexT twice_high) {
4343
if (twice_low == twice_high) {
4444
return static_cast<T>(0);
4545
}
4646
T min = static_cast<T>(twice_low) / 2;
4747
T span = static_cast<T>(twice_high - twice_low) / 2;
4848
in = fabs(in - min);
4949
T extra = fmod(in, span);
50-
int flips = static_cast<int>(floor(in / span));
50+
IndexT flips = floor(in / span);
5151
return (flips & 1) ? span - extra + min : extra + min; // cond ? odd : even
5252
}
5353

54-
template <typename T>
54+
template <typename T, typename IndexT>
5555
static __forceinline__ __device__ T ComputePositions(T coord,
56-
int size,
56+
IndexT size,
5757
PaddingMode padding_mode,
5858
bool align_corners) {
59-
coord = Unnormalize<T>(coord, size, align_corners);
59+
coord = Unnormalize(coord, size, align_corners);
6060
if (padding_mode == PaddingMode::border) {
6161
coord = ClipIndexes(coord, size);
6262
} else if (padding_mode == PaddingMode::reflect) {
63-
coord = align_corners ? ReflectIndexes(coord, 0, 2 * (size - 1))
64-
: ReflectIndexes(coord, -1, 2 * size - 1);
63+
coord = align_corners ? ReflectIndexes<T, IndexT>(coord, 0, 2 * (size - 1))
64+
: ReflectIndexes<T, IndexT>(coord, -1, 2 * size - 1);
6565
coord = ClipIndexes(coord, size);
6666
}
6767
return SafeDownGradeToIntRange(coord);
6868
}
6969

70-
template <typename T>
71-
__global__ void GridSampleCudaKernel(const int nthreads,
72-
int n,
73-
int out_c,
74-
int out_h,
75-
int out_w,
76-
int in_h,
77-
int in_w,
78-
const T* input,
79-
const T* grid,
80-
T* output,
70+
template <typename T, typename IndexT>
71+
__global__ void GridSampleCudaKernel(IndexT n,
72+
IndexT out_c,
73+
IndexT out_hw,
74+
IndexT in_h,
75+
IndexT in_w,
76+
const T* __restrict__ input,
77+
const T* __restrict__ grid,
78+
T* __restrict__ output,
8179
const Mode mode,
8280
const PaddingMode padding_mode,
8381
bool align_corners) {
84-
int inp_sN = out_c * in_h * in_w;
85-
86-
int inp_sC = in_h * in_w;
87-
int inp_sH = in_w;
88-
int inp_sW = 1;
89-
int grid_sN = out_h * out_w * 2;
90-
int grid_sH = out_w * 2;
91-
int grid_sW = 2;
92-
int grid_sCoor = 1;
93-
int out_sN = out_c * out_h * out_w;
94-
int out_sC = out_h * out_w;
95-
int out_sH = out_w;
96-
int out_sW = 1;
97-
CUDA_KERNEL_LOOP(index, nthreads) {
98-
const int w = index % out_w;
99-
const int h = (index / out_w) % out_h;
100-
const int n = index / (out_h * out_w);
101-
const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
82+
IndexT nthreads = n * out_hw;
83+
IndexT inp_sN = out_c * (in_h * in_w);
84+
IndexT inp_sC = in_h * in_w;
85+
IndexT inp_sH = in_w;
86+
IndexT inp_sW = 1;
87+
IndexT grid_sNHW = 2;
88+
IndexT grid_sCoor = 1;
89+
IndexT out_sN = out_c * out_hw;
90+
IndexT out_sC = out_hw;
91+
IndexT out_sHW = 1;
92+
CUDA_KERNEL_LOOP_TYPE(index, nthreads, IndexT) {
93+
const IndexT hw = index % out_hw;
94+
const IndexT n = index / out_hw;
95+
const IndexT grid_offset = index * grid_sNHW;
10296

10397
T ix = grid[grid_offset];
10498
T iy = grid[grid_offset + grid_sCoor];
10599

106100
ix = ComputePositions(ix, in_w, padding_mode, align_corners);
107101
iy = ComputePositions(iy, in_h, padding_mode, align_corners);
108102
if (mode == Mode::bilinear) {
109-
int ix_nw = static_cast<int>(floor(ix));
110-
int iy_nw = static_cast<int>(floor(iy));
111-
int ix_ne = ix_nw + 1;
112-
int iy_ne = iy_nw;
113-
int ix_sw = ix_nw;
114-
int iy_sw = iy_nw + 1;
115-
int ix_se = ix_nw + 1;
116-
int iy_se = iy_nw + 1;
103+
IndexT ix_nw = floor(ix);
104+
IndexT iy_nw = floor(iy);
105+
IndexT ix_ne = ix_nw + 1;
106+
IndexT iy_ne = iy_nw;
107+
IndexT ix_sw = ix_nw;
108+
IndexT iy_sw = iy_nw + 1;
109+
IndexT ix_se = ix_nw + 1;
110+
IndexT iy_se = iy_nw + 1;
117111

118112
T nw = (ix_se - ix) * (iy_se - iy);
119113
T ne = (ix - ix_sw) * (iy_sw - iy);
120114
T sw = (ix_ne - ix) * (iy - iy_ne);
121115
T se = (ix - ix_nw) * (iy - iy_nw);
122116

123-
auto inp_offset_NC = n * inp_sN;
117+
IndexT inp_offset_NC = n * inp_sN;
118+
T* out_ptr_NCHW = output + (n * out_sN + hw * out_sHW);
124119

125-
auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW;
126-
for (int c = 0; c < out_c;
120+
for (IndexT c = 0; c < out_c;
127121
++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) {
128-
*out_ptr_NCHW = static_cast<T>(0);
122+
T value{0};
129123
if (InBounds(iy_nw, ix_nw, in_h, in_w)) {
130-
*out_ptr_NCHW +=
131-
input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW] * nw;
124+
value += input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW] * nw;
132125
}
133126
if (InBounds(iy_ne, ix_ne, in_h, in_w)) {
134-
*out_ptr_NCHW +=
135-
input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW] * ne;
127+
value += input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW] * ne;
136128
}
137129
if (InBounds(iy_sw, ix_sw, in_h, in_w)) {
138-
*out_ptr_NCHW +=
139-
input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW] * sw;
130+
value += input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW] * sw;
140131
}
141132
if (InBounds(iy_se, ix_se, in_h, in_w)) {
142-
*out_ptr_NCHW +=
143-
input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW] * se;
133+
value += input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW] * se;
144134
}
135+
*out_ptr_NCHW = value;
145136
}
146137
} else if (mode == Mode::nearest) {
147-
int ix_nearest = static_cast<int>(std::nearbyint(ix));
148-
int iy_nearest = static_cast<int>(std::nearbyint(iy));
149-
auto inp_offset_NC = n * inp_sN;
150-
auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW;
151-
for (int c = 0; c < out_c;
138+
IndexT ix_nearest = std::nearbyint(ix);
139+
IndexT iy_nearest = std::nearbyint(iy);
140+
IndexT inp_offset_NC = n * inp_sN;
141+
T* out_ptr_NCHW = output + (n * out_sN + hw * out_sHW);
142+
for (IndexT c = 0; c < out_c;
152143
++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) {
153144
if (InBounds(iy_nearest, ix_nearest, in_h, in_w)) {
154145
*out_ptr_NCHW =
@@ -349,38 +340,54 @@ void GridSampleKernel(const Context& dev_ctx,
349340
}
350341

351342
if (x.dims().size() == 4) {
352-
const int n = grid.dims()[0];
353-
const int out_h = grid.dims()[1];
354-
const int out_w = grid.dims()[2];
355-
const int c = x.dims()[1];
356-
const int in_h = x.dims()[2];
357-
const int in_w = x.dims()[3];
343+
const int64_t n = grid.dims()[0];
344+
const int64_t out_h = grid.dims()[1];
345+
const int64_t out_w = grid.dims()[2];
346+
const int64_t c = x.dims()[1];
347+
const int64_t in_h = x.dims()[2];
348+
const int64_t in_w = x.dims()[3];
358349
VLOG(3) << "n: " << n << "; c: " << c << "; out_h: " << out_h
359350
<< "; out_w: " << out_w;
360351

361352
auto* output_data = dev_ctx.template Alloc<T>(out);
362353
VLOG(3) << "out dims: " << out->dims()[0] << "; " << out->dims()[1] << "; "
363354
<< out->dims()[2] << "; " << out->dims()[3];
364355

365-
int count = static_cast<int>(n * out_h * out_w);
356+
int64_t count = n * out_h * out_w;
366357
auto cu_stream = dev_ctx.stream();
367358
backends::gpu::GpuLaunchConfig config =
368359
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, count);
369-
GridSampleCudaKernel<T>
370-
<<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>(
371-
count,
372-
n,
373-
c,
374-
out_h,
375-
out_w,
376-
in_h,
377-
in_w,
378-
x.data<T>(),
379-
grid.data<T>(),
380-
output_data,
381-
enum_mode,
382-
enum_padding_mode,
383-
align_corners);
360+
if (x.numel() <= std::numeric_limits<int>::max() &&
361+
grid.numel() <= std::numeric_limits<int>::max() &&
362+
out->numel() <= std::numeric_limits<int>::max()) {
363+
GridSampleCudaKernel<T, int>
364+
<<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>(
365+
n,
366+
c,
367+
out_h * out_w,
368+
in_h,
369+
in_w,
370+
x.data<T>(),
371+
grid.data<T>(),
372+
output_data,
373+
enum_mode,
374+
enum_padding_mode,
375+
align_corners);
376+
} else {
377+
GridSampleCudaKernel<T, int64_t>
378+
<<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>(
379+
n,
380+
c,
381+
out_h * out_w,
382+
in_h,
383+
in_w,
384+
x.data<T>(),
385+
grid.data<T>(),
386+
output_data,
387+
enum_mode,
388+
enum_padding_mode,
389+
align_corners);
390+
}
384391
} else {
385392
const int n = grid.dims()[0];
386393
const int out_d = grid.dims()[1];

0 commit comments

Comments
 (0)