diff --git a/paddle/fluid/pybind/slice_utils.h b/paddle/fluid/pybind/slice_utils.h index e540d2f4f5d5fb..3f59acb12ba26f 100644 --- a/paddle/fluid/pybind/slice_utils.h +++ b/paddle/fluid/pybind/slice_utils.h @@ -519,6 +519,12 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor, if (bool_index.shape().size() == tensor_shape.size()) { return masked_select_ad_func(tensor, bool_index); } + + if (bool_index.shape().size() == 1) { + auto bool_2_idx = nonzero_ad_func(bool_index); + return gather_ad_func(tensor, bool_2_idx); + } + auto bool_2_idx = nonzero_ad_func(bool_index); return gather_nd_ad_func(tensor, bool_2_idx); }