From 44b01a89e7aafc2cc76aed5f625fe7880d6549e6 Mon Sep 17 00:00:00 2001 From: zhanghonggeng Date: Thu, 8 May 2025 10:35:57 +0000 Subject: [PATCH] Use gather to replace gather nd when the bool index is one-dimensional --- paddle/fluid/pybind/slice_utils.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/paddle/fluid/pybind/slice_utils.h b/paddle/fluid/pybind/slice_utils.h index e540d2f4f5d5fb..3f59acb12ba26f 100644 --- a/paddle/fluid/pybind/slice_utils.h +++ b/paddle/fluid/pybind/slice_utils.h @@ -519,6 +519,12 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor, if (bool_index.shape().size() == tensor_shape.size()) { return masked_select_ad_func(tensor, bool_index); } + + if (bool_index.shape().size() == 1) { + auto bool_2_idx = nonzero_ad_func(bool_index); + return gather_ad_func(tensor, bool_2_idx); + } + auto bool_2_idx = nonzero_ad_func(bool_index); return gather_nd_ad_func(tensor, bool_2_idx); }