Skip to content

Commit 041ab9e

Browse files
committed
[PHI] Fix grid sample kernel for big tensor
1 parent 7b9c77d commit 041ab9e

File tree

1 file changed

+96
-86
lines changed

1 file changed

+96
-86
lines changed

paddle/phi/kernels/gpu/grid_sample_kernel.cu

Lines changed: 96 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -23,132 +23,124 @@
2323

2424
namespace phi {
2525

26-
template <typename T>
26+
template <typename T, typename IndexT>
2727
static __forceinline__ __device__ T Unnormalize(T coord,
28-
int size,
28+
IndexT size,
2929
bool align_corners) {
3030
return align_corners ? ((coord + 1.f) / 2) * (size - 1)
3131
: ((coord + 1.f) * size - 1) / 2;
3232
}
3333

34-
template <typename T>
35-
static __forceinline__ __device__ T ClipIndexes(T in, int max_value) {
34+
template <typename T, typename IndexT>
35+
static __forceinline__ __device__ T ClipIndexes(T in, IndexT max_value) {
3636
return min(static_cast<T>(max_value - 1), max(in, static_cast<T>(0)));
3737
}
3838

39-
template <typename T>
39+
template <typename T, typename IndexT>
4040
static __forceinline__ __device__ T ReflectIndexes(T in,
41-
int twice_low,
42-
int twice_high) {
41+
IndexT twice_low,
42+
IndexT twice_high) {
4343
if (twice_low == twice_high) {
4444
return static_cast<T>(0);
4545
}
4646
T min = static_cast<T>(twice_low) / 2;
4747
T span = static_cast<T>(twice_high - twice_low) / 2;
4848
in = fabs(in - min);
4949
T extra = fmod(in, span);
50-
int flips = static_cast<int>(floor(in / span));
50+
IndexT flips = floor(in / span);
5151
return (flips & 1) ? span - extra + min : extra + min; // cond ? odd : even
5252
}
5353

54-
template <typename T>
54+
template <typename T, typename IndexT>
5555
static __forceinline__ __device__ T ComputePositions(T coord,
56-
int size,
56+
IndexT size,
5757
PaddingMode padding_mode,
5858
bool align_corners) {
59-
coord = Unnormalize<T>(coord, size, align_corners);
59+
coord = Unnormalize(coord, size, align_corners);
6060
if (padding_mode == PaddingMode::border) {
6161
coord = ClipIndexes(coord, size);
6262
} else if (padding_mode == PaddingMode::reflect) {
63-
coord = align_corners ? ReflectIndexes(coord, 0, 2 * (size - 1))
64-
: ReflectIndexes(coord, -1, 2 * size - 1);
63+
coord = align_corners ? ReflectIndexes<T, IndexT>(coord, 0, 2 * (size - 1))
64+
: ReflectIndexes<T, IndexT>(coord, -1, 2 * size - 1);
6565
coord = ClipIndexes(coord, size);
6666
}
6767
return SafeDownGradeToIntRange(coord);
6868
}
6969

70-
template <typename T>
71-
__global__ void GridSampleCudaKernel(const int nthreads,
72-
int n,
73-
int out_c,
74-
int out_h,
75-
int out_w,
76-
int in_h,
77-
int in_w,
70+
template <typename T, typename IndexT>
71+
__global__ void GridSampleCudaKernel(IndexT n,
72+
IndexT out_c,
73+
IndexT out_h,
74+
IndexT out_w,
75+
IndexT in_h,
76+
IndexT in_w,
7877
const T* input,
7978
const T* grid,
8079
T* output,
8180
const Mode mode,
8281
const PaddingMode padding_mode,
8382
bool align_corners) {
84-
int inp_sN = out_c * in_h * in_w;
85-
86-
int inp_sC = in_h * in_w;
87-
int inp_sH = in_w;
88-
int inp_sW = 1;
89-
int grid_sN = out_h * out_w * 2;
90-
int grid_sH = out_w * 2;
91-
int grid_sW = 2;
92-
int grid_sCoor = 1;
93-
int out_sN = out_c * out_h * out_w;
94-
int out_sC = out_h * out_w;
95-
int out_sH = out_w;
96-
int out_sW = 1;
97-
CUDA_KERNEL_LOOP(index, nthreads) {
98-
const int w = index % out_w;
99-
const int h = (index / out_w) % out_h;
100-
const int n = index / (out_h * out_w);
101-
const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
83+
IndexT nthreads = n * (out_h * out_w);
84+
IndexT inp_sN = out_c * (in_h * in_w);
85+
IndexT inp_sC = in_h * in_w;
86+
IndexT inp_sH = in_w;
87+
IndexT inp_sW = 1;
88+
IndexT grid_sNHW = 2;
89+
IndexT grid_sCoor = 1;
90+
IndexT out_sN = out_c * (out_h * out_w);
91+
IndexT out_sC = out_h * out_w;
92+
IndexT out_sHW = 1;
93+
CUDA_KERNEL_LOOP_TYPE(index, nthreads, IndexT) {
94+
const IndexT hw = index % (out_h * out_w);
95+
const IndexT n = index / (out_h * out_w);
96+
const IndexT grid_offset = index * grid_sNHW;
10297

10398
T ix = grid[grid_offset];
10499
T iy = grid[grid_offset + grid_sCoor];
105100

106101
ix = ComputePositions(ix, in_w, padding_mode, align_corners);
107102
iy = ComputePositions(iy, in_h, padding_mode, align_corners);
108103
if (mode == Mode::bilinear) {
109-
int ix_nw = static_cast<int>(floor(ix));
110-
int iy_nw = static_cast<int>(floor(iy));
111-
int ix_ne = ix_nw + 1;
112-
int iy_ne = iy_nw;
113-
int ix_sw = ix_nw;
114-
int iy_sw = iy_nw + 1;
115-
int ix_se = ix_nw + 1;
116-
int iy_se = iy_nw + 1;
104+
IndexT ix_nw = floor(ix);
105+
IndexT iy_nw = floor(iy);
106+
IndexT ix_ne = ix_nw + 1;
107+
IndexT iy_ne = iy_nw;
108+
IndexT ix_sw = ix_nw;
109+
IndexT iy_sw = iy_nw + 1;
110+
IndexT ix_se = ix_nw + 1;
111+
IndexT iy_se = iy_nw + 1;
117112

118113
T nw = (ix_se - ix) * (iy_se - iy);
119114
T ne = (ix - ix_sw) * (iy_sw - iy);
120115
T sw = (ix_ne - ix) * (iy - iy_ne);
121116
T se = (ix - ix_nw) * (iy - iy_nw);
122117

123-
auto inp_offset_NC = n * inp_sN;
118+
IndexT inp_offset_NC = n * inp_sN;
119+
T* out_ptr_NCHW = output + (n * out_sN + hw * out_sHW);
124120

125-
auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW;
126-
for (int c = 0; c < out_c;
121+
for (IndexT c = 0; c < out_c;
127122
++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) {
128-
*out_ptr_NCHW = static_cast<T>(0);
123+
T value{0};
129124
if (InBounds(iy_nw, ix_nw, in_h, in_w)) {
130-
*out_ptr_NCHW +=
131-
input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW] * nw;
125+
value += input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW] * nw;
132126
}
133127
if (InBounds(iy_ne, ix_ne, in_h, in_w)) {
134-
*out_ptr_NCHW +=
135-
input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW] * ne;
128+
value += input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW] * ne;
136129
}
137130
if (InBounds(iy_sw, ix_sw, in_h, in_w)) {
138-
*out_ptr_NCHW +=
139-
input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW] * sw;
131+
value += input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW] * sw;
140132
}
141133
if (InBounds(iy_se, ix_se, in_h, in_w)) {
142-
*out_ptr_NCHW +=
143-
input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW] * se;
134+
value += input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW] * se;
144135
}
136+
*out_ptr_NCHW = value;
145137
}
146138
} else if (mode == Mode::nearest) {
147-
int ix_nearest = static_cast<int>(std::nearbyint(ix));
148-
int iy_nearest = static_cast<int>(std::nearbyint(iy));
149-
auto inp_offset_NC = n * inp_sN;
150-
auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW;
151-
for (int c = 0; c < out_c;
139+
IndexT ix_nearest = std::nearbyint(ix);
140+
IndexT iy_nearest = std::nearbyint(iy);
141+
IndexT inp_offset_NC = n * inp_sN;
142+
T* out_ptr_NCHW = output + (n * out_sN + hw * out_sHW);
143+
for (IndexT c = 0; c < out_c;
152144
++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) {
153145
if (InBounds(iy_nearest, ix_nearest, in_h, in_w)) {
154146
*out_ptr_NCHW =
@@ -349,38 +341,56 @@ void GridSampleKernel(const Context& dev_ctx,
349341
}
350342

351343
if (x.dims().size() == 4) {
352-
const int n = grid.dims()[0];
353-
const int out_h = grid.dims()[1];
354-
const int out_w = grid.dims()[2];
355-
const int c = x.dims()[1];
356-
const int in_h = x.dims()[2];
357-
const int in_w = x.dims()[3];
344+
const int64_t n = grid.dims()[0];
345+
const int64_t out_h = grid.dims()[1];
346+
const int64_t out_w = grid.dims()[2];
347+
const int64_t c = x.dims()[1];
348+
const int64_t in_h = x.dims()[2];
349+
const int64_t in_w = x.dims()[3];
358350
VLOG(3) << "n: " << n << "; c: " << c << "; out_h: " << out_h
359351
<< "; out_w: " << out_w;
360352

361353
auto* output_data = dev_ctx.template Alloc<T>(out);
362354
VLOG(3) << "out dims: " << out->dims()[0] << "; " << out->dims()[1] << "; "
363355
<< out->dims()[2] << "; " << out->dims()[3];
364356

365-
int count = static_cast<int>(n * out_h * out_w);
357+
int64_t count = n * out_h * out_w;
366358
auto cu_stream = dev_ctx.stream();
367359
backends::gpu::GpuLaunchConfig config =
368360
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, count);
369-
GridSampleCudaKernel<T>
370-
<<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>(
371-
count,
372-
n,
373-
c,
374-
out_h,
375-
out_w,
376-
in_h,
377-
in_w,
378-
x.data<T>(),
379-
grid.data<T>(),
380-
output_data,
381-
enum_mode,
382-
enum_padding_mode,
383-
align_corners);
361+
if (x.numel() <= std::numeric_limits<int>::max() &&
362+
grid.numel() <= std::numeric_limits<int>::max() &&
363+
out->numel() <= std::numeric_limits<int>::max()) {
364+
GridSampleCudaKernel<T, int>
365+
<<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>(
366+
n,
367+
c,
368+
out_h,
369+
out_w,
370+
in_h,
371+
in_w,
372+
x.data<T>(),
373+
grid.data<T>(),
374+
output_data,
375+
enum_mode,
376+
enum_padding_mode,
377+
align_corners);
378+
} else {
379+
GridSampleCudaKernel<T, int64_t>
380+
<<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>(
381+
n,
382+
c,
383+
out_h,
384+
out_w,
385+
in_h,
386+
in_w,
387+
x.data<T>(),
388+
grid.data<T>(),
389+
output_data,
390+
enum_mode,
391+
enum_padding_mode,
392+
align_corners);
393+
}
384394
} else {
385395
const int n = grid.dims()[0];
386396
const int out_d = grid.dims()[1];

0 commit comments

Comments
 (0)