@@ -17,10 +17,131 @@ limitations under the License. */
17
17
#include " paddle/phi/core/tensor_utils.h"
18
18
#include " paddle/phi/kernels/funcs/blas/blas.h"
19
19
#include " paddle/phi/kernels/funcs/math_function.h"
20
+ #include " paddle/phi/kernels/funcs/scatter.cu.h"
20
21
21
22
namespace phi {
22
23
namespace funcs {
23
24
25
+ #ifndef PADDLE_WITH_HIP
26
+ /* *
27
+ * Transform pivot array to permutation by swapping perm[i] and perm[pivot[i]]
28
+ * from 0 to n-1, where pivot and perm have shape [batch_size, n].
29
+ * Example:
30
+ * Input pivot = [[6, 7, 4, 5, 5, 7, 8, 8]]
31
+ * Output perm = [[5, 6, 3, 4, 2, 1, 7, 0]]
32
+ */
33
+ __global__ void UnpackPivot (const int * __restrict__ pivot,
34
+ int * __restrict__ perm,
35
+ int64_t batch_size,
36
+ int64_t n) {
37
+ constexpr int warp_size = 32 ;
38
+ int warps_per_block = blockDim .x / warp_size;
39
+ int warp_id = threadIdx .x / warp_size;
40
+ int warp_offset = threadIdx .x % warp_size;
41
+ int64_t offset = static_cast <int64_t >(blockIdx .x ) * warps_per_block + warp_id;
42
+ int64_t stride = static_cast <int64_t >(gridDim .x ) * warps_per_block;
43
+
44
+ for (; offset < batch_size; offset += stride) {
45
+ // init perm[*, n] with 0...n-1
46
+ for (int64_t i = warp_offset; i < n; i += warp_size) {
47
+ perm[offset * n + i] = offset * n + i;
48
+ }
49
+ __syncwarp ();
50
+
51
+ // Since the swapping makes entirely discrete access, we only use the first
52
+ // thread in each warp to avoid warp divergence.
53
+ if (warp_offset > 0 ) continue ;
54
+
55
+ // Swap perm[i] and perm[pivot[i]] for i in 0...n-1
56
+ for (int64_t i = offset * n; i < offset * n + n; ++i) {
57
+ int64_t j = pivot[i] - 1 + offset * n; // cublas use 1-index
58
+ int tmp = perm[i];
59
+ perm[i] = perm[j];
60
+ perm[j] = tmp;
61
+ }
62
+ }
63
+ }
64
+
65
+ /* *
66
+ * Eliminate the L and U in equation:
67
+ * (U^T @ L^T @ P) @ X = B (the U^T @ L^T @ P is stored in A)
68
+ * by solving the inversion of L^T and U^T respectively. The result is:
69
+ * P @ X = L^T^-1 @ U^T^-1 @ B
70
+ * and is stored in B.
71
+ */
72
+ template <typename Context, typename T>
73
+ void SolveLU (const phi::funcs::BlasT<Context, T>& blas,
74
+ int m,
75
+ int n,
76
+ const T* A,
77
+ T* B,
78
+ int batch_size) {
79
+ constexpr T alpha = 1.0 ;
80
+ for (int64_t i = 0 ; i < batch_size; ++i) {
81
+ // Before: U^T @ L^T @ P @ X = B
82
+ blas.TRSM (CblasRight,
83
+ CblasLower,
84
+ CblasTrans,
85
+ CblasNonUnit,
86
+ m,
87
+ n,
88
+ alpha,
89
+ A + i * n * n,
90
+ n,
91
+ B + i * m * n,
92
+ n);
93
+ // After: L^T @ P @ X = U^T^-1 @ B
94
+ blas.TRSM (CblasRight,
95
+ CblasUpper,
96
+ CblasTrans,
97
+ CblasUnit,
98
+ m,
99
+ n,
100
+ alpha,
101
+ A + i * n * n,
102
+ n,
103
+ B + i * m * n,
104
+ n);
105
+ // After: P @ X = L^T^-1 @ U^T^-1 @ B
106
+ }
107
+ }
108
+
109
+ // Batched version of SolveLU.
110
+ template <typename Context, typename T>
111
+ void BatchedSolveLU (const phi::funcs::BlasT<Context, T>& blas,
112
+ int m,
113
+ int n,
114
+ const T** A,
115
+ T** B,
116
+ int batch_size) {
117
+ constexpr T alpha = 1.0 ;
118
+ blas.BatchedTRSM (CblasRight,
119
+ CblasLower,
120
+ CblasTrans,
121
+ CblasNonUnit,
122
+ m,
123
+ n,
124
+ alpha,
125
+ A,
126
+ n,
127
+ B,
128
+ n,
129
+ batch_size);
130
+ blas.BatchedTRSM (CblasRight,
131
+ CblasUpper,
132
+ CblasTrans,
133
+ CblasUnit,
134
+ m,
135
+ n,
136
+ alpha,
137
+ A,
138
+ n,
139
+ B,
140
+ n,
141
+ batch_size);
142
+ }
143
+ #endif
144
+
24
145
template <typename Context, typename T>
25
146
void MatrixSolveFunctor<Context, T>::operator ()(const Context& context,
26
147
const DenseTensor& a,
@@ -39,47 +160,38 @@ void MatrixSolveFunctor<Context, T>::operator()(const Context& context,
39
160
const int a_rank = a_dims.size ();
40
161
int n = a_dims[a_rank - 1 ];
41
162
int lda = n;
42
- int batch_size = a_rank > 2 ? a.numel () / (n * n) : 1 ;
163
+ int64_t batch_size = a_rank > 2 ? a.numel () / (n * n) : 1 ;
43
164
44
165
const auto & b_dims = b.dims ();
45
166
const int b_rank = b_dims.size ();
46
167
int nrhs = b_dims[b_rank - 1 ];
47
- int ldb = b_dims[b_rank - 2 ];
48
-
49
- // make sure the out dims is right
50
- out->Resize (b_dims);
168
+ int ldb = n;
51
169
52
- context.template Alloc <T>(out);
53
-
54
- // copy input A to a temporary tensor tmp_a,
55
- // LU factorization, written back to original matrix A, so in the beginning,
56
- // it's necessary to create a temporary tensor tmp_a.
170
+ // 1. Copy input A to a temporary tensor tmp_a for LU factorization.
57
171
DenseTensor tmp_a (a.dtype ());
58
172
tmp_a.Resize (a.dims ());
59
-
60
173
context.template Alloc <T>(&tmp_a);
61
174
phi::Copy (context, a, context.GetPlace (), false , &tmp_a);
62
175
63
- // copy input B to a temporary tensor tmp_b, and transpose tmp_b,
64
- // because cuBlas assumes column-major while Paddle uses row-majar.
65
- DenseTensor tmp_b (b.type ());
66
- const auto & new_dims_vec = getNewDimsVec (b_dims);
67
- tmp_b.Resize (common::make_ddim (new_dims_vec));
68
- context.template Alloc <T>(&tmp_b);
176
+ // 2. Transpose B and save it in out, because cuBlas assumes column-major
177
+ // while Paddle uses row-majar.
178
+ const auto & new_b_dims = getNewDimsVec (b_dims);
179
+ out->Resize (common::make_ddim (new_b_dims));
180
+ context.template Alloc <T>(out);
69
181
phi::funcs::TransposeNormal<Context, T> trans;
70
182
std::vector<int > new_axis = getNewAxis (b_rank);
71
- trans (context, b, &tmp_b , new_axis);
183
+ trans (context, b, out , new_axis);
72
184
73
185
const T* a_data_in_gpu = tmp_a.data <T>();
74
- const T* b_data_in_gpu = tmp_b. data <T>();
186
+ T* b_data_in_gpu = out-> data <T>();
75
187
76
188
std::vector<const T*> cpu_ptrs (batch_size * 2 );
77
- for (int i = 0 ; i < batch_size; ++i) {
189
+ for (int64_t i = 0 ; i < batch_size; ++i) {
78
190
cpu_ptrs[i] = a_data_in_gpu + i * n * n;
79
191
cpu_ptrs[i + batch_size] = b_data_in_gpu + i * n * nrhs;
80
192
}
81
193
82
- // Copy the addresses of A and tmp_b from host to device.
194
+ // 3. Copy the addresses of A and B from host to device.
83
195
phi::Allocator::AllocationPtr tmp_gpu_ptrs_data = phi::memory_utils::Alloc (
84
196
context.GetPlace (),
85
197
cpu_ptrs.size () * sizeof (T*),
@@ -94,8 +206,8 @@ void MatrixSolveFunctor<Context, T>::operator()(const Context& context,
94
206
T** gpu_tmp_b_ptrs =
95
207
reinterpret_cast <T**>(tmp_gpu_ptrs_data->ptr ()) + batch_size;
96
208
97
- // Allocate device memory for BatchedGETRF's info and pivots.
98
- int num_ints = n < 32 ? batch_size : batch_size * (n + 1 );
209
+ // 4. Allocate device memory for BatchedGETRF's info and pivots.
210
+ int64_t num_ints = batch_size * (n + 1 );
99
211
phi::Allocator::AllocationPtr tmp_gpu_info_data = phi::memory_utils::Alloc (
100
212
context.GetPlace (),
101
213
num_ints * sizeof (int ),
@@ -111,14 +223,13 @@ void MatrixSolveFunctor<Context, T>::operator()(const Context& context,
111
223
int * gpu_pivot_ptr =
112
224
reinterpret_cast <int *>(tmp_gpu_info_data->ptr ()) + batch_size;
113
225
114
- // This function performs the LU factorization of each matrix A by the
115
- // equation A = L * U. L and U are written back to original matrix A,
116
- // and diagonal elements of L are discarded.
226
+ // 5. Performs LU factorization on A.
117
227
blas.BatchedGETRF (n,
118
228
reinterpret_cast <T**>(tmp_gpu_ptrs_data->ptr ()),
119
229
gpu_pivot_ptr,
120
230
gpu_info_ptr,
121
231
batch_size);
232
+ // After: P @ A^T = L @ U
122
233
123
234
// check whether BatchedGETRF is executed successfully or not
124
235
memory_utils::Copy (phi::CPUPlace (),
@@ -139,33 +250,47 @@ void MatrixSolveFunctor<Context, T>::operator()(const Context& context,
139
250
info[i]));
140
251
}
141
252
142
- // hold the result code from BatchedGETRS
143
- int host_info = 0 ;
253
+ // 6. Solve L and U in equation Ax = B where A = U^T @ L^T @ P.
254
+ // The batched version is advantageous for small shapes, but has error for
255
+ // large shapes. In this case, we call the non-batched version for batch_size
256
+ // times instead.
257
+ // Ref: https://docs.nvidia.com/cuda/cublas/#cublas-t-trsmbatched
258
+ constexpr int max_batch_nrhs = 65535 * 8 ; // max(gridDim.y) * 8
259
+ if (batch_size > 1 && nrhs <= max_batch_nrhs) {
260
+ BatchedSolveLU (blas,
261
+ nrhs,
262
+ n,
263
+ reinterpret_cast <const T**>(tmp_gpu_ptrs_data->ptr ()),
264
+ gpu_tmp_b_ptrs,
265
+ batch_size);
266
+ } else {
267
+ SolveLU (blas, nrhs, n, a_data_in_gpu, b_data_in_gpu, batch_size);
268
+ }
269
+
270
+ // 7. Transpose B back to row-major form.
271
+ DenseTensor tmp_b (b.type ());
272
+ tmp_b.Resize (b_dims);
273
+ context.template Alloc <T>(&tmp_b);
274
+ phi::funcs::TransposeNormal<Context, T> trans2;
275
+ trans2 (context, *out, &tmp_b, new_axis);
144
276
145
- // to solve the equation after LU factorization
146
- CBLAS_TRANSPOSE transA = CblasTrans;
147
- blas.BatchedGETRS (transA,
148
- n,
149
- nrhs,
150
- reinterpret_cast <const T**>(tmp_gpu_ptrs_data->ptr ()),
151
- lda,
152
- gpu_pivot_ptr,
153
- gpu_tmp_b_ptrs,
154
- ldb,
155
- &host_info,
156
- batch_size);
277
+ // 8. Permute B according to pivots to get the final result.
278
+ DenseTensor perm;
279
+ perm.Resize ({batch_size * n});
280
+ context.template Alloc <int >(&perm);
157
281
158
- // check whether BatchedGETRS is executed successfully or not
159
- PADDLE_ENFORCE_EQ (host_info,
160
- 0 ,
161
- common::errors::InvalidArgument (
162
- " The [%d]'th argument to cublas*getrsBatched had "
163
- " an illegal value." ,
164
- -host_info));
282
+ auto config =
283
+ phi::backends::gpu::GetGpuLaunchConfig1D (context, batch_size * 32 );
284
+ auto stream = context.stream ();
285
+ UnpackPivot<<<config.block_per_grid, config.thread_per_block, 0 , stream>>> (
286
+ gpu_pivot_ptr, perm.data <int >(), batch_size, n);
165
287
166
- // transpose tmp_b to get the final result in row-major form.
167
- phi::funcs::TransposeNormal<Context, T> trans2;
168
- trans2 (context, tmp_b, out, new_axis);
288
+ // fuse dims 0...n-2 because scatter only supports one index dim
289
+ tmp_b.Resize ({batch_size * n, nrhs});
290
+ out->Resize ({batch_size * n, nrhs});
291
+ GPUScatterAssign<T>(context, tmp_b, perm, out);
292
+ out->Resize (b_dims);
293
+ // After: X = P^T @ L^T^-1 @ U^T^-1 @ B
169
294
170
295
#else
171
296
compute_solve_eigen<Context, T>(context, a, b, out);
0 commit comments