16
16
#include " ../include/bestla_packq_impl.hpp"
17
17
18
18
namespace woq {
19
- template <class GemmCore , BTLA_ISA ISA>
19
+
20
+ template <class proB >
20
21
void execute_qpack (repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) {
21
- using proB = bestla::prologue_b::gemm::WeightKBlockNInteger<GemmCore, ISA>;
22
22
static proB ker;
23
- auto qpackw = ker.createStorage (ctx->n , ctx->k , p->blocksize , wei2bestladt_map.at (p->weight_type ),
24
- scale2bestladt_map.at (p->scale_type ), BTLA_DTYPE::BF16, p->asym );
23
+ using WType = typename proB::StorageWeight;
24
+ WType qpackw (0 );
25
+ if constexpr (std::is_same_v<WType, bestla::storage::gemm::StorageWeightKBlockNInteger>) {
26
+ qpackw = ker.createStorage (ctx->n , ctx->k , p->blocksize , wei2bestladt_map.at (p->weight_type ),
27
+ scale2bestladt_map.at (p->scale_type ), BTLA_DTYPE::BF16, p->asym );
28
+ } else {
29
+ qpackw = ker.createStorage (ctx->n , ctx->k , p->blocksize , wei2bestladt_map.at (p->weight_type ),
30
+ scale2bestladt_map.at (p->scale_type ));
31
+ }
25
32
if (p->enable_act_shuffle ) ker.enableShuffle (&qpackw);
26
33
ctx->packw_size = qpackw.mSize ;
27
34
if (task == WOQ_GET_PACKW_SIZE) return ;
@@ -33,6 +40,20 @@ void execute_qpack(repack_quantized_weight_param* p, repack_quantized_weight_ctx
33
40
p->asym ? ctx->zp ->data_ptr <int8_t >() : nullptr , &qpackw, dispatcher_utils::qbits_threading::get ());
34
41
}
35
42
43
+ template <class GemmCore , BTLA_ISA ISA>
44
+ void parse_prob (repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) {
45
+ if (p->weight_type == " int8" || p->weight_type == " int4_clip" || p->weight_type == " int3_clip" ||
46
+ p->weight_type == " int2_clip" ) {
47
+ return execute_qpack<bestla::prologue_b::gemm::WeightKBlockNInteger<GemmCore, ISA>>(p, ctx, task);
48
+ }
49
+ if (p->weight_type == " nf4" || p->weight_type == " fp4_e2m1_bnb" || p->weight_type == " fp4_e2m1" ) {
50
+ TORCH_CHECK (!p->asym , " Qbits: float-weight unsupports asym quantization." );
51
+ return execute_qpack<bestla::prologue_b::gemm::WeightKBlockNFloat<GemmCore, ISA>>(p, ctx, task);
52
+ }
53
+ TORCH_CHECK (false , " Qbits: unsupported bestla packq config, compute_type: " + p->compute_type +
54
+ " weight_type: " + p->weight_type );
55
+ }
56
+
36
57
std::string get_dtype_str (BTLA_DTYPE dtype) {
37
58
switch (dtype) {
38
59
case BTLA_DTYPE::F32:
@@ -183,40 +204,38 @@ torch::Tensor get_packw_info(torch::Tensor& packw, PACKW_ACQUIRE_TYPE ACQ_T) {
183
204
}
184
205
185
206
void bestla_packq (repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) {
186
- // TODO(zhe): elegant impl.
187
- TORCH_CHECK (p->weight_type == " int8" || p->weight_type == " int4_clip" || p->weight_type == " int3_clip" ||
188
- p->weight_type == " int2_clip" ,
189
- " Qbits: only support Integer WOQ in PACKQ" );
190
-
191
207
if (p->compute_type == " int8" ) {
208
+ TORCH_CHECK (p->weight_type == " int8" || p->weight_type == " int4_clip" || p->weight_type == " int3_clip" ||
209
+ p->weight_type == " int2_clip" ,
210
+ " Qbits: only support Integer weight-type with int8 compute-type" );
192
211
if (dispatcher_utils::check_amx () && p->blocksize % bestla::gemm::ICoreRowNAmxint8KBlock<64 , 16 >::KTILE == 0 ) {
193
- return execute_qpack <bestla::gemm::ICoreRowNAmxint8KBlock<64 , 16 >, BTLA_ISA::AMX_INT8>(p, ctx, task);
212
+ return parse_prob <bestla::gemm::ICoreRowNAmxint8KBlock<64 , 16 >, BTLA_ISA::AMX_INT8>(p, ctx, task);
194
213
}
195
214
if (dispatcher_utils::check_avx512_vnni () &&
196
215
p->blocksize % bestla::gemm::ICoreRowNAvx512vnniKBlock<48 , 4 >::KTILE == 0 ) {
197
- return execute_qpack <bestla::gemm::ICoreRowNAvx512vnniKBlock<48 , 4 >, BTLA_ISA::AVX512_VNNI>(p, ctx, task);
216
+ return parse_prob <bestla::gemm::ICoreRowNAvx512vnniKBlock<48 , 4 >, BTLA_ISA::AVX512_VNNI>(p, ctx, task);
198
217
}
199
218
if (dispatcher_utils::check_avx_vnni () && p->blocksize % bestla::gemm::ICoreRowNAvxvnniKBlock<24 , 2 >::KTILE == 0 ) {
200
- return execute_qpack <bestla::gemm::ICoreRowNAvxvnniKBlock<24 , 2 >, BTLA_ISA::AVX_VNNI>(p, ctx, task);
219
+ return parse_prob <bestla::gemm::ICoreRowNAvxvnniKBlock<24 , 2 >, BTLA_ISA::AVX_VNNI>(p, ctx, task);
201
220
}
202
221
if (dispatcher_utils::check_avx2 () && p->blocksize % bestla::gemm::ICoreRowNAvx2vnniKBlock<24 , 2 >::KTILE == 0 ) {
203
- return execute_qpack <bestla::gemm::ICoreRowNAvx2vnniKBlock<24 , 2 >, BTLA_ISA::AVX2>(p, ctx, task);
222
+ return parse_prob <bestla::gemm::ICoreRowNAvx2vnniKBlock<24 , 2 >, BTLA_ISA::AVX2>(p, ctx, task);
204
223
}
205
224
TORCH_CHECK (false , " Qbits: Illegal config in int8 compute_type, blocksize:" , p->blocksize ,
206
225
" , ISA support avx2:" , dispatcher_utils::check_avx2 ());
207
226
}
208
227
if (p->compute_type == " fp32" ) {
209
228
if (dispatcher_utils::check_avx512f ()) {
210
- return execute_qpack <bestla::gemm::SCoreRowNAvx512f<48 , 8 >, BTLA_ISA::AVX512F>(p, ctx, task);
229
+ return parse_prob <bestla::gemm::SCoreRowNAvx512f<48 , 8 >, BTLA_ISA::AVX512F>(p, ctx, task);
211
230
}
212
231
if (dispatcher_utils::check_avx2 ()) {
213
- return execute_qpack <bestla::gemm::SCoreRowNAvx2<24 , 4 >, BTLA_ISA::AVX2>(p, ctx, task);
232
+ return parse_prob <bestla::gemm::SCoreRowNAvx2<24 , 4 >, BTLA_ISA::AVX2>(p, ctx, task);
214
233
}
215
234
TORCH_CHECK (false , " Qbits: device ISA must support BTLA_ISA::AVX2 when compute_type==fp32" );
216
235
}
217
236
if (p->compute_type == " bf16" ) {
218
237
if (dispatcher_utils::check_amx ()) {
219
- return execute_qpack <bestla::gemm::HCoreRowNAmxbf16<64 , 16 >, BTLA_ISA::AMX_BF16>(p, ctx, task);
238
+ return parse_prob <bestla::gemm::HCoreRowNAmxbf16<64 , 16 >, BTLA_ISA::AMX_BF16>(p, ctx, task);
220
239
}
221
240
TORCH_CHECK (false , " Qbits: device ISA must support AMX-BF16 when compute_type==bf16" );
222
241
}
0 commit comments