Skip to content

Commit 752eb6e

Browse files
Enigmatismswanghuancoder
authored andcommitted
[PHI] Preliminary fix for elementwise broadcast int32 shape overflow (PaddlePaddle#72584)
1 parent 04d7eb9 commit 752eb6e

File tree

9 files changed

+75
-69
lines changed

9 files changed

+75
-69
lines changed

paddle/fluid/operators/elementwise/elementwise_op.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,9 @@ class ElementwiseOp : public framework::OperatorWithKernel {
106106
axis));
107107
axis = (axis < 0 ? (std::abs(x_dims.size() - y_dims.size()) + axis + 1)
108108
: axis);
109-
std::vector<int> x_dims_array(max_dim);
110-
std::vector<int> y_dims_array(max_dim);
111-
std::vector<int> out_dims_array(max_dim);
109+
std::vector<int64_t> x_dims_array(max_dim);
110+
std::vector<int64_t> y_dims_array(max_dim);
111+
std::vector<int64_t> out_dims_array(max_dim);
112112
#ifdef PADDLE_WITH_DNNL
113113
// Broadcasting of dims has to be done on Paddle shapes (NHWC)
114114
// if model is using NHWC and any of shapes in at least 3D
@@ -120,8 +120,8 @@ class ElementwiseOp : public framework::OperatorWithKernel {
120120
if (should_rotate) {
121121
// Pick bigger shape and rotate this one
122122
bool x_over_y = (x_dims.size() > y_dims.size());
123-
auto vdims = x_over_y ? common::vectorize<int>(x_dims)
124-
: common::vectorize<int>(y_dims);
123+
auto vdims = x_over_y ? common::vectorize<int64_t>(x_dims)
124+
: common::vectorize<int64_t>(y_dims);
125125
std::rotate(vdims.begin() + 1, vdims.begin() + 2, vdims.end());
126126
if (x_over_y) {
127127
x_dims = common::make_ddim(vdims);

paddle/phi/infermeta/binary.cc

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -489,9 +489,9 @@ void CompareRawInferMeta(const MetaTensor& x,
489489
} else {
490490
int max_dim = std::max(dim_x.size(), dim_y.size());
491491
int axis = std::abs(dim_x.size() - dim_y.size());
492-
std::vector<int> x_dims_array(max_dim);
493-
std::vector<int> y_dims_array(max_dim);
494-
std::vector<int> out_dims_array(max_dim);
492+
std::vector<int64_t> x_dims_array(max_dim);
493+
std::vector<int64_t> y_dims_array(max_dim);
494+
std::vector<int64_t> out_dims_array(max_dim);
495495
funcs::GetBroadcastDimsArrays(dim_x,
496496
dim_y,
497497
x_dims_array.data(),
@@ -543,9 +543,9 @@ void ComplexInferMeta(const MetaTensor& x,
543543

544544
// start align axis
545545
int axis = std::abs(x_dims.size() - y_dims.size());
546-
std::vector<int> x_dims_array(max_dim);
547-
std::vector<int> y_dims_array(max_dim);
548-
std::vector<int> out_dims_array(max_dim);
546+
std::vector<int64_t> x_dims_array(max_dim);
547+
std::vector<int64_t> y_dims_array(max_dim);
548+
std::vector<int64_t> out_dims_array(max_dim);
549549
phi::funcs::GetBroadcastDimsArrays(x_dims,
550550
y_dims,
551551
x_dims_array.data(),
@@ -1690,9 +1690,9 @@ void ElementwiseRawInferMeta(const MetaTensor& x,
16901690
axis));
16911691
axis = (axis < 0 ? (std::abs(x_dims.size() - y_dims.size()) + axis + 1)
16921692
: axis);
1693-
std::vector<int> x_dims_array(max_dim);
1694-
std::vector<int> y_dims_array(max_dim);
1695-
std::vector<int> out_dims_array(max_dim);
1693+
std::vector<int64_t> x_dims_array(max_dim);
1694+
std::vector<int64_t> y_dims_array(max_dim);
1695+
std::vector<int64_t> out_dims_array(max_dim);
16961696

16971697
#ifdef PADDLE_WITH_DNNL
16981698
bool should_rotate =
@@ -1703,8 +1703,8 @@ void ElementwiseRawInferMeta(const MetaTensor& x,
17031703
if (should_rotate) {
17041704
// Pick bigger shape and rotate this one
17051705
bool x_over_y = (common::product(x_dims) > common::product(y_dims));
1706-
auto vdims = x_over_y ? common::vectorize<int>(x_dims)
1707-
: common::vectorize<int>(y_dims);
1706+
auto vdims = x_over_y ? common::vectorize<int64_t>(x_dims)
1707+
: common::vectorize<int64_t>(y_dims);
17081708
std::rotate(vdims.begin() + 1, vdims.begin() + 2, vdims.end());
17091709
if (x_over_y) {
17101710
x_dims = common::make_ddim(vdims);
@@ -3141,8 +3141,8 @@ void MatrixRankTolInferMeta(const MetaTensor& x,
31413141
"The dims of input must be greater than 2"));
31423142

31433143
if (hermitian) {
3144-
int rows = static_cast<int>(dim_x[dim_x.size() - 2]);
3145-
int cols = static_cast<int>(dim_x[dim_x.size() - 1]);
3144+
int64_t rows = static_cast<int64_t>(dim_x[dim_x.size() - 2]);
3145+
int64_t cols = static_cast<int64_t>(dim_x[dim_x.size() - 1]);
31463146
PADDLE_ENFORCE_EQ(rows,
31473147
cols,
31483148
common::errors::InvalidArgument(
@@ -3155,9 +3155,9 @@ void MatrixRankTolInferMeta(const MetaTensor& x,
31553155
} else {
31563156
int max_dim = std::max(dim_x_batch.size(), dim_tol.size());
31573157
int axis = std::abs(dim_x_batch.size() - dim_tol.size());
3158-
std::vector<int> x_batch_dims_array(max_dim);
3159-
std::vector<int> tol_dims_array(max_dim);
3160-
std::vector<int> out_dims_array(max_dim);
3158+
std::vector<int64_t> x_batch_dims_array(max_dim);
3159+
std::vector<int64_t> tol_dims_array(max_dim);
3160+
std::vector<int64_t> out_dims_array(max_dim);
31613161
phi::funcs::GetBroadcastDimsArrays(dim_x_batch,
31623162
dim_tol,
31633163
x_batch_dims_array.data(),

paddle/phi/infermeta/multiary.cc

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5153,16 +5153,16 @@ void SendUERecvInferMeta(const MetaTensor& x,
51535153

51545154
// Infer out's shape according to x and e(need broadcasting condition)
51555155
out->set_dtype(x.dtype());
5156-
auto x_dims1 = common::vectorize<int>(x_dims);
5157-
auto y_dims1 = common::vectorize<int>(y_dims);
5158-
std::vector<int> x_dims2(x_dims1.begin() + 1, x_dims1.end());
5159-
std::vector<int> y_dims2(y_dims1.begin() + 1, y_dims1.end());
5156+
auto x_dims1 = common::vectorize<int64_t>(x_dims);
5157+
auto y_dims1 = common::vectorize<int64_t>(y_dims);
5158+
std::vector<int64_t> x_dims2(x_dims1.begin() + 1, x_dims1.end());
5159+
std::vector<int64_t> y_dims2(y_dims1.begin() + 1, y_dims1.end());
51605160

51615161
int max_dim = static_cast<int>(std::max(x_dims2.size(), y_dims2.size()));
51625162
int axis = std::abs(static_cast<int>(x_dims2.size() - y_dims2.size()));
5163-
std::vector<int> x_dims_array(max_dim);
5164-
std::vector<int> y_dims_array(max_dim);
5165-
std::vector<int> out_dims_array(max_dim);
5163+
std::vector<int64_t> x_dims_array(max_dim);
5164+
std::vector<int64_t> y_dims_array(max_dim);
5165+
std::vector<int64_t> out_dims_array(max_dim);
51665166
// Only need to broadcast dimensions other than the 0th dimension.
51675167
phi::funcs::GetBroadcastDimsArrays(common::make_ddim(x_dims2),
51685168
common::make_ddim(y_dims2),
@@ -5224,15 +5224,15 @@ void SendUVInferMeta(const MetaTensor& x,
52245224
out->set_dtype(x.dtype());
52255225
auto x_dims = x.dims();
52265226
auto y_dims = y.dims();
5227-
auto x_dims1 = common::vectorize<int>(x_dims);
5228-
auto y_dims1 = common::vectorize<int>(y_dims);
5229-
std::vector<int> x_dims2(x_dims1.begin() + 1, x_dims1.end());
5230-
std::vector<int> y_dims2(y_dims1.begin() + 1, y_dims1.end());
5227+
auto x_dims1 = common::vectorize<int64_t>(x_dims);
5228+
auto y_dims1 = common::vectorize<int64_t>(y_dims);
5229+
std::vector<int64_t> x_dims2(x_dims1.begin() + 1, x_dims1.end());
5230+
std::vector<int64_t> y_dims2(y_dims1.begin() + 1, y_dims1.end());
52315231
int max_dim = static_cast<int>(std::max(x_dims2.size(), y_dims2.size()));
52325232
int axis = std::abs(static_cast<int>(x_dims2.size() - y_dims2.size()));
5233-
std::vector<int> x_dims_array(max_dim);
5234-
std::vector<int> y_dims_array(max_dim);
5235-
std::vector<int> out_dims_array(max_dim);
5233+
std::vector<int64_t> x_dims_array(max_dim);
5234+
std::vector<int64_t> y_dims_array(max_dim);
5235+
std::vector<int64_t> out_dims_array(max_dim);
52365236
// Only need to broadcast dimensions other than the 0th dimension.
52375237
phi::funcs::GetBroadcastDimsArrays(common::make_ddim(x_dims2),
52385238
common::make_ddim(y_dims2),

paddle/phi/kernels/funcs/common_infer_shape_functions.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ inline phi::DDim BroadcastTwoDims(const phi::DDim &x_dims,
2121
int axis = -1) {
2222
int max_dim = std::max(x_dims.size(), y_dims.size());
2323
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
24-
std::vector<int> x_dims_array(max_dim);
25-
std::vector<int> y_dims_array(max_dim);
26-
std::vector<int> out_dims_array(max_dim);
24+
std::vector<int64_t> x_dims_array(max_dim);
25+
std::vector<int64_t> y_dims_array(max_dim);
26+
std::vector<int64_t> out_dims_array(max_dim);
2727
GetBroadcastDimsArrays(x_dims,
2828
y_dims,
2929
x_dims_array.data(),

paddle/phi/kernels/funcs/common_shape.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@ inline void SetXShape(const DenseTensor &x, DenseTensor *xshape) {
3232
xshape->ResetLoD(x.meta().legacy_lod);
3333
}
3434

35+
template <typename T>
3536
inline void GetBroadcastDimsArrays(const DDim &x_dims,
3637
const DDim &y_dims,
37-
int *x_dims_array,
38-
int *y_dims_array,
39-
int *out_dims_array,
38+
T *x_dims_array,
39+
T *y_dims_array,
40+
T *out_dims_array,
4041
const int max_dim,
4142
const int axis) {
4243
PADDLE_ENFORCE_GE(

paddle/phi/kernels/funcs/elementwise_base.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -274,14 +274,14 @@ template <typename Functor, typename T, typename OutType = T>
274274
void CommonForwardBroadcastCPU(const DenseTensor &x,
275275
const DenseTensor &y,
276276
DenseTensor *z,
277-
int *x_dims_array,
278-
int *y_dims_array,
279-
int *out_dims_array,
277+
int64_t *x_dims_array,
278+
int64_t *y_dims_array,
279+
int64_t *out_dims_array,
280280
int max_dim,
281281
const CPUContext &ctx,
282282
Functor func,
283283
const bool is_xsize_larger = true) {
284-
std::vector<int> index_array(max_dim, 0);
284+
std::vector<int64_t> index_array(max_dim, 0);
285285
const T *x_data = x.data<T>();
286286
const T *y_data = y.data<T>();
287287
PADDLE_ENFORCE_NOT_NULL(
@@ -290,8 +290,10 @@ void CommonForwardBroadcastCPU(const DenseTensor &x,
290290
y_data, errors::InvalidArgument("The input Y should not be empty."));
291291
OutType *out_data = ctx.Alloc<OutType>(z);
292292

293-
const int out_size = std::accumulate(
294-
out_dims_array, out_dims_array + max_dim, 1, std::multiplies<int>());
293+
const int64_t out_size = std::accumulate(out_dims_array,
294+
out_dims_array + max_dim,
295+
1ll,
296+
std::multiplies<int64_t>());
295297
int x_index, y_index;
296298
for (int out_index = 0; out_index < out_size; ++out_index) {
297299
x_index = GetElementwiseIndex(x_dims_array, max_dim, index_array.data());
@@ -331,9 +333,9 @@ void CommonElementwiseBroadcastForward(const CPUContext &dev_ctx,
331333
"Axis should be less than or equal to %d, but received axis is %d.",
332334
max_dim,
333335
axis));
334-
std::vector<int> x_dims_array(max_dim);
335-
std::vector<int> y_dims_array(max_dim);
336-
std::vector<int> out_dims_array(max_dim);
336+
std::vector<int64_t> x_dims_array(max_dim);
337+
std::vector<int64_t> y_dims_array(max_dim);
338+
std::vector<int64_t> out_dims_array(max_dim);
337339
GetBroadcastDimsArrays(x_dims,
338340
y_dims,
339341
x_dims_array.data(),

paddle/phi/kernels/funcs/elementwise_utils.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,11 @@ inline DDim TrimTrailingSingularDims(const DDim &dims) {
9494
return actual_dims;
9595
}
9696

97-
inline int GetElementwiseIndex(const int *x_dims_array,
98-
const int max_dim,
99-
const int *index_array) {
100-
int index_ = 0;
97+
template <typename ShapeT = int>
98+
inline ShapeT GetElementwiseIndex(const ShapeT *x_dims_array,
99+
const int max_dim,
100+
const ShapeT *index_array) {
101+
ShapeT index_ = 0;
101102
for (int i = 0; i < max_dim; i++) {
102103
if (x_dims_array[i] > 1) {
103104
index_ = index_ * x_dims_array[i] + index_array[i];
@@ -106,9 +107,10 @@ inline int GetElementwiseIndex(const int *x_dims_array,
106107
return index_;
107108
}
108109

109-
inline void UpdateElementwiseIndexArray(const int *out_dims_array,
110+
template <typename ShapeT = int>
111+
inline void UpdateElementwiseIndexArray(const ShapeT *out_dims_array,
110112
const int max_dim,
111-
int *index_array) {
113+
ShapeT *index_array) {
112114
for (int i = max_dim - 1; i >= 0; --i) {
113115
++index_array[i];
114116
if (index_array[i] >= out_dims_array[i]) {

paddle/phi/kernels/fusion/xpu/add_layernorm_xpu_kernel.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ namespace fusion {
2525
static phi::DDim BroadCastInferShape(const DDim x_dims,
2626
const DDim y_dims,
2727
int axis) {
28-
std::vector<int> out_dims_array(x_dims.size(), -1);
28+
std::vector<int64_t> out_dims_array(x_dims.size(), -1);
2929
if (x_dims != y_dims) {
3030
int max_dim = std::max(x_dims.size(), y_dims.size());
3131
if (x_dims.size() == y_dims.size()) {
@@ -49,8 +49,8 @@ static phi::DDim BroadCastInferShape(const DDim x_dims,
4949
axis));
5050
axis = (axis < 0 ? (std::abs(x_dims.size() - y_dims.size()) + axis + 1)
5151
: axis);
52-
std::vector<int> x_dims_array(max_dim);
53-
std::vector<int> y_dims_array(max_dim);
52+
std::vector<int64_t> x_dims_array(max_dim);
53+
std::vector<int64_t> y_dims_array(max_dim);
5454
out_dims_array.resize(max_dim);
5555
phi::funcs::GetBroadcastDimsArrays(x_dims,
5656
y_dims,

paddle/phi/kernels/impl/graph_message_passing_impl.h

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,18 +87,19 @@ inline BroadCastInfo CalcBCastInfo(const phi::DDim& l_dims,
8787
return binfo;
8888
}
8989

90-
inline std::vector<int> InferBroadcastShape(const phi::DDim& x_dims,
91-
const phi::DDim& e_dims,
92-
const std::string& type = "x") {
93-
auto x_dims1 = common::vectorize<int>(x_dims);
94-
auto e_dims1 = common::vectorize<int>(e_dims);
95-
std::vector<int> x_dims2(x_dims1.begin() + 1, x_dims1.end());
96-
std::vector<int> e_dims2(e_dims1.begin() + 1, e_dims1.end());
90+
template <typename ShapeT = int64_t>
91+
inline std::vector<ShapeT> InferBroadcastShape(const phi::DDim& x_dims,
92+
const phi::DDim& e_dims,
93+
const std::string& type = "x") {
94+
auto x_dims1 = common::vectorize<ShapeT>(x_dims);
95+
auto e_dims1 = common::vectorize<ShapeT>(e_dims);
96+
std::vector<ShapeT> x_dims2(x_dims1.begin() + 1, x_dims1.end());
97+
std::vector<ShapeT> e_dims2(e_dims1.begin() + 1, e_dims1.end());
9798
int max_dim = std::max(x_dims2.size(), e_dims2.size());
9899
int axis = std::abs(static_cast<int>(x_dims2.size() - e_dims2.size()));
99-
std::vector<int> x_dims_array(max_dim);
100-
std::vector<int> e_dims_array(max_dim);
101-
std::vector<int> out_dims_array(max_dim);
100+
std::vector<ShapeT> x_dims_array(max_dim);
101+
std::vector<ShapeT> e_dims_array(max_dim);
102+
std::vector<ShapeT> out_dims_array(max_dim);
102103
// Only need to broadcast dimensions other than the 0th dimension.
103104
phi::funcs::GetBroadcastDimsArrays(common::make_ddim(x_dims2),
104105
common::make_ddim(e_dims2),

0 commit comments

Comments
 (0)