Skip to content

[Accuracy diff No.16] Fix accuracy diff for paddle.cumsumpaddle.logcumsumexp API #74081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 81 additions & 104 deletions paddle/phi/kernels/gpu/cum_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,6 @@ __global__ void MatrixRowReverse(const T* matrix_data,
}
}

template <typename T, typename Op>
struct BlockPrefixCallbackOp {
// Running prefix
T running_total_;
Op op_;

__device__ BlockPrefixCallbackOp(T running_total, Op op)
: running_total_(running_total), op_(op) {}

// Callback operator to be entered by the first warp of threads in the block.
// tid 0 is responsible for returning a value for seeding the block-wide scan.
__device__ T operator()(T block_aggregate) {
T old_prefix = running_total_;
running_total_ = op_(old_prefix, block_aggregate);
return old_prefix;
}
};

// No bank-conflict transpose
template <typename T, int TILE_DIM, int BLOCK_ROWS>
__global__ void MatrixTranspose(T* odata,
Expand Down Expand Up @@ -146,6 +128,73 @@ struct Identity<T, ComplexSum> {
static constexpr T value = {0, 0};
};

template <typename T, typename Op>
struct BlockPrefixCallbackOp {
// Running prefix
T running_total_;
T compensation_;
Op op_;

__device__ BlockPrefixCallbackOp(T identity, Op op)
: running_total_(identity), compensation_(identity), op_(op) {}

// Callback operator to be entered by the first warp of threads in the block.
// tid 0 is responsible for returning a value for seeding the block-wide scan.
__device__ T operator()(T block_aggregate) {
T old_prefix = running_total_;

// Kahan Summation
T y = op_(block_aggregate, static_cast<T>(-compensation_));
T t = op_(running_total_, y);
T y_high = op_(t, static_cast<T>(-running_total_));
compensation_ = op_(y_high, static_cast<T>(-y));
running_total_ = t;

return old_prefix;
}
};

template <typename T>
struct BlockPrefixCallbackOp<T, LogAddExp> {
T max_so_far_;
T scaled_sum_;
T compensation_;
LogAddExp op_;

__device__ BlockPrefixCallbackOp(T identity, LogAddExp op)
: max_so_far_(identity), scaled_sum_(0.0), compensation_(0.0), op_(op) {}

__device__ T operator()(T block_aggregate) {
if (scaled_sum_ == 0.0) {
max_so_far_ = block_aggregate;
scaled_sum_ = 1.0;
compensation_ = 0.0;
return std::numeric_limits<T>::lowest();
}

// Online Scaling
T old_prefix = max_so_far_ + std::log(scaled_sum_);
T m_old = max_so_far_;
T m_new = std::max(m_old, block_aggregate);

if (m_new > m_old) {
T scale = std::exp(m_old - m_new);
scaled_sum_ *= scale;
compensation_ *= scale;
}

// Kahan Summation
T term = std::exp(block_aggregate - m_new);
T y = term - compensation_;
T t = scaled_sum_ + y;
compensation_ = (t - scaled_sum_) - y;
scaled_sum_ = t;
max_so_far_ = m_new;

return old_prefix;
}
};

template <typename T, int BLOCK_THREADS, int ITEMS_PER_THREAD, typename Op>
__global__ void BlockScanKernel(T* d_out,
const T* d_in,
Expand All @@ -154,17 +203,17 @@ __global__ void BlockScanKernel(T* d_out,
bool exclusive,
Op op) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
using CallbackOp = BlockPrefixCallbackOp<MT, Op>;

// Specialize BlockLoad, BlockStore, and BlockRadixSort collective types
typedef cub::
BlockLoad<MT, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_TRANSPOSE>
BlockLoadT;
typedef cub::BlockStore<MT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
cub::BLOCK_STORE_TRANSPOSE>
BlockStoreT;
typedef cub::BlockScan<MT, BLOCK_THREADS> BlockScanT;
using BlockLoadT = cub::
BlockLoad<MT, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_TRANSPOSE>;
using BlockStoreT = cub::BlockStore<MT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
cub::BLOCK_STORE_TRANSPOSE>;
using BlockScanT = cub::BlockScan<MT, BLOCK_THREADS>;

// Allocate type-safe, repurposable shared memory for collectives
__shared__ union {
typename BlockLoadT::TempStorage load;
Expand All @@ -176,24 +225,21 @@ __global__ void BlockScanKernel(T* d_out,
int64_t item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;

for (int64_t bx = blockIdx.x; bx < grid_size; bx += gridDim.x) {
BlockPrefixCallbackOp<MT, Op> prefix_op(Identity<MT, Op>::value, op);
CallbackOp prefix_op(Identity<MT, Op>::value, op);

for (int64_t block_offset = 0; block_offset < scan_size;
block_offset += item_per_block) {
int64_t valid_item = (scan_size - block_offset > item_per_block)
? item_per_block
: (scan_size - block_offset);
if (scan_size < item_per_block) {
valid_item = scan_size;
}
int64_t valid_item = std::min(scan_size - block_offset, item_per_block);

int64_t offset = bx * scan_size + block_offset;

MT thread_keys[ITEMS_PER_THREAD];
BlockLoadT(temp_storage.load)
.Load(d_in + offset, thread_keys, valid_item, 0);
.Load(
d_in + offset, thread_keys, valid_item, Identity<MT, Op>::value);

__syncthreads();

if (exclusive) {
BlockScanT(temp_storage.scan)
.ExclusiveScan(thread_keys, thread_keys, op, prefix_op);
Expand All @@ -209,63 +255,6 @@ __global__ void BlockScanKernel(T* d_out,
}
}

template <typename Context, typename T>
typename std::enable_if<!std::is_same<T, phi::dtype::float16>::value &&
!std::is_same<T, phi::dtype::bfloat16>::value>::type
ThrustCumsumKernel(const Context& dev_ctx,
const T* in_data,
T* out_data,
int64_t size,
bool reverse,
bool exclusive) {
#ifdef __HIPCC__
const auto& policy = thrust::hip::par.on(dev_ctx.stream());
#else
phi::memory_utils::ThrustAllocator<cudaStream_t> allocator(dev_ctx.GetPlace(),
dev_ctx.stream());
const auto& policy = thrust::cuda::par(allocator).on(dev_ctx.stream());
#endif
if (reverse) {
thrust::reverse_iterator<thrust::device_ptr<const T>> reversed_in(
thrust::device_pointer_cast(in_data) + size);
thrust::reverse_iterator<thrust::device_ptr<T>> reversed_out(
thrust::device_pointer_cast(out_data) + size);
if (exclusive) {
thrust::exclusive_scan(
policy, reversed_in, reversed_in + size, reversed_out);
} else {
thrust::inclusive_scan(
policy, reversed_in, reversed_in + size, reversed_out);
}
} else {
if (exclusive) {
thrust::exclusive_scan(policy, in_data, in_data + size, out_data);
} else {
thrust::inclusive_scan(policy, in_data, in_data + size, out_data);
}
}

return;
}

template <typename Context, typename T>
typename std::enable_if<std::is_same<T, phi::dtype::float16>::value>::type
ThrustCumsumKernel(const Context& dev_ctx,
const phi::dtype::float16* in_data,
phi::dtype::float16* out_data,
int64_t size,
bool reverse,
bool exclusive) {}

template <typename Context, typename T>
typename std::enable_if<std::is_same<T, phi::dtype::bfloat16>::value>::type
ThrustCumsumKernel(const Context& dev_ctx,
const phi::dtype::bfloat16* in_data,
phi::dtype::bfloat16* out_data,
int64_t size,
bool reverse,
bool exclusive) {}

template <typename T, typename Context, typename Op>
void ScanKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand All @@ -290,7 +279,6 @@ void ScanKernel(const Context& dev_ctx,
}

auto out_dims = out->dims();
auto size = x.numel();

PADDLE_ENFORCE_EQ(
axis < out_dims.size() && axis >= (0 - out_dims.size()),
Expand All @@ -307,22 +295,11 @@ void ScanKernel(const Context& dev_ctx,

const T* in_data = x.data<T>();

// Use thrust for parallel acceleration when the input size is equal to the
// length of the 'axis' dimension.
if (!std::is_same<T, phi::dtype::float16>::value &&
!std::is_same<T, phi::dtype::bfloat16>::value &&
std::is_same<Op, cub::Sum>::value && size == out_dims[axis]) {
ThrustCumsumKernel<Context, T>(
dev_ctx, in_data, out_data, size, reverse, exclusive);
return;
}

size_t height = 1;
size_t width = 1;
for (size_t i = 0; i <= axis; i++) {
height *= out_dims[i];
}

for (size_t i = axis + 1; i < out_dims.size(); i++) {
width *= out_dims[i];
}
Expand Down
9 changes: 7 additions & 2 deletions test/legacy_test/test_logcumsumexp_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,13 @@ def np_logcumsumexp_grad(
exclusive: bool = False,
):
out = np_logcumsumexp(x, axis, flatten, reverse, exclusive)
log_grad_positive = np.where(dout > 0, np.log(dout), np.finfo(x.dtype).min)
log_grad_negative = np.where(dout < 0, np.log(-dout), np.finfo(x.dtype).min)
dout = np.asarray(dout)
pos_mask = dout > 0
neg_mask = dout < 0
log_grad_positive = np.full_like(dout, np.finfo(x.dtype).min)
log_grad_negative = np.full_like(dout, np.finfo(x.dtype).min)
log_grad_positive[pos_mask] = np.log(dout[pos_mask])
log_grad_negative[neg_mask] = np.log(-dout[neg_mask])

output_pos = np.exp(
np_logcumsumexp(
Expand Down