Skip to content

Commit 4e331fd

Browse files
Dmovicwanghuancoder
authored andcommitted
Update strided copy kernel (PaddlePaddle#72662)
1 parent 98b4876 commit 4e331fd

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

paddle/phi/kernels/gpu/strided_copy_kernel.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ bool LaunchStridedCopyCaseOneKernel(
221221
const phi::Array<int64_t, phi::DDim::kMaxRank + 1>& output_stride,
222222
const phi::Array<int64_t, phi::DDim::kMaxRank + 1>& dims,
223223
int rank,
224-
int numel) {
224+
int64_t numel) {
225225
dim3 grid(1, 1, 1), block(1, 1, 1);
226226
phi::Array<int64_t, 6> cur_dims;
227227
block.x = 512;
@@ -398,7 +398,7 @@ void LaunchStridedCopyDefaultKernel(
398398
const phi::Array<int64_t, phi::DDim::kMaxRank + 1>& output_stride,
399399
const phi::Array<int64_t, phi::DDim::kMaxRank + 1>& dims,
400400
int rank,
401-
int numel) {
401+
int64_t numel) {
402402
int64_t block = 512;
403403
int64_t grid = (numel + block - 1) / block;
404404

@@ -648,7 +648,7 @@ bool LaunchStrided2ContiguousCaseOneKernel(
648648
T* output_data,
649649
const phi::Array<int64_t, phi::DDim::kMaxRank + 1>& dims,
650650
int rank,
651-
int numel) {
651+
int64_t numel) {
652652
dim3 grid(1, 1, 1), block(1, 1, 1);
653653
phi::Array<int64_t, 6> cur_dims;
654654
block.x = 512;
@@ -803,7 +803,7 @@ void LaunchStrided2ContiguousDefaultKernel(
803803
T* output_data,
804804
const phi::Array<int64_t, phi::DDim::kMaxRank + 1>& dims,
805805
int rank,
806-
int numel) {
806+
int64_t numel) {
807807
int64_t block = 512;
808808
int64_t grid = (numel + block - 1) / block;
809809

@@ -1054,7 +1054,7 @@ bool LaunchContiguous2StridedCaseOneKernel(
10541054
const phi::Array<int64_t, phi::DDim::kMaxRank + 1>& output_stride,
10551055
const phi::Array<int64_t, phi::DDim::kMaxRank + 1>& dims,
10561056
int rank,
1057-
int numel) {
1057+
int64_t numel) {
10581058
dim3 grid(1, 1, 1), block(1, 1, 1);
10591059
phi::Array<int64_t, 6> cur_dims;
10601060
block.x = 512;
@@ -1209,7 +1209,7 @@ void LaunchContiguous2StridedDefaultKernel(
12091209
const phi::Array<int64_t, phi::DDim::kMaxRank + 1>& output_stride,
12101210
const phi::Array<int64_t, phi::DDim::kMaxRank + 1>& dims,
12111211
int rank,
1212-
int numel) {
1212+
int64_t numel) {
12131213
int64_t block = 512;
12141214
int64_t grid = (numel + block - 1) / block;
12151215

0 commit comments

Comments
 (0)