|
16 | 16 |
|
17 | 17 | #include "paddle/phi/backends/cpu/cpu_context.h"
|
18 | 18 | #include "paddle/phi/common/complex.h"
|
| 19 | +#include "paddle/phi/common/int_array.h" |
| 20 | +#include "paddle/phi/core/dense_tensor.h" |
19 | 21 | #include "paddle/phi/core/kernel_registry.h"
|
20 |
| -#include "paddle/phi/kernels/impl/set_value_grad_kernel_impl.h" |
| 22 | +#include "paddle/phi/core/tensor_utils.h" |
| 23 | +#include "paddle/phi/kernels/full_kernel.h" |
| 24 | +#include "paddle/phi/kernels/funcs/common_shape.h" |
| 25 | +#include "paddle/phi/kernels/funcs/eigen/common.h" |
| 26 | +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" |
| 27 | +#include "paddle/phi/kernels/funcs/math_function.h" |
| 28 | +#include "paddle/phi/kernels/funcs/strided_slice.h" |
| 29 | +#include "paddle/phi/kernels/impl/share_data_kernel_impl.h" |
| 30 | +#include "paddle/phi/kernels/reduce_sum_kernel.h" |
| 31 | +#include "paddle/phi/kernels/reshape_kernel.h" |
| 32 | + |
| 33 | +namespace phi { |
| 34 | + |
| 35 | +inline void GetOffsets(const DDim& big_dim, |
| 36 | + const DDim& small_dim, |
| 37 | + DDim start_offset, |
| 38 | + int cur_dim, |
| 39 | + std::vector<DDim>* offsets) { |
| 40 | + if (cur_dim == big_dim.size()) { |
| 41 | + offsets->push_back(start_offset); |
| 42 | + return; |
| 43 | + } |
| 44 | + if (small_dim[cur_dim] == big_dim[cur_dim]) { |
| 45 | + GetOffsets(big_dim, small_dim, start_offset, cur_dim + 1, offsets); |
| 46 | + } else { |
| 47 | + for (int i = 0; i < big_dim[cur_dim]; i++) { |
| 48 | + GetOffsets(big_dim, small_dim, start_offset, cur_dim + 1, offsets); |
| 49 | + start_offset[cur_dim] += 1; |
| 50 | + } |
| 51 | + } |
| 52 | +} |
| 53 | + |
| 54 | +template <typename T, typename Context, size_t RANK> |
| 55 | +void SetValueGradImpl(const Context& dev_ctx, |
| 56 | + const DenseTensor& out_grad, |
| 57 | + std::vector<int64_t>& starts_local, // NOLINT |
| 58 | + std::vector<int64_t>& ends_local, // NOLINT |
| 59 | + std::vector<int64_t>& steps_local, // NOLINT |
| 60 | + const std::vector<int64_t>& axes, |
| 61 | + const std::vector<int64_t>& decrease_axes, |
| 62 | + const std::vector<int64_t>& none_axes UNUSED, |
| 63 | + DenseTensor* x_grad, |
| 64 | + DenseTensor* value_grad) { |
| 65 | + PADDLE_ENFORCE_EQ( |
| 66 | + out_grad.IsInitialized(), |
| 67 | + true, |
| 68 | + errors::PermissionDenied( |
| 69 | + "The input of `set_value_grad`(out_grad) has not been initialized")); |
| 70 | + |
| 71 | + auto in_dims = out_grad.dims(); |
| 72 | + |
| 73 | + std::vector<int> decrease_axis_int32(decrease_axes.begin(), |
| 74 | + decrease_axes.end()); |
| 75 | + std::vector<int> axes_int32(axes.begin(), axes.end()); |
| 76 | + std::vector<int> infer_flags(axes.size(), 1); |
| 77 | + std::vector<int64_t> out_dims_vector(in_dims.size(), -1); |
| 78 | + funcs::StridedSliceOutDims(starts_local, |
| 79 | + ends_local, |
| 80 | + steps_local, |
| 81 | + axes_int32, |
| 82 | + infer_flags, |
| 83 | + in_dims, |
| 84 | + decrease_axis_int32, |
| 85 | + out_dims_vector.data(), |
| 86 | + axes.size(), |
| 87 | + false); |
| 88 | + |
| 89 | + DDim out_dims(common::make_ddim(out_dims_vector)); |
| 90 | + |
| 91 | + std::vector<int> reverse_vector(starts_local.size(), 0); |
| 92 | + funcs::StridedSliceFunctor(starts_local.data(), |
| 93 | + ends_local.data(), |
| 94 | + steps_local.data(), |
| 95 | + axes_int32.data(), |
| 96 | + reverse_vector.data(), |
| 97 | + in_dims, |
| 98 | + infer_flags, |
| 99 | + decrease_axis_int32, |
| 100 | + starts_local.size()); |
| 101 | + |
| 102 | + auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>(); |
| 103 | + auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>(); |
| 104 | + auto steps_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>(); |
| 105 | + auto reverse_axis = Eigen::array<bool, RANK>(); |
| 106 | + |
| 107 | + for (size_t axis = 0; axis < RANK; axis++) { |
| 108 | + starts_indices[axis] = 0; |
| 109 | + ends_indices[axis] = out_dims[axis]; |
| 110 | + steps_indices[axis] = 1; |
| 111 | + reverse_axis[axis] = false; |
| 112 | + } |
| 113 | + |
| 114 | + for (size_t axis = 0; axis < axes.size(); axis++) { |
| 115 | + int axis_index = axes[axis]; |
| 116 | + starts_indices[axis_index] = starts_local[axis]; |
| 117 | + ends_indices[axis_index] = ends_local[axis]; |
| 118 | + steps_indices[axis_index] = steps_local[axis]; |
| 119 | + reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false; |
| 120 | + } |
| 121 | + |
| 122 | + bool need_reverse = false; |
| 123 | + for (size_t axis = 0; axis < axes.size(); axis++) { |
| 124 | + if (reverse_vector[axis] == 1) { |
| 125 | + need_reverse = true; |
| 126 | + break; |
| 127 | + } |
| 128 | + } |
| 129 | + |
| 130 | + auto& place = *dev_ctx.eigen_device(); |
| 131 | + phi::funcs::SetConstant<Context, T> set_zero; |
| 132 | + |
| 133 | + if (x_grad) { |
| 134 | + // Set gradient of `Input` |
| 135 | + Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); |
| 136 | + |
| 137 | + auto x_grad_t = |
| 138 | + EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From(*x_grad); |
| 139 | + |
| 140 | + DenseTensor tmp = Full<T>(dev_ctx, out_dims_vector, static_cast<T>(0)); |
| 141 | + auto tmp_t = |
| 142 | + EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From(tmp); |
| 143 | + |
| 144 | + x_grad_t.stridedSlice(starts_indices, ends_indices, steps_indices) |
| 145 | + .device(place) = tmp_t; |
| 146 | + } |
| 147 | + if (value_grad) { |
| 148 | + dev_ctx.template Alloc<T>(value_grad); |
| 149 | + set_zero(dev_ctx, value_grad, static_cast<T>(0)); |
| 150 | + |
| 151 | + auto in_t = EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From( |
| 152 | + out_grad); |
| 153 | + |
| 154 | + if (value_grad->dims() == out_dims) { |
| 155 | + auto value_grad_t = |
| 156 | + EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From( |
| 157 | + *value_grad); |
| 158 | + if (need_reverse) { |
| 159 | + DenseTensor tmp = Full<T>(dev_ctx, out_dims_vector, static_cast<T>(0)); |
| 160 | + auto tmp_t = |
| 161 | + EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From(tmp); |
| 162 | + |
| 163 | + tmp_t.device(place) = |
| 164 | + in_t.stridedSlice(starts_indices, ends_indices, steps_indices); |
| 165 | + value_grad_t.device(place) = tmp_t.reverse(reverse_axis); |
| 166 | + } else { |
| 167 | + value_grad_t.device(place) = |
| 168 | + in_t.stridedSlice(starts_indices, ends_indices, steps_indices); |
| 169 | + } |
| 170 | + } else { |
| 171 | + int out_dims_size = out_dims.size(); |
| 172 | + auto value_grad_dims = value_grad->dims(); |
| 173 | + auto fake_value_grad_dims = out_dims; |
| 174 | + |
| 175 | + // Create an extended shape according to the rules of broadcast. |
| 176 | + auto value_grad_dims_size = value_grad_dims.size(); |
| 177 | + |
| 178 | + int num_decrease = 0; |
| 179 | + |
| 180 | + int decrease_axis_size = decrease_axes.size(); |
| 181 | + for (int i = 0; i < out_dims_size; i++) { |
| 182 | + if (decrease_axes.end() != |
| 183 | + std::find(decrease_axes.begin(), decrease_axes.end(), i)) { |
| 184 | + fake_value_grad_dims[i] = 1; |
| 185 | + num_decrease++; |
| 186 | + } else if (i < out_dims_size - (value_grad_dims_size + |
| 187 | + decrease_axis_size - num_decrease)) { |
| 188 | + fake_value_grad_dims[i] = 1; |
| 189 | + } else { |
| 190 | + auto index_grad = |
| 191 | + i - (out_dims_size - |
| 192 | + (value_grad_dims_size + decrease_axis_size - num_decrease)); |
| 193 | + fake_value_grad_dims[i] = value_grad_dims[index_grad]; |
| 194 | + |
| 195 | + PADDLE_ENFORCE_EQ( |
| 196 | + (out_dims[i] == value_grad_dims[index_grad]) || |
| 197 | + (value_grad_dims[index_grad] == 1), |
| 198 | + true, |
| 199 | + errors::InvalidArgument("An error occurred while calculating %s: " |
| 200 | + "[%s] can not be accumulated into [%s].", |
| 201 | + "ValueTensor@GRAD", |
| 202 | + out_dims, |
| 203 | + value_grad_dims)); |
| 204 | + } |
| 205 | + } |
| 206 | + |
| 207 | + VLOG(3) << "Dimensions of " |
| 208 | + << "ValueTensor@GRAD" |
| 209 | + << "([" << value_grad_dims << "])is broadcasted into [" |
| 210 | + << fake_value_grad_dims << "]."; |
| 211 | + |
| 212 | + auto extent = Eigen::DSizes<Eigen::DenseIndex, RANK>(); |
| 213 | + auto offset = out_dims; |
| 214 | + for (int i = 0; i < out_dims_size; i++) { |
| 215 | + offset[i] = 0; |
| 216 | + extent[i] = fake_value_grad_dims[i]; |
| 217 | + } |
| 218 | + std::vector<DDim> offsets; |
| 219 | + GetOffsets(out_dims, fake_value_grad_dims, offset, 0, &offsets); |
| 220 | + |
| 221 | + auto value_grad_t = |
| 222 | + EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From( |
| 223 | + *value_grad, fake_value_grad_dims); |
| 224 | + |
| 225 | + DenseTensor tmp = Full<T>(dev_ctx, out_dims_vector, static_cast<T>(0)); |
| 226 | + auto tmp_t = |
| 227 | + EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From(tmp); |
| 228 | + |
| 229 | + tmp_t.device(place) = |
| 230 | + in_t.stridedSlice(starts_indices, ends_indices, steps_indices); |
| 231 | + |
| 232 | + // accumulate gradient |
| 233 | + for (auto offset : offsets) { |
| 234 | + value_grad_t.device(place) = |
| 235 | + value_grad_t + tmp_t.slice(EigenDim<RANK>::From(offset), extent); |
| 236 | + } |
| 237 | + if (need_reverse) { |
| 238 | + DenseTensor tmp_value = |
| 239 | + Full<T>(dev_ctx, |
| 240 | + {fake_value_grad_dims.Get(), fake_value_grad_dims.size()}, |
| 241 | + static_cast<T>(0)); |
| 242 | + auto tmp_value_t = |
| 243 | + EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From( |
| 244 | + tmp_value); |
| 245 | + tmp_value_t.device(place) = value_grad_t.reverse(reverse_axis); |
| 246 | + value_grad_t.device(place) = tmp_value_t; |
| 247 | + } |
| 248 | + } |
| 249 | + } |
| 250 | +} |
| 251 | + |
| 252 | +template <typename T, typename Context> |
| 253 | +void SetValueGradKernel(const Context& dev_ctx, |
| 254 | + const DenseTensor& out_grad, |
| 255 | + const IntArray& starts, |
| 256 | + const IntArray& ends, |
| 257 | + const IntArray& steps, |
| 258 | + const std::vector<int64_t>& axes, |
| 259 | + const std::vector<int64_t>& decrease_axes, |
| 260 | + const std::vector<int64_t>& none_axes, |
| 261 | + DenseTensor* x_grad, |
| 262 | + DenseTensor* value_grad) { |
| 263 | + if (out_grad.numel() == 0) { |
| 264 | + if (x_grad) dev_ctx.template Alloc<T>(x_grad); |
| 265 | + if (value_grad) dev_ctx.template Alloc<T>(value_grad); |
| 266 | + return; |
| 267 | + } |
| 268 | + const int rank = out_grad.dims().size(); |
| 269 | + std::vector<int64_t> starts_local = starts.GetData(); |
| 270 | + std::vector<int64_t> ends_local = ends.GetData(); |
| 271 | + std::vector<int64_t> steps_local = steps.GetData(); |
| 272 | + |
| 273 | + bool ellipsis_flag = true; |
| 274 | + for (size_t i = 0; i < axes.size(); i++) { |
| 275 | + auto idx = axes[i]; |
| 276 | + if (!(starts_local[i] == 0 && ends_local[i] == out_grad.dims()[idx] && |
| 277 | + steps_local[i] == 1)) { |
| 278 | + ellipsis_flag = false; |
| 279 | + } |
| 280 | + } |
| 281 | + |
| 282 | + if (ellipsis_flag) { |
| 283 | + if (x_grad) { |
| 284 | + FullKernel<T, Context>(dev_ctx, |
| 285 | + common::vectorize(x_grad->dims()), |
| 286 | + Scalar(0), |
| 287 | + x_grad->dtype(), |
| 288 | + x_grad); |
| 289 | + } |
| 290 | + if (value_grad) { |
| 291 | + if (value_grad->numel() == out_grad.numel()) { |
| 292 | + if (value_grad->dims() != out_grad.dims()) { |
| 293 | + DenseTensor out_grad_temp; |
| 294 | + ShareDataKernel<T, Context>(dev_ctx, out_grad, &out_grad_temp); |
| 295 | + out_grad_temp.Resize(value_grad->dims()); |
| 296 | + Copy(dev_ctx, out_grad_temp, dev_ctx.GetPlace(), false, value_grad); |
| 297 | + } else { |
| 298 | + Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, value_grad); |
| 299 | + } |
| 300 | + } else { |
| 301 | + auto reduce_dim = phi::funcs::GetReduceDims(out_grad, *value_grad); |
| 302 | + SumKernel<T, Context>( |
| 303 | + dev_ctx, out_grad, reduce_dim, out_grad.dtype(), false, value_grad); |
| 304 | + } |
| 305 | + } |
| 306 | + return; |
| 307 | + } |
| 308 | + |
| 309 | + switch (rank) { |
| 310 | +#define CASE_RANK(__Rk) \ |
| 311 | + case __Rk: \ |
| 312 | + SetValueGradImpl<T, Context, __Rk>(dev_ctx, \ |
| 313 | + out_grad, \ |
| 314 | + starts_local, \ |
| 315 | + ends_local, \ |
| 316 | + steps_local, \ |
| 317 | + axes, \ |
| 318 | + decrease_axes, \ |
| 319 | + none_axes, \ |
| 320 | + x_grad, \ |
| 321 | + value_grad); \ |
| 322 | + break; |
| 323 | + CASE_RANK(1); |
| 324 | + CASE_RANK(2); |
| 325 | + CASE_RANK(3); |
| 326 | + CASE_RANK(4); |
| 327 | + CASE_RANK(5); |
| 328 | + CASE_RANK(6); |
| 329 | +#undef CASE_RANK |
| 330 | + default: |
| 331 | + PADDLE_THROW(common::errors::InvalidArgument( |
| 332 | + "The rank of set_value_grad's input should be less than 7, but " |
| 333 | + "received %d.", |
| 334 | + rank)); |
| 335 | + } |
| 336 | + return; |
| 337 | +} |
| 338 | + |
| 339 | +template <typename T, typename Context> |
| 340 | +void SetValueWithScalarGradKernel(const Context& dev_ctx, |
| 341 | + const DenseTensor& out_grad, |
| 342 | + const IntArray& starts, |
| 343 | + const IntArray& ends, |
| 344 | + const IntArray& steps, |
| 345 | + const std::vector<int64_t>& axes, |
| 346 | + const std::vector<int64_t>& decrease_axes, |
| 347 | + const std::vector<int64_t>& none_axes, |
| 348 | + DenseTensor* x_grad) { |
| 349 | + SetValueGradKernel<T, Context>(dev_ctx, |
| 350 | + out_grad, |
| 351 | + starts, |
| 352 | + ends, |
| 353 | + steps, |
| 354 | + axes, |
| 355 | + decrease_axes, |
| 356 | + none_axes, |
| 357 | + x_grad, |
| 358 | + nullptr); |
| 359 | +} |
| 360 | + |
| 361 | +} // namespace phi |
21 | 362 |
|
22 | 363 | PD_REGISTER_KERNEL(set_value_grad,
|
23 | 364 | CPU,
|
|
0 commit comments