Skip to content

Commit 98b4876

Browse files
lshpkuwanghuancoder
authored andcommitted
[PHI] Align linalg.solve kernel with torch (PaddlePaddle#72608)
1 parent 752eb6e commit 98b4876

File tree

1 file changed

+175
-50
lines changed

1 file changed

+175
-50
lines changed

paddle/phi/kernels/funcs/matrix_solve.cu

Lines changed: 175 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,131 @@ limitations under the License. */
1717
#include "paddle/phi/core/tensor_utils.h"
1818
#include "paddle/phi/kernels/funcs/blas/blas.h"
1919
#include "paddle/phi/kernels/funcs/math_function.h"
20+
#include "paddle/phi/kernels/funcs/scatter.cu.h"
2021

2122
namespace phi {
2223
namespace funcs {
2324

25+
#ifndef PADDLE_WITH_HIP
26+
/**
27+
* Transform pivot array to permutation by swapping perm[i] and perm[pivot[i]]
28+
* from 0 to n-1, where pivot and perm have shape [batch_size, n].
29+
* Example:
30+
* Input pivot = [[6, 7, 4, 5, 5, 7, 8, 8]]
31+
* Output perm = [[5, 6, 3, 4, 2, 1, 7, 0]]
32+
*/
33+
__global__ void UnpackPivot(const int* __restrict__ pivot,
34+
int* __restrict__ perm,
35+
int64_t batch_size,
36+
int64_t n) {
37+
constexpr int warp_size = 32;
38+
int warps_per_block = blockDim.x / warp_size;
39+
int warp_id = threadIdx.x / warp_size;
40+
int warp_offset = threadIdx.x % warp_size;
41+
int64_t offset = static_cast<int64_t>(blockIdx.x) * warps_per_block + warp_id;
42+
int64_t stride = static_cast<int64_t>(gridDim.x) * warps_per_block;
43+
44+
for (; offset < batch_size; offset += stride) {
45+
// init perm[*, n] with 0...n-1
46+
for (int64_t i = warp_offset; i < n; i += warp_size) {
47+
perm[offset * n + i] = offset * n + i;
48+
}
49+
__syncwarp();
50+
51+
// Since the swapping makes entirely discrete access, we only use the first
52+
// thread in each warp to avoid warp divergence.
53+
if (warp_offset > 0) continue;
54+
55+
// Swap perm[i] and perm[pivot[i]] for i in 0...n-1
56+
for (int64_t i = offset * n; i < offset * n + n; ++i) {
57+
int64_t j = pivot[i] - 1 + offset * n; // cublas use 1-index
58+
int tmp = perm[i];
59+
perm[i] = perm[j];
60+
perm[j] = tmp;
61+
}
62+
}
63+
}
64+
65+
/**
66+
* Eliminate the L and U in equation:
67+
* (U^T @ L^T @ P) @ X = B (the U^T @ L^T @ P is stored in A)
68+
* by solving the inversion of L^T and U^T respectively. The result is:
69+
* P @ X = L^T^-1 @ U^T^-1 @ B
70+
* and is stored in B.
71+
*/
72+
template <typename Context, typename T>
73+
void SolveLU(const phi::funcs::BlasT<Context, T>& blas,
74+
int m,
75+
int n,
76+
const T* A,
77+
T* B,
78+
int batch_size) {
79+
constexpr T alpha = 1.0;
80+
for (int64_t i = 0; i < batch_size; ++i) {
81+
// Before: U^T @ L^T @ P @ X = B
82+
blas.TRSM(CblasRight,
83+
CblasLower,
84+
CblasTrans,
85+
CblasNonUnit,
86+
m,
87+
n,
88+
alpha,
89+
A + i * n * n,
90+
n,
91+
B + i * m * n,
92+
n);
93+
// After: L^T @ P @ X = U^T^-1 @ B
94+
blas.TRSM(CblasRight,
95+
CblasUpper,
96+
CblasTrans,
97+
CblasUnit,
98+
m,
99+
n,
100+
alpha,
101+
A + i * n * n,
102+
n,
103+
B + i * m * n,
104+
n);
105+
// After: P @ X = L^T^-1 @ U^T^-1 @ B
106+
}
107+
}
108+
109+
// Batched version of SolveLU.
110+
template <typename Context, typename T>
111+
void BatchedSolveLU(const phi::funcs::BlasT<Context, T>& blas,
112+
int m,
113+
int n,
114+
const T** A,
115+
T** B,
116+
int batch_size) {
117+
constexpr T alpha = 1.0;
118+
blas.BatchedTRSM(CblasRight,
119+
CblasLower,
120+
CblasTrans,
121+
CblasNonUnit,
122+
m,
123+
n,
124+
alpha,
125+
A,
126+
n,
127+
B,
128+
n,
129+
batch_size);
130+
blas.BatchedTRSM(CblasRight,
131+
CblasUpper,
132+
CblasTrans,
133+
CblasUnit,
134+
m,
135+
n,
136+
alpha,
137+
A,
138+
n,
139+
B,
140+
n,
141+
batch_size);
142+
}
143+
#endif
144+
24145
template <typename Context, typename T>
25146
void MatrixSolveFunctor<Context, T>::operator()(const Context& context,
26147
const DenseTensor& a,
@@ -39,47 +160,38 @@ void MatrixSolveFunctor<Context, T>::operator()(const Context& context,
39160
const int a_rank = a_dims.size();
40161
int n = a_dims[a_rank - 1];
41162
int lda = n;
42-
int batch_size = a_rank > 2 ? a.numel() / (n * n) : 1;
163+
int64_t batch_size = a_rank > 2 ? a.numel() / (n * n) : 1;
43164

44165
const auto& b_dims = b.dims();
45166
const int b_rank = b_dims.size();
46167
int nrhs = b_dims[b_rank - 1];
47-
int ldb = b_dims[b_rank - 2];
48-
49-
// make sure the out dims is right
50-
out->Resize(b_dims);
168+
int ldb = n;
51169

52-
context.template Alloc<T>(out);
53-
54-
// copy input A to a temporary tensor tmp_a,
55-
// LU factorization, written back to original matrix A, so in the beginning,
56-
// it's necessary to create a temporary tensor tmp_a.
170+
// 1. Copy input A to a temporary tensor tmp_a for LU factorization.
57171
DenseTensor tmp_a(a.dtype());
58172
tmp_a.Resize(a.dims());
59-
60173
context.template Alloc<T>(&tmp_a);
61174
phi::Copy(context, a, context.GetPlace(), false, &tmp_a);
62175

63-
// copy input B to a temporary tensor tmp_b, and transpose tmp_b,
64-
// because cuBlas assumes column-major while Paddle uses row-majar.
65-
DenseTensor tmp_b(b.type());
66-
const auto& new_dims_vec = getNewDimsVec(b_dims);
67-
tmp_b.Resize(common::make_ddim(new_dims_vec));
68-
context.template Alloc<T>(&tmp_b);
176+
// 2. Transpose B and save it in out, because cuBlas assumes column-major
177+
// while Paddle uses row-majar.
178+
const auto& new_b_dims = getNewDimsVec(b_dims);
179+
out->Resize(common::make_ddim(new_b_dims));
180+
context.template Alloc<T>(out);
69181
phi::funcs::TransposeNormal<Context, T> trans;
70182
std::vector<int> new_axis = getNewAxis(b_rank);
71-
trans(context, b, &tmp_b, new_axis);
183+
trans(context, b, out, new_axis);
72184

73185
const T* a_data_in_gpu = tmp_a.data<T>();
74-
const T* b_data_in_gpu = tmp_b.data<T>();
186+
T* b_data_in_gpu = out->data<T>();
75187

76188
std::vector<const T*> cpu_ptrs(batch_size * 2);
77-
for (int i = 0; i < batch_size; ++i) {
189+
for (int64_t i = 0; i < batch_size; ++i) {
78190
cpu_ptrs[i] = a_data_in_gpu + i * n * n;
79191
cpu_ptrs[i + batch_size] = b_data_in_gpu + i * n * nrhs;
80192
}
81193

82-
// Copy the addresses of A and tmp_b from host to device.
194+
// 3. Copy the addresses of A and B from host to device.
83195
phi::Allocator::AllocationPtr tmp_gpu_ptrs_data = phi::memory_utils::Alloc(
84196
context.GetPlace(),
85197
cpu_ptrs.size() * sizeof(T*),
@@ -94,8 +206,8 @@ void MatrixSolveFunctor<Context, T>::operator()(const Context& context,
94206
T** gpu_tmp_b_ptrs =
95207
reinterpret_cast<T**>(tmp_gpu_ptrs_data->ptr()) + batch_size;
96208

97-
// Allocate device memory for BatchedGETRF's info and pivots.
98-
int num_ints = n < 32 ? batch_size : batch_size * (n + 1);
209+
// 4. Allocate device memory for BatchedGETRF's info and pivots.
210+
int64_t num_ints = batch_size * (n + 1);
99211
phi::Allocator::AllocationPtr tmp_gpu_info_data = phi::memory_utils::Alloc(
100212
context.GetPlace(),
101213
num_ints * sizeof(int),
@@ -111,14 +223,13 @@ void MatrixSolveFunctor<Context, T>::operator()(const Context& context,
111223
int* gpu_pivot_ptr =
112224
reinterpret_cast<int*>(tmp_gpu_info_data->ptr()) + batch_size;
113225

114-
// This function performs the LU factorization of each matrix A by the
115-
// equation A = L * U. L and U are written back to original matrix A,
116-
// and diagonal elements of L are discarded.
226+
// 5. Performs LU factorization on A.
117227
blas.BatchedGETRF(n,
118228
reinterpret_cast<T**>(tmp_gpu_ptrs_data->ptr()),
119229
gpu_pivot_ptr,
120230
gpu_info_ptr,
121231
batch_size);
232+
// After: P @ A^T = L @ U
122233

123234
// check whether BatchedGETRF is executed successfully or not
124235
memory_utils::Copy(phi::CPUPlace(),
@@ -139,33 +250,47 @@ void MatrixSolveFunctor<Context, T>::operator()(const Context& context,
139250
info[i]));
140251
}
141252

142-
// hold the result code from BatchedGETRS
143-
int host_info = 0;
253+
// 6. Solve L and U in equation Ax = B where A = U^T @ L^T @ P.
254+
// The batched version is advantageous for small shapes, but has error for
255+
// large shapes. In this case, we call the non-batched version for batch_size
256+
// times instead.
257+
// Ref: https://docs.nvidia.com/cuda/cublas/#cublas-t-trsmbatched
258+
constexpr int max_batch_nrhs = 65535 * 8; // max(gridDim.y) * 8
259+
if (batch_size > 1 && nrhs <= max_batch_nrhs) {
260+
BatchedSolveLU(blas,
261+
nrhs,
262+
n,
263+
reinterpret_cast<const T**>(tmp_gpu_ptrs_data->ptr()),
264+
gpu_tmp_b_ptrs,
265+
batch_size);
266+
} else {
267+
SolveLU(blas, nrhs, n, a_data_in_gpu, b_data_in_gpu, batch_size);
268+
}
269+
270+
// 7. Transpose B back to row-major form.
271+
DenseTensor tmp_b(b.type());
272+
tmp_b.Resize(b_dims);
273+
context.template Alloc<T>(&tmp_b);
274+
phi::funcs::TransposeNormal<Context, T> trans2;
275+
trans2(context, *out, &tmp_b, new_axis);
144276

145-
// to solve the equation after LU factorization
146-
CBLAS_TRANSPOSE transA = CblasTrans;
147-
blas.BatchedGETRS(transA,
148-
n,
149-
nrhs,
150-
reinterpret_cast<const T**>(tmp_gpu_ptrs_data->ptr()),
151-
lda,
152-
gpu_pivot_ptr,
153-
gpu_tmp_b_ptrs,
154-
ldb,
155-
&host_info,
156-
batch_size);
277+
// 8. Permute B according to pivots to get the final result.
278+
DenseTensor perm;
279+
perm.Resize({batch_size * n});
280+
context.template Alloc<int>(&perm);
157281

158-
// check whether BatchedGETRS is executed successfully or not
159-
PADDLE_ENFORCE_EQ(host_info,
160-
0,
161-
common::errors::InvalidArgument(
162-
"The [%d]'th argument to cublas*getrsBatched had "
163-
"an illegal value.",
164-
-host_info));
282+
auto config =
283+
phi::backends::gpu::GetGpuLaunchConfig1D(context, batch_size * 32);
284+
auto stream = context.stream();
285+
UnpackPivot<<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
286+
gpu_pivot_ptr, perm.data<int>(), batch_size, n);
165287

166-
// transpose tmp_b to get the final result in row-major form.
167-
phi::funcs::TransposeNormal<Context, T> trans2;
168-
trans2(context, tmp_b, out, new_axis);
288+
// fuse dims 0...n-2 because scatter only supports one index dim
289+
tmp_b.Resize({batch_size * n, nrhs});
290+
out->Resize({batch_size * n, nrhs});
291+
GPUScatterAssign<T>(context, tmp_b, perm, out);
292+
out->Resize(b_dims);
293+
// After: X = P^T @ L^T^-1 @ U^T^-1 @ B
169294

170295
#else
171296
compute_solve_eigen<Context, T>(context, a, b, out);

0 commit comments

Comments
 (0)