Skip to content

Commit 86d658f

Browse files
[0-size Tensor Job2 No.21-23] Add 0-size Tensor support for send_u_recv (#73806)
* fix send_uev func 0-size * Fix the issue of automatically throwing exceptions * Fix2 * add test of send_uev_recv func * fix test * fix : add PADDLE_ENFORCE_EQ
1 parent 61b91dc commit 86d658f

17 files changed

+494
-12
lines changed

paddle/phi/infermeta/multiary.cc

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5331,10 +5331,13 @@ void SendUERecvInferMeta(const MetaTensor& x,
53315331
dst_index_dims.size()));
53325332
}
53335333

5334-
PADDLE_ENFORCE_EQ(src_index_dims[0],
5335-
dst_index_dims[0],
5336-
common::errors::InvalidArgument(
5337-
"Src_index and Dst_index should have the same shape."));
5334+
if (src_index_dims[0] != 0) {
5335+
PADDLE_ENFORCE_EQ(
5336+
src_index_dims[0],
5337+
dst_index_dims[0],
5338+
common::errors::InvalidArgument(
5339+
"Src_index and Dst_index should have the same shape."));
5340+
}
53385341

53395342
auto y_dims = y.dims();
53405343
PADDLE_ENFORCE_EQ(
@@ -5416,10 +5419,13 @@ void SendUVInferMeta(const MetaTensor& x,
54165419
dst_index_dims.size()));
54175420
}
54185421

5419-
PADDLE_ENFORCE_EQ(src_index_dims[0],
5420-
dst_index_dims[0],
5421-
common::errors::InvalidArgument(
5422-
"Src_index and Dst_index should have the same shape."));
5422+
if (src_index_dims[0] != 0) {
5423+
PADDLE_ENFORCE_EQ(
5424+
src_index_dims[0],
5425+
dst_index_dims[0],
5426+
common::errors::InvalidArgument(
5427+
"Src_index and Dst_index should have the same shape."));
5428+
}
54235429

54245430
// Infer out's shape according to x and y(need broadcasting condition)
54255431
out->set_dtype(x.dtype());

paddle/phi/infermeta/ternary.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2517,10 +2517,13 @@ void SendURecvInferMeta(const MetaTensor& x,
25172517
dst_index_dims.size()));
25182518
}
25192519

2520-
PADDLE_ENFORCE_EQ(src_index_dims[0],
2521-
dst_index_dims[0],
2522-
common::errors::InvalidArgument(
2523-
"Src_index and Dst_index should have the same shape."));
2520+
if (src_index_dims[0] != 0) {
2521+
PADDLE_ENFORCE_EQ(
2522+
src_index_dims[0],
2523+
dst_index_dims[0],
2524+
common::errors::InvalidArgument(
2525+
"Src_index and Dst_index should have the same shape."));
2526+
}
25242527

25252528
auto dims = x.dims();
25262529
std::vector<int64_t> dims_ = common::vectorize(dims);

paddle/phi/kernels/cpu/send_u_recv_grad_kernel.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include "paddle/phi/core/kernel_registry.h"
2121
#include "paddle/phi/kernels/cpu/graph_send_recv_funcs.h"
22+
#include "paddle/phi/kernels/full_kernel.h"
2223

2324
namespace phi {
2425

@@ -128,6 +129,14 @@ void SendURecvGradKernel(const Context& dev_ctx,
128129
const std::string& reduce_op,
129130
DenseTensor* x_grad) {
130131
auto index_type = src_index.dtype();
132+
133+
if (out_grad.numel() == 0 || x.numel() == 0 || src_index.numel() == 0 ||
134+
dst_index.numel() == 0) {
135+
phi::Full<T, Context>(
136+
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
137+
return;
138+
}
139+
131140
if (index_type == phi::DataType::INT32) {
132141
GraphSendRecvGradOpKernelLaunchHelper<Context, T, int32_t>(
133142
dev_ctx,

paddle/phi/kernels/cpu/send_u_recv_kernel.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "paddle/phi/backends/cpu/cpu_context.h"
2323
#include "paddle/phi/core/kernel_registry.h"
2424
#include "paddle/phi/kernels/cpu/graph_send_recv_funcs.h"
25+
#include "paddle/phi/kernels/full_kernel.h"
2526

2627
namespace phi {
2728

@@ -154,6 +155,28 @@ void SendURecvKernel(const Context& dev_ctx,
154155
DenseTensor* dst_count) {
155156
auto index_type = src_index.dtype();
156157
auto& out_size_data = out_size.GetData();
158+
159+
if (x.numel() == 0 || src_index.numel() == 0 || dst_index.numel() == 0) {
160+
if (out_size_data[0] <= 0) {
161+
out->Resize(x.dims());
162+
} else {
163+
out->Resize(common::make_ddim(out_size_data));
164+
}
165+
if (reduce_op == "MEAN") {
166+
int64_t input_size =
167+
out_size_data[0] <= 0 ? x.dims()[0] : out_size_data[0];
168+
dst_count->Resize({input_size});
169+
}
170+
phi::Full<T, Context>(
171+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
172+
phi::Full<int32_t, Context>(
173+
dev_ctx,
174+
phi::IntArray(common::vectorize(dst_count->dims())),
175+
0,
176+
dst_count);
177+
return;
178+
}
179+
157180
if (index_type == phi::DataType::INT32) {
158181
GraphSendRecvOpKernelLaunchHelper<Context, T, int32_t>(dev_ctx,
159182
x,

paddle/phi/kernels/cpu/send_ue_recv_grad_kernel.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "paddle/phi/kernels/cpu/graph_send_recv_funcs.h"
2424
#include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h"
2525
#include "paddle/phi/kernels/empty_kernel.h"
26+
#include "paddle/phi/kernels/full_kernel.h"
2627
#include "paddle/phi/kernels/funcs/math_function.h"
2728
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
2829
#include "paddle/phi/kernels/reduce_sum_kernel.h"
@@ -458,6 +459,16 @@ void SendUERecvGradKernel(const Context& dev_ctx,
458459
DenseTensor* x_grad,
459460
DenseTensor* y_grad) {
460461
auto index_type = src_index.dtype();
462+
463+
if (out_grad.numel() == 0 || x.numel() == 0 || y.numel() == 0 ||
464+
src_index.numel() == 0 || dst_index.numel() == 0) {
465+
phi::Full<T, Context>(
466+
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
467+
phi::Full<T, Context>(
468+
dev_ctx, phi::IntArray(common::vectorize(y_grad->dims())), 0, y_grad);
469+
return;
470+
}
471+
461472
if (index_type == phi::DataType::INT32) {
462473
GraphSendUERecvGradOpKernelLaunchHelper<Context, T, int32_t>(
463474
dev_ctx,

paddle/phi/kernels/cpu/send_ue_recv_kernel.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "paddle/phi/backends/cpu/cpu_context.h"
2323
#include "paddle/phi/core/kernel_registry.h"
2424
#include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h"
25+
#include "paddle/phi/kernels/full_kernel.h"
2526
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
2627

2728
namespace phi {
@@ -256,6 +257,30 @@ void SendUERecvKernel(const Context& dev_ctx,
256257
DenseTensor* dst_count) {
257258
auto index_type = src_index.dtype();
258259
auto& out_size_data = out_size.GetData();
260+
261+
if (x.numel() == 0 || y.numel() == 0 || src_index.numel() == 0 ||
262+
dst_index.numel() == 0) {
263+
std::vector<int64_t> dims_ = common::vectorize(out->dims());
264+
if (out_size_data[0] <= 0) {
265+
dims_[0] = x.dims()[0];
266+
} else {
267+
dims_[0] = out_size_data[0];
268+
}
269+
if (reduce_op == "MEAN") {
270+
int64_t input_size =
271+
out_size_data[0] <= 0 ? x.dims()[0] : out_size_data[0];
272+
dst_count->Resize({input_size});
273+
}
274+
out->Resize(common::make_ddim(dims_));
275+
phi::Full<T, Context>(
276+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
277+
phi::Full<int, Context>(dev_ctx,
278+
phi::IntArray(common::vectorize(dst_count->dims())),
279+
0,
280+
dst_count);
281+
return;
282+
}
283+
259284
if (index_type == phi::DataType::INT32) {
260285
GraphSendUERecvOpKernelLaunchHelper<Context, T, int32_t>(dev_ctx,
261286
x,

paddle/phi/kernels/cpu/send_uv_grad_kernel.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "paddle/phi/backends/cpu/cpu_context.h"
1919
#include "paddle/phi/core/kernel_registry.h"
2020
#include "paddle/phi/kernels/empty_kernel.h"
21+
#include "paddle/phi/kernels/full_kernel.h"
2122
#include "paddle/phi/kernels/funcs/math_function.h"
2223
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
2324
#include "paddle/phi/kernels/reduce_sum_kernel.h"
@@ -241,6 +242,16 @@ void SendUVGradKernel(const Context& dev_ctx,
241242
DenseTensor* x_grad,
242243
DenseTensor* y_grad) {
243244
auto index_type = src_index.dtype();
245+
246+
if (out_grad.numel() == 0 || x.numel() == 0 || y.numel() == 0 ||
247+
src_index.numel() == 0 || dst_index.numel() == 0) {
248+
phi::Full<T, Context>(
249+
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
250+
phi::Full<T, Context>(
251+
dev_ctx, phi::IntArray(common::vectorize(y_grad->dims())), 0, y_grad);
252+
return;
253+
}
254+
244255
if (index_type == phi::DataType::INT32) {
245256
GraphSendUVGradOpKernelLaunchHelper<Context, T, int32_t>(dev_ctx,
246257
x,

paddle/phi/kernels/cpu/send_uv_kernel.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "paddle/phi/backends/cpu/cpu_context.h"
1919
#include "paddle/phi/core/kernel_registry.h"
2020
#include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h"
21+
#include "paddle/phi/kernels/full_kernel.h"
2122
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
2223

2324
namespace phi {
@@ -105,6 +106,14 @@ void SendUVKernel(const Context& dev_ctx,
105106
const std::string& message_op,
106107
DenseTensor* out) {
107108
auto index_type = src_index.dtype();
109+
110+
if (x.numel() == 0 || y.numel() == 0 || src_index.numel() == 0 ||
111+
dst_index.numel() == 0) {
112+
phi::Full<T, Context>(
113+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
114+
return;
115+
}
116+
108117
if (index_type == phi::DataType::INT32) {
109118
GraphSendUVOpKernelLaunchHelper<Context, T, int32_t>(
110119
dev_ctx, x, y, src_index, dst_index, message_op, out);

paddle/phi/kernels/gpu/send_u_recv_grad_kernel.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "paddle/common/hostdevice.h"
2121
#include "paddle/phi/backends/gpu/gpu_context.h"
2222
#include "paddle/phi/core/kernel_registry.h"
23+
#include "paddle/phi/kernels/full_kernel.h"
2324
#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h"
2425

2526
namespace phi {
@@ -105,6 +106,14 @@ void SendURecvGradKernel(const Context& dev_ctx,
105106
const std::string& reduce_op,
106107
DenseTensor* x_grad) {
107108
auto index_type = src_index.dtype();
109+
110+
if (out_grad.numel() == 0 || x.numel() == 0 || src_index.numel() == 0 ||
111+
dst_index.numel() == 0) {
112+
phi::Full<T, Context>(
113+
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
114+
return;
115+
}
116+
108117
if (index_type == phi::DataType::INT32) {
109118
GraphSendRecvGradOpCUDAKernelLaunchHelper<Context, T, int32_t>(
110119
dev_ctx,

paddle/phi/kernels/gpu/send_u_recv_kernel.cu

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "paddle/common/hostdevice.h"
2525
#include "paddle/phi/backends/gpu/gpu_context.h"
2626
#include "paddle/phi/core/kernel_registry.h"
27+
#include "paddle/phi/kernels/full_kernel.h"
2728
#include "paddle/phi/kernels/funcs/math_function.h"
2829
#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h"
2930

@@ -152,6 +153,28 @@ void SendURecvKernel(const Context& dev_ctx,
152153
DenseTensor* dst_count) {
153154
auto index_type = src_index.dtype();
154155
auto& out_size_data = out_size.GetData();
156+
157+
if (x.numel() == 0 || src_index.numel() == 0 || dst_index.numel() == 0) {
158+
if (out_size_data[0] <= 0) {
159+
out->Resize(x.dims());
160+
} else {
161+
out->Resize(common::make_ddim(out_size_data));
162+
}
163+
if (reduce_op == "MEAN") {
164+
int64_t input_size =
165+
out_size_data[0] <= 0 ? x.dims()[0] : out_size_data[0];
166+
dst_count->Resize({input_size});
167+
}
168+
phi::Full<T, Context>(
169+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
170+
phi::Full<int32_t, Context>(
171+
dev_ctx,
172+
phi::IntArray(common::vectorize(dst_count->dims())),
173+
0,
174+
dst_count);
175+
return;
176+
}
177+
155178
if (index_type == phi::DataType::INT32) {
156179
GraphSendRecvOpCUDAKernelLaunchHelper<Context, T, int32_t>(dev_ctx,
157180
x,

0 commit comments

Comments
 (0)