diff --git a/paddle/phi/kernels/gpu/cum_kernel.cu b/paddle/phi/kernels/gpu/cum_kernel.cu index 99dbc1d17153e6..72f27299b23e49 100644 --- a/paddle/phi/kernels/gpu/cum_kernel.cu +++ b/paddle/phi/kernels/gpu/cum_kernel.cu @@ -56,24 +56,6 @@ __global__ void MatrixRowReverse(const T* matrix_data, } } -template -struct BlockPrefixCallbackOp { - // Running prefix - T running_total_; - Op op_; - - __device__ BlockPrefixCallbackOp(T running_total, Op op) - : running_total_(running_total), op_(op) {} - - // Callback operator to be entered by the first warp of threads in the block. - // tid 0 is responsible for returning a value for seeding the block-wide scan. - __device__ T operator()(T block_aggregate) { - T old_prefix = running_total_; - running_total_ = op_(old_prefix, block_aggregate); - return old_prefix; - } -}; - // No bank-conflict transpose template __global__ void MatrixTranspose(T* odata, @@ -146,6 +128,73 @@ struct Identity { static constexpr T value = {0, 0}; }; +template +struct BlockPrefixCallbackOp { + // Running prefix + T running_total_; + T compensation_; + Op op_; + + __device__ BlockPrefixCallbackOp(T identity, Op op) + : running_total_(identity), compensation_(identity), op_(op) {} + + // Callback operator to be entered by the first warp of threads in the block. + // tid 0 is responsible for returning a value for seeding the block-wide scan. + __device__ T operator()(T block_aggregate) { + T old_prefix = running_total_; + + // Kahan Summation + T y = op_(block_aggregate, static_cast(-compensation_)); + T t = op_(running_total_, y); + T y_high = op_(t, static_cast(-running_total_)); + compensation_ = op_(y_high, static_cast(-y)); + running_total_ = t; + + return old_prefix; + } +}; + +template +struct BlockPrefixCallbackOp { + T max_so_far_; + T scaled_sum_; + T compensation_; + LogAddExp op_; + + __device__ BlockPrefixCallbackOp(T identity, LogAddExp op) + : max_so_far_(identity), scaled_sum_(0.0), compensation_(0.0), op_(op) {} + + __device__ T operator()(T block_aggregate) { + if (scaled_sum_ == 0.0) { + max_so_far_ = block_aggregate; + scaled_sum_ = 1.0; + compensation_ = 0.0; + return std::numeric_limits::lowest(); + } + + // Online Scaling + T old_prefix = max_so_far_ + std::log(scaled_sum_); + T m_old = max_so_far_; + T m_new = std::max(m_old, block_aggregate); + + if (m_new > m_old) { + T scale = std::exp(m_old - m_new); + scaled_sum_ *= scale; + compensation_ *= scale; + } + + // Kahan Summation + T term = std::exp(block_aggregate - m_new); + T y = term - compensation_; + T t = scaled_sum_ + y; + compensation_ = (t - scaled_sum_) - y; + scaled_sum_ = t; + max_so_far_ = m_new; + + return old_prefix; + } +}; + template __global__ void BlockScanKernel(T* d_out, const T* d_in, @@ -154,17 +203,17 @@ __global__ void BlockScanKernel(T* d_out, bool exclusive, Op op) { using MT = typename phi::dtype::MPTypeTrait::Type; + using CallbackOp = BlockPrefixCallbackOp; // Specialize BlockLoad, BlockStore, and BlockRadixSort collective types - typedef cub:: - BlockLoad - BlockLoadT; - typedef cub::BlockStore - BlockStoreT; - typedef cub::BlockScan BlockScanT; + using BlockLoadT = cub:: + BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockScanT = cub::BlockScan; + // Allocate type-safe, repurposable shared memory for collectives __shared__ union { typename BlockLoadT::TempStorage load; @@ -176,24 +225,21 @@ __global__ void BlockScanKernel(T* d_out, int64_t item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD; for (int64_t bx = blockIdx.x; bx < grid_size; bx += gridDim.x) { - BlockPrefixCallbackOp prefix_op(Identity::value, op); + CallbackOp prefix_op(Identity::value, op); for (int64_t block_offset = 0; block_offset < scan_size; block_offset += item_per_block) { - int64_t valid_item = (scan_size - block_offset > item_per_block) - ? item_per_block - : (scan_size - block_offset); - if (scan_size < item_per_block) { - valid_item = scan_size; - } + int64_t valid_item = std::min(scan_size - block_offset, item_per_block); int64_t offset = bx * scan_size + block_offset; MT thread_keys[ITEMS_PER_THREAD]; BlockLoadT(temp_storage.load) - .Load(d_in + offset, thread_keys, valid_item, 0); + .Load( + d_in + offset, thread_keys, valid_item, Identity::value); __syncthreads(); + if (exclusive) { BlockScanT(temp_storage.scan) .ExclusiveScan(thread_keys, thread_keys, op, prefix_op); @@ -209,63 +255,6 @@ __global__ void BlockScanKernel(T* d_out, } } -template -typename std::enable_if::value && - !std::is_same::value>::type -ThrustCumsumKernel(const Context& dev_ctx, - const T* in_data, - T* out_data, - int64_t size, - bool reverse, - bool exclusive) { -#ifdef __HIPCC__ - const auto& policy = thrust::hip::par.on(dev_ctx.stream()); -#else - phi::memory_utils::ThrustAllocator allocator(dev_ctx.GetPlace(), - dev_ctx.stream()); - const auto& policy = thrust::cuda::par(allocator).on(dev_ctx.stream()); -#endif - if (reverse) { - thrust::reverse_iterator> reversed_in( - thrust::device_pointer_cast(in_data) + size); - thrust::reverse_iterator> reversed_out( - thrust::device_pointer_cast(out_data) + size); - if (exclusive) { - thrust::exclusive_scan( - policy, reversed_in, reversed_in + size, reversed_out); - } else { - thrust::inclusive_scan( - policy, reversed_in, reversed_in + size, reversed_out); - } - } else { - if (exclusive) { - thrust::exclusive_scan(policy, in_data, in_data + size, out_data); - } else { - thrust::inclusive_scan(policy, in_data, in_data + size, out_data); - } - } - - return; -} - -template -typename std::enable_if::value>::type -ThrustCumsumKernel(const Context& dev_ctx, - const phi::dtype::float16* in_data, - phi::dtype::float16* out_data, - int64_t size, - bool reverse, - bool exclusive) {} - -template -typename std::enable_if::value>::type -ThrustCumsumKernel(const Context& dev_ctx, - const phi::dtype::bfloat16* in_data, - phi::dtype::bfloat16* out_data, - int64_t size, - bool reverse, - bool exclusive) {} - template void ScanKernel(const Context& dev_ctx, const DenseTensor& x, @@ -290,7 +279,6 @@ void ScanKernel(const Context& dev_ctx, } auto out_dims = out->dims(); - auto size = x.numel(); PADDLE_ENFORCE_EQ( axis < out_dims.size() && axis >= (0 - out_dims.size()), @@ -307,22 +295,11 @@ void ScanKernel(const Context& dev_ctx, const T* in_data = x.data(); - // Use thrust for parallel acceleration when the input size is equal to the - // length of the 'axis' dimension. - if (!std::is_same::value && - !std::is_same::value && - std::is_same::value && size == out_dims[axis]) { - ThrustCumsumKernel( - dev_ctx, in_data, out_data, size, reverse, exclusive); - return; - } - size_t height = 1; size_t width = 1; for (size_t i = 0; i <= axis; i++) { height *= out_dims[i]; } - for (size_t i = axis + 1; i < out_dims.size(); i++) { width *= out_dims[i]; } diff --git a/test/legacy_test/test_logcumsumexp_op.py b/test/legacy_test/test_logcumsumexp_op.py index 9e55fec9efab8a..615b5298e54d1e 100644 --- a/test/legacy_test/test_logcumsumexp_op.py +++ b/test/legacy_test/test_logcumsumexp_op.py @@ -77,8 +77,13 @@ def np_logcumsumexp_grad( exclusive: bool = False, ): out = np_logcumsumexp(x, axis, flatten, reverse, exclusive) - log_grad_positive = np.where(dout > 0, np.log(dout), np.finfo(x.dtype).min) - log_grad_negative = np.where(dout < 0, np.log(-dout), np.finfo(x.dtype).min) + dout = np.asarray(dout) + pos_mask = dout > 0 + neg_mask = dout < 0 + log_grad_positive = np.full_like(dout, np.finfo(x.dtype).min) + log_grad_negative = np.full_like(dout, np.finfo(x.dtype).min) + log_grad_positive[pos_mask] = np.log(dout[pos_mask]) + log_grad_negative[neg_mask] = np.log(-dout[neg_mask]) output_pos = np.exp( np_logcumsumexp(