Skip to content

Commit 90dc213

Browse files
[phi] Migrate set_value CPU kernel to cpu directory and rename GPU kernel (#74037)
* Migrate set_value * Migrate set_value_grad * Review: use Macro
1 parent dcda866 commit 90dc213

File tree

7 files changed

+639
-742
lines changed

7 files changed

+639
-742
lines changed

paddle/phi/kernels/cpu/set_value_grad_kernel.cc

Lines changed: 342 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,349 @@
1616

1717
#include "paddle/phi/backends/cpu/cpu_context.h"
1818
#include "paddle/phi/common/complex.h"
19+
#include "paddle/phi/common/int_array.h"
20+
#include "paddle/phi/core/dense_tensor.h"
1921
#include "paddle/phi/core/kernel_registry.h"
20-
#include "paddle/phi/kernels/impl/set_value_grad_kernel_impl.h"
22+
#include "paddle/phi/core/tensor_utils.h"
23+
#include "paddle/phi/kernels/full_kernel.h"
24+
#include "paddle/phi/kernels/funcs/common_shape.h"
25+
#include "paddle/phi/kernels/funcs/eigen/common.h"
26+
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
27+
#include "paddle/phi/kernels/funcs/math_function.h"
28+
#include "paddle/phi/kernels/funcs/strided_slice.h"
29+
#include "paddle/phi/kernels/impl/share_data_kernel_impl.h"
30+
#include "paddle/phi/kernels/reduce_sum_kernel.h"
31+
#include "paddle/phi/kernels/reshape_kernel.h"
32+
33+
namespace phi {
34+
35+
inline void GetOffsets(const DDim& big_dim,
36+
const DDim& small_dim,
37+
DDim start_offset,
38+
int cur_dim,
39+
std::vector<DDim>* offsets) {
40+
if (cur_dim == big_dim.size()) {
41+
offsets->push_back(start_offset);
42+
return;
43+
}
44+
if (small_dim[cur_dim] == big_dim[cur_dim]) {
45+
GetOffsets(big_dim, small_dim, start_offset, cur_dim + 1, offsets);
46+
} else {
47+
for (int i = 0; i < big_dim[cur_dim]; i++) {
48+
GetOffsets(big_dim, small_dim, start_offset, cur_dim + 1, offsets);
49+
start_offset[cur_dim] += 1;
50+
}
51+
}
52+
}
53+
54+
template <typename T, typename Context, size_t RANK>
55+
void SetValueGradImpl(const Context& dev_ctx,
56+
const DenseTensor& out_grad,
57+
std::vector<int64_t>& starts_local, // NOLINT
58+
std::vector<int64_t>& ends_local, // NOLINT
59+
std::vector<int64_t>& steps_local, // NOLINT
60+
const std::vector<int64_t>& axes,
61+
const std::vector<int64_t>& decrease_axes,
62+
const std::vector<int64_t>& none_axes UNUSED,
63+
DenseTensor* x_grad,
64+
DenseTensor* value_grad) {
65+
PADDLE_ENFORCE_EQ(
66+
out_grad.IsInitialized(),
67+
true,
68+
errors::PermissionDenied(
69+
"The input of `set_value_grad`(out_grad) has not been initialized"));
70+
71+
auto in_dims = out_grad.dims();
72+
73+
std::vector<int> decrease_axis_int32(decrease_axes.begin(),
74+
decrease_axes.end());
75+
std::vector<int> axes_int32(axes.begin(), axes.end());
76+
std::vector<int> infer_flags(axes.size(), 1);
77+
std::vector<int64_t> out_dims_vector(in_dims.size(), -1);
78+
funcs::StridedSliceOutDims(starts_local,
79+
ends_local,
80+
steps_local,
81+
axes_int32,
82+
infer_flags,
83+
in_dims,
84+
decrease_axis_int32,
85+
out_dims_vector.data(),
86+
axes.size(),
87+
false);
88+
89+
DDim out_dims(common::make_ddim(out_dims_vector));
90+
91+
std::vector<int> reverse_vector(starts_local.size(), 0);
92+
funcs::StridedSliceFunctor(starts_local.data(),
93+
ends_local.data(),
94+
steps_local.data(),
95+
axes_int32.data(),
96+
reverse_vector.data(),
97+
in_dims,
98+
infer_flags,
99+
decrease_axis_int32,
100+
starts_local.size());
101+
102+
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>();
103+
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>();
104+
auto steps_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>();
105+
auto reverse_axis = Eigen::array<bool, RANK>();
106+
107+
for (size_t axis = 0; axis < RANK; axis++) {
108+
starts_indices[axis] = 0;
109+
ends_indices[axis] = out_dims[axis];
110+
steps_indices[axis] = 1;
111+
reverse_axis[axis] = false;
112+
}
113+
114+
for (size_t axis = 0; axis < axes.size(); axis++) {
115+
int axis_index = axes[axis];
116+
starts_indices[axis_index] = starts_local[axis];
117+
ends_indices[axis_index] = ends_local[axis];
118+
steps_indices[axis_index] = steps_local[axis];
119+
reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false;
120+
}
121+
122+
bool need_reverse = false;
123+
for (size_t axis = 0; axis < axes.size(); axis++) {
124+
if (reverse_vector[axis] == 1) {
125+
need_reverse = true;
126+
break;
127+
}
128+
}
129+
130+
auto& place = *dev_ctx.eigen_device();
131+
phi::funcs::SetConstant<Context, T> set_zero;
132+
133+
if (x_grad) {
134+
// Set gradient of `Input`
135+
Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
136+
137+
auto x_grad_t =
138+
EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From(*x_grad);
139+
140+
DenseTensor tmp = Full<T>(dev_ctx, out_dims_vector, static_cast<T>(0));
141+
auto tmp_t =
142+
EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From(tmp);
143+
144+
x_grad_t.stridedSlice(starts_indices, ends_indices, steps_indices)
145+
.device(place) = tmp_t;
146+
}
147+
if (value_grad) {
148+
dev_ctx.template Alloc<T>(value_grad);
149+
set_zero(dev_ctx, value_grad, static_cast<T>(0));
150+
151+
auto in_t = EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From(
152+
out_grad);
153+
154+
if (value_grad->dims() == out_dims) {
155+
auto value_grad_t =
156+
EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From(
157+
*value_grad);
158+
if (need_reverse) {
159+
DenseTensor tmp = Full<T>(dev_ctx, out_dims_vector, static_cast<T>(0));
160+
auto tmp_t =
161+
EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From(tmp);
162+
163+
tmp_t.device(place) =
164+
in_t.stridedSlice(starts_indices, ends_indices, steps_indices);
165+
value_grad_t.device(place) = tmp_t.reverse(reverse_axis);
166+
} else {
167+
value_grad_t.device(place) =
168+
in_t.stridedSlice(starts_indices, ends_indices, steps_indices);
169+
}
170+
} else {
171+
int out_dims_size = out_dims.size();
172+
auto value_grad_dims = value_grad->dims();
173+
auto fake_value_grad_dims = out_dims;
174+
175+
// Create an extended shape according to the rules of broadcast.
176+
auto value_grad_dims_size = value_grad_dims.size();
177+
178+
int num_decrease = 0;
179+
180+
int decrease_axis_size = decrease_axes.size();
181+
for (int i = 0; i < out_dims_size; i++) {
182+
if (decrease_axes.end() !=
183+
std::find(decrease_axes.begin(), decrease_axes.end(), i)) {
184+
fake_value_grad_dims[i] = 1;
185+
num_decrease++;
186+
} else if (i < out_dims_size - (value_grad_dims_size +
187+
decrease_axis_size - num_decrease)) {
188+
fake_value_grad_dims[i] = 1;
189+
} else {
190+
auto index_grad =
191+
i - (out_dims_size -
192+
(value_grad_dims_size + decrease_axis_size - num_decrease));
193+
fake_value_grad_dims[i] = value_grad_dims[index_grad];
194+
195+
PADDLE_ENFORCE_EQ(
196+
(out_dims[i] == value_grad_dims[index_grad]) ||
197+
(value_grad_dims[index_grad] == 1),
198+
true,
199+
errors::InvalidArgument("An error occurred while calculating %s: "
200+
"[%s] can not be accumulated into [%s].",
201+
"ValueTensor@GRAD",
202+
out_dims,
203+
value_grad_dims));
204+
}
205+
}
206+
207+
VLOG(3) << "Dimensions of "
208+
<< "ValueTensor@GRAD"
209+
<< "([" << value_grad_dims << "])is broadcasted into ["
210+
<< fake_value_grad_dims << "].";
211+
212+
auto extent = Eigen::DSizes<Eigen::DenseIndex, RANK>();
213+
auto offset = out_dims;
214+
for (int i = 0; i < out_dims_size; i++) {
215+
offset[i] = 0;
216+
extent[i] = fake_value_grad_dims[i];
217+
}
218+
std::vector<DDim> offsets;
219+
GetOffsets(out_dims, fake_value_grad_dims, offset, 0, &offsets);
220+
221+
auto value_grad_t =
222+
EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From(
223+
*value_grad, fake_value_grad_dims);
224+
225+
DenseTensor tmp = Full<T>(dev_ctx, out_dims_vector, static_cast<T>(0));
226+
auto tmp_t =
227+
EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From(tmp);
228+
229+
tmp_t.device(place) =
230+
in_t.stridedSlice(starts_indices, ends_indices, steps_indices);
231+
232+
// accumulate gradient
233+
for (auto offset : offsets) {
234+
value_grad_t.device(place) =
235+
value_grad_t + tmp_t.slice(EigenDim<RANK>::From(offset), extent);
236+
}
237+
if (need_reverse) {
238+
DenseTensor tmp_value =
239+
Full<T>(dev_ctx,
240+
{fake_value_grad_dims.Get(), fake_value_grad_dims.size()},
241+
static_cast<T>(0));
242+
auto tmp_value_t =
243+
EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From(
244+
tmp_value);
245+
tmp_value_t.device(place) = value_grad_t.reverse(reverse_axis);
246+
value_grad_t.device(place) = tmp_value_t;
247+
}
248+
}
249+
}
250+
}
251+
252+
template <typename T, typename Context>
253+
void SetValueGradKernel(const Context& dev_ctx,
254+
const DenseTensor& out_grad,
255+
const IntArray& starts,
256+
const IntArray& ends,
257+
const IntArray& steps,
258+
const std::vector<int64_t>& axes,
259+
const std::vector<int64_t>& decrease_axes,
260+
const std::vector<int64_t>& none_axes,
261+
DenseTensor* x_grad,
262+
DenseTensor* value_grad) {
263+
if (out_grad.numel() == 0) {
264+
if (x_grad) dev_ctx.template Alloc<T>(x_grad);
265+
if (value_grad) dev_ctx.template Alloc<T>(value_grad);
266+
return;
267+
}
268+
const int rank = out_grad.dims().size();
269+
std::vector<int64_t> starts_local = starts.GetData();
270+
std::vector<int64_t> ends_local = ends.GetData();
271+
std::vector<int64_t> steps_local = steps.GetData();
272+
273+
bool ellipsis_flag = true;
274+
for (size_t i = 0; i < axes.size(); i++) {
275+
auto idx = axes[i];
276+
if (!(starts_local[i] == 0 && ends_local[i] == out_grad.dims()[idx] &&
277+
steps_local[i] == 1)) {
278+
ellipsis_flag = false;
279+
}
280+
}
281+
282+
if (ellipsis_flag) {
283+
if (x_grad) {
284+
FullKernel<T, Context>(dev_ctx,
285+
common::vectorize(x_grad->dims()),
286+
Scalar(0),
287+
x_grad->dtype(),
288+
x_grad);
289+
}
290+
if (value_grad) {
291+
if (value_grad->numel() == out_grad.numel()) {
292+
if (value_grad->dims() != out_grad.dims()) {
293+
DenseTensor out_grad_temp;
294+
ShareDataKernel<T, Context>(dev_ctx, out_grad, &out_grad_temp);
295+
out_grad_temp.Resize(value_grad->dims());
296+
Copy(dev_ctx, out_grad_temp, dev_ctx.GetPlace(), false, value_grad);
297+
} else {
298+
Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, value_grad);
299+
}
300+
} else {
301+
auto reduce_dim = phi::funcs::GetReduceDims(out_grad, *value_grad);
302+
SumKernel<T, Context>(
303+
dev_ctx, out_grad, reduce_dim, out_grad.dtype(), false, value_grad);
304+
}
305+
}
306+
return;
307+
}
308+
309+
switch (rank) {
310+
#define CASE_RANK(__Rk) \
311+
case __Rk: \
312+
SetValueGradImpl<T, Context, __Rk>(dev_ctx, \
313+
out_grad, \
314+
starts_local, \
315+
ends_local, \
316+
steps_local, \
317+
axes, \
318+
decrease_axes, \
319+
none_axes, \
320+
x_grad, \
321+
value_grad); \
322+
break;
323+
CASE_RANK(1);
324+
CASE_RANK(2);
325+
CASE_RANK(3);
326+
CASE_RANK(4);
327+
CASE_RANK(5);
328+
CASE_RANK(6);
329+
#undef CASE_RANK
330+
default:
331+
PADDLE_THROW(common::errors::InvalidArgument(
332+
"The rank of set_value_grad's input should be less than 7, but "
333+
"received %d.",
334+
rank));
335+
}
336+
return;
337+
}
338+
339+
template <typename T, typename Context>
340+
void SetValueWithScalarGradKernel(const Context& dev_ctx,
341+
const DenseTensor& out_grad,
342+
const IntArray& starts,
343+
const IntArray& ends,
344+
const IntArray& steps,
345+
const std::vector<int64_t>& axes,
346+
const std::vector<int64_t>& decrease_axes,
347+
const std::vector<int64_t>& none_axes,
348+
DenseTensor* x_grad) {
349+
SetValueGradKernel<T, Context>(dev_ctx,
350+
out_grad,
351+
starts,
352+
ends,
353+
steps,
354+
axes,
355+
decrease_axes,
356+
none_axes,
357+
x_grad,
358+
nullptr);
359+
}
360+
361+
} // namespace phi
21362

22363
PD_REGISTER_KERNEL(set_value_grad,
23364
CPU,

0 commit comments

Comments
 (0)