@@ -152,45 +152,45 @@ __global__ void GridSampleCudaKernel(IndexT n,
152
152
}
153
153
}
154
154
155
- template <typename T>
156
- __global__ void GridSample3DCudaKernel (const int nthreads,
157
- int out_c,
158
- int out_d,
159
- int out_h,
160
- int out_w,
161
- int in_d,
162
- int in_h,
163
- int in_w,
155
+ template <typename T, typename IndexT >
156
+ __global__ void GridSample3DCudaKernel (const IndexT nthreads,
157
+ IndexT out_c,
158
+ IndexT out_d,
159
+ IndexT out_h,
160
+ IndexT out_w,
161
+ IndexT in_d,
162
+ IndexT in_h,
163
+ IndexT in_w,
164
164
const T* input,
165
165
const T* grid,
166
166
T* output,
167
167
const Mode interpolation_mode,
168
168
const PaddingMode padding_mode,
169
169
bool align_corners) {
170
- int inp_sW = 1 ;
171
- int inp_sH = in_w;
172
- int inp_sD = in_h * in_w;
173
- int inp_sC = in_d * inp_sD;
174
- int inp_sN = out_c * inp_sC;
175
-
176
- int grid_sCoor = 1 ;
177
- int grid_sW = 3 ;
178
- int grid_sH = out_w * grid_sW;
179
- int grid_sD = out_h * grid_sH;
180
- int grid_sN = out_d * grid_sD;
181
-
182
- int out_sW = 1 ;
183
- int out_sH = out_w;
184
- int out_sD = out_h * out_w;
185
- int out_sC = out_d * out_sD;
186
- int out_sN = out_c * out_sC;
187
-
188
- CUDA_KERNEL_LOOP_TYPE (index, nthreads, int ) {
189
- const int w = index % out_w;
190
- const int h = (index / out_w) % out_h;
191
- const int d = (index / (out_h * out_w)) % out_d;
192
- const int n = index / (out_d * out_h * out_w);
193
- const int grid_offset =
170
+ IndexT inp_sW = 1 ;
171
+ IndexT inp_sH = in_w;
172
+ IndexT inp_sD = in_h * in_w;
173
+ IndexT inp_sC = in_d * inp_sD;
174
+ IndexT inp_sN = out_c * inp_sC;
175
+
176
+ IndexT grid_sCoor = 1 ;
177
+ IndexT grid_sW = 3 ;
178
+ IndexT grid_sH = out_w * grid_sW;
179
+ IndexT grid_sD = out_h * grid_sH;
180
+ IndexT grid_sN = out_d * grid_sD;
181
+
182
+ IndexT out_sW = 1 ;
183
+ IndexT out_sH = out_w;
184
+ IndexT out_sD = out_h * out_w;
185
+ IndexT out_sC = out_d * out_sD;
186
+ IndexT out_sN = out_c * out_sC;
187
+
188
+ CUDA_KERNEL_LOOP_TYPE (index, nthreads, IndexT ) {
189
+ const IndexT w = index % out_w;
190
+ const IndexT h = (index / out_w) % out_h;
191
+ const IndexT d = (index / (out_h * out_w)) % out_d;
192
+ const IndexT n = index / (out_d * out_h * out_w);
193
+ const IndexT grid_offset =
194
194
n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;
195
195
// get the corresponding input x, y, z coordinates from grid
196
196
T ix = grid[grid_offset];
@@ -203,37 +203,37 @@ __global__ void GridSample3DCudaKernel(const int nthreads,
203
203
// get corner pixel values from (x, y, z)
204
204
// for 4d, we used north-east-south-west
205
205
// for 5d, we add top-bottom
206
- int ix_tnw = static_cast <int >(std::floor (ix));
207
- int iy_tnw = static_cast <int >(std::floor (iy));
208
- int iz_tnw = static_cast <int >(std::floor (iz));
206
+ IndexT ix_tnw = static_cast <IndexT >(std::floor (ix));
207
+ IndexT iy_tnw = static_cast <IndexT >(std::floor (iy));
208
+ IndexT iz_tnw = static_cast <IndexT >(std::floor (iz));
209
209
210
- int ix_tne = ix_tnw + 1 ;
211
- int iy_tne = iy_tnw;
212
- int iz_tne = iz_tnw;
210
+ IndexT ix_tne = ix_tnw + 1 ;
211
+ IndexT iy_tne = iy_tnw;
212
+ IndexT iz_tne = iz_tnw;
213
213
214
- int ix_tsw = ix_tnw;
215
- int iy_tsw = iy_tnw + 1 ;
216
- int iz_tsw = iz_tnw;
214
+ IndexT ix_tsw = ix_tnw;
215
+ IndexT iy_tsw = iy_tnw + 1 ;
216
+ IndexT iz_tsw = iz_tnw;
217
217
218
- int ix_tse = ix_tnw + 1 ;
219
- int iy_tse = iy_tnw + 1 ;
220
- int iz_tse = iz_tnw;
218
+ IndexT ix_tse = ix_tnw + 1 ;
219
+ IndexT iy_tse = iy_tnw + 1 ;
220
+ IndexT iz_tse = iz_tnw;
221
221
222
- int ix_bnw = ix_tnw;
223
- int iy_bnw = iy_tnw;
224
- int iz_bnw = iz_tnw + 1 ;
222
+ IndexT ix_bnw = ix_tnw;
223
+ IndexT iy_bnw = iy_tnw;
224
+ IndexT iz_bnw = iz_tnw + 1 ;
225
225
226
- int ix_bne = ix_tnw + 1 ;
227
- int iy_bne = iy_tnw;
228
- int iz_bne = iz_tnw + 1 ;
226
+ IndexT ix_bne = ix_tnw + 1 ;
227
+ IndexT iy_bne = iy_tnw;
228
+ IndexT iz_bne = iz_tnw + 1 ;
229
229
230
- int ix_bsw = ix_tnw;
231
- int iy_bsw = iy_tnw + 1 ;
232
- int iz_bsw = iz_tnw + 1 ;
230
+ IndexT ix_bsw = ix_tnw;
231
+ IndexT iy_bsw = iy_tnw + 1 ;
232
+ IndexT iz_bsw = iz_tnw + 1 ;
233
233
234
- int ix_bse = ix_tnw + 1 ;
235
- int iy_bse = iy_tnw + 1 ;
236
- int iz_bse = iz_tnw + 1 ;
234
+ IndexT ix_bse = ix_tnw + 1 ;
235
+ IndexT iy_bse = iy_tnw + 1 ;
236
+ IndexT iz_bse = iz_tnw + 1 ;
237
237
238
238
// get surfaces to each neighbor:
239
239
T tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
@@ -245,10 +245,10 @@ __global__ void GridSample3DCudaKernel(const int nthreads,
245
245
T bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
246
246
T bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
247
247
248
- auto inp_ptr_NC = input + n * inp_sN;
249
- auto out_ptr_NCDHW =
250
- output + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
251
- for (int c = 0 ; c < out_c;
248
+ const T* inp_ptr_NC = input + n * inp_sN;
249
+ T* out_ptr_NCDHW =
250
+ output + ( n * out_sN + d * out_sD + h * out_sH + w * out_sW) ;
251
+ for (IndexT c = 0 ; c < out_c;
252
252
++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) {
253
253
*out_ptr_NCDHW = static_cast <T>(0 );
254
254
if (InBounds3D (iz_tnw, iy_tnw, ix_tnw, in_d, in_h, in_w)) {
@@ -293,15 +293,15 @@ __global__ void GridSample3DCudaKernel(const int nthreads,
293
293
}
294
294
}
295
295
} else if (interpolation_mode == Mode::nearest) {
296
- int ix_nearest = static_cast <int >(std::round (ix));
297
- int iy_nearest = static_cast <int >(std::round (iy));
298
- int iz_nearest = static_cast <int >(std::round (iz));
296
+ IndexT ix_nearest = static_cast <IndexT >(std::round (ix));
297
+ IndexT iy_nearest = static_cast <IndexT >(std::round (iy));
298
+ IndexT iz_nearest = static_cast <IndexT >(std::round (iz));
299
299
300
300
// assign nearest neighbor pixel value to output pixel
301
- auto inp_ptr_NC = input + n * inp_sN;
302
- auto out_ptr_NCDHW =
303
- output + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
304
- for (int c = 0 ; c < out_c;
301
+ const T* inp_ptr_NC = input + n * inp_sN;
302
+ T* out_ptr_NCDHW =
303
+ output + ( n * out_sN + d * out_sD + h * out_sH + w * out_sW) ;
304
+ for (IndexT c = 0 ; c < out_c;
305
305
++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) {
306
306
if (InBounds3D (iz_nearest, iy_nearest, ix_nearest, in_d, in_h, in_w)) {
307
307
*out_ptr_NCDHW =
@@ -343,6 +343,10 @@ void GridSampleKernel(const Context& dev_ctx,
343
343
enum_mode = Mode::bilinear;
344
344
}
345
345
346
+ bool use_int32_index = x.numel () <= std::numeric_limits<int >::max () &&
347
+ grid.numel () <= std::numeric_limits<int >::max () &&
348
+ out->numel () <= std::numeric_limits<int >::max ();
349
+
346
350
if (x.dims ().size () == 4 ) {
347
351
const int64_t n = grid.dims ()[0 ];
348
352
const int64_t out_h = grid.dims ()[1 ];
@@ -361,46 +365,36 @@ void GridSampleKernel(const Context& dev_ctx,
361
365
auto cu_stream = dev_ctx.stream ();
362
366
backends::gpu::GpuLaunchConfig config =
363
367
backends::gpu::GetGpuLaunchConfig1D (dev_ctx, count);
364
- if (x.numel () <= std::numeric_limits<int >::max () &&
365
- grid.numel () <= std::numeric_limits<int >::max () &&
366
- out->numel () <= std::numeric_limits<int >::max ()) {
367
- GridSampleCudaKernel<T, int >
368
- <<<config.block_per_grid, config.thread_per_block, 0 , cu_stream>>> (
369
- n,
370
- c,
371
- out_h * out_w,
372
- in_h,
373
- in_w,
374
- x.data <T>(),
375
- grid.data <T>(),
376
- output_data,
377
- enum_mode,
378
- enum_padding_mode,
379
- align_corners);
368
+
369
+ #define LAUNCH_KERNEL (INDEX_TYPE ) \
370
+ GridSampleCudaKernel<T, INDEX_TYPE> \
371
+ <<<config.block_per_grid, config.thread_per_block, 0 , cu_stream>>> ( \
372
+ n, \
373
+ c, \
374
+ out_h * out_w, \
375
+ in_h, \
376
+ in_w, \
377
+ x.data <T>(), \
378
+ grid.data <T>(), \
379
+ output_data, \
380
+ enum_mode, \
381
+ enum_padding_mode, \
382
+ align_corners)
383
+ if (use_int32_index) {
384
+ LAUNCH_KERNEL (int );
380
385
} else {
381
- GridSampleCudaKernel<T, int64_t >
382
- <<<config.block_per_grid, config.thread_per_block, 0 , cu_stream>>> (
383
- n,
384
- c,
385
- out_h * out_w,
386
- in_h,
387
- in_w,
388
- x.data <T>(),
389
- grid.data <T>(),
390
- output_data,
391
- enum_mode,
392
- enum_padding_mode,
393
- align_corners);
386
+ LAUNCH_KERNEL (int64_t );
394
387
}
388
+ #undef LAUNCH_KERNEL
395
389
} else {
396
- const int n = grid.dims ()[0 ];
397
- const int out_d = grid.dims ()[1 ];
398
- const int out_h = grid.dims ()[2 ];
399
- const int out_w = grid.dims ()[3 ];
400
- const int c = x.dims ()[1 ];
401
- const int in_d = x.dims ()[2 ];
402
- const int in_h = x.dims ()[3 ];
403
- const int in_w = x.dims ()[4 ];
390
+ const int64_t n = grid.dims ()[0 ];
391
+ const int64_t out_d = grid.dims ()[1 ];
392
+ const int64_t out_h = grid.dims ()[2 ];
393
+ const int64_t out_w = grid.dims ()[3 ];
394
+ const int64_t c = x.dims ()[1 ];
395
+ const int64_t in_d = x.dims ()[2 ];
396
+ const int64_t in_h = x.dims ()[3 ];
397
+ const int64_t in_w = x.dims ()[4 ];
404
398
405
399
VLOG (3 ) << " n: " << n << " ; c: " << c << " ; out_d: " << out_d
406
400
<< " ; out_h: " << out_h << " ; out_w: " << out_w;
@@ -410,26 +404,34 @@ void GridSampleKernel(const Context& dev_ctx,
410
404
<< out->dims ()[2 ] << " ; " << out->dims ()[3 ] << " ; "
411
405
<< out->dims ()[4 ];
412
406
413
- int count = static_cast < int >( n * out_d * out_h * out_w) ;
407
+ int64_t count = n * out_d * out_h * out_w;
414
408
auto cu_stream = dev_ctx.stream ();
415
409
backends::gpu::GpuLaunchConfig config =
416
410
backends::gpu::GetGpuLaunchConfig1D (dev_ctx, count);
417
- GridSample3DCudaKernel<T>
418
- <<<config.block_per_grid, config.thread_per_block, 0 , cu_stream>>> (
419
- count,
420
- c,
421
- out_d,
422
- out_h,
423
- out_w,
424
- in_d,
425
- in_h,
426
- in_w,
427
- x.data <T>(),
428
- grid.data <T>(),
429
- output_data,
430
- enum_mode,
431
- enum_padding_mode,
432
- align_corners);
411
+
412
+ #define LAUNCH_KERNEL (INDEX_TYPE ) \
413
+ GridSample3DCudaKernel<T, INDEX_TYPE> \
414
+ <<<config.block_per_grid, config.thread_per_block, 0 , cu_stream>>> ( \
415
+ count, \
416
+ c, \
417
+ out_d, \
418
+ out_h, \
419
+ out_w, \
420
+ in_d, \
421
+ in_h, \
422
+ in_w, \
423
+ x.data <T>(), \
424
+ grid.data <T>(), \
425
+ output_data, \
426
+ enum_mode, \
427
+ enum_padding_mode, \
428
+ align_corners)
429
+ if (use_int32_index) {
430
+ LAUNCH_KERNEL (int );
431
+ } else {
432
+ LAUNCH_KERNEL (int64_t );
433
+ }
434
+ #undef LAUNCH_KERNEL
433
435
}
434
436
}
435
437
0 commit comments