@@ -221,7 +221,7 @@ bool LaunchStridedCopyCaseOneKernel(
221
221
const phi::Array<int64_t , phi::DDim::kMaxRank + 1 >& output_stride,
222
222
const phi::Array<int64_t , phi::DDim::kMaxRank + 1 >& dims,
223
223
int rank,
224
- int numel) {
224
+ int64_t numel) {
225
225
dim3 grid (1 , 1 , 1 ), block (1 , 1 , 1 );
226
226
phi::Array<int64_t , 6 > cur_dims;
227
227
block.x = 512 ;
@@ -398,7 +398,7 @@ void LaunchStridedCopyDefaultKernel(
398
398
const phi::Array<int64_t , phi::DDim::kMaxRank + 1 >& output_stride,
399
399
const phi::Array<int64_t , phi::DDim::kMaxRank + 1 >& dims,
400
400
int rank,
401
- int numel) {
401
+ int64_t numel) {
402
402
int64_t block = 512 ;
403
403
int64_t grid = (numel + block - 1 ) / block;
404
404
@@ -648,7 +648,7 @@ bool LaunchStrided2ContiguousCaseOneKernel(
648
648
T* output_data,
649
649
const phi::Array<int64_t , phi::DDim::kMaxRank + 1 >& dims,
650
650
int rank,
651
- int numel) {
651
+ int64_t numel) {
652
652
dim3 grid (1 , 1 , 1 ), block (1 , 1 , 1 );
653
653
phi::Array<int64_t , 6 > cur_dims;
654
654
block.x = 512 ;
@@ -803,7 +803,7 @@ void LaunchStrided2ContiguousDefaultKernel(
803
803
T* output_data,
804
804
const phi::Array<int64_t , phi::DDim::kMaxRank + 1 >& dims,
805
805
int rank,
806
- int numel) {
806
+ int64_t numel) {
807
807
int64_t block = 512 ;
808
808
int64_t grid = (numel + block - 1 ) / block;
809
809
@@ -1054,7 +1054,7 @@ bool LaunchContiguous2StridedCaseOneKernel(
1054
1054
const phi::Array<int64_t , phi::DDim::kMaxRank + 1 >& output_stride,
1055
1055
const phi::Array<int64_t , phi::DDim::kMaxRank + 1 >& dims,
1056
1056
int rank,
1057
- int numel) {
1057
+ int64_t numel) {
1058
1058
dim3 grid (1 , 1 , 1 ), block (1 , 1 , 1 );
1059
1059
phi::Array<int64_t , 6 > cur_dims;
1060
1060
block.x = 512 ;
@@ -1209,7 +1209,7 @@ void LaunchContiguous2StridedDefaultKernel(
1209
1209
const phi::Array<int64_t , phi::DDim::kMaxRank + 1 >& output_stride,
1210
1210
const phi::Array<int64_t , phi::DDim::kMaxRank + 1 >& dims,
1211
1211
int rank,
1212
- int numel) {
1212
+ int64_t numel) {
1213
1213
int64_t block = 512 ;
1214
1214
int64_t grid = (numel + block - 1 ) / block;
1215
1215
0 commit comments