Skip to content

Commit f51e3ff

Browse files
Use gather to replace gather nd when the bool index is one-dimensional (#72625)
1 parent 754f949 commit f51e3ff

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

paddle/fluid/pybind/slice_utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,12 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
519519
if (bool_index.shape().size() == tensor_shape.size()) {
520520
return masked_select_ad_func(tensor, bool_index);
521521
}
522+
523+
if (bool_index.shape().size() == 1) {
524+
auto bool_2_idx = nonzero_ad_func(bool_index);
525+
return gather_ad_func(tensor, bool_2_idx);
526+
}
527+
522528
auto bool_2_idx = nonzero_ad_func(bool_index);
523529
return gather_nd_ad_func(tensor, bool_2_idx);
524530
}

0 commit comments

Comments
 (0)