diff --git a/intel_extension_for_transformers/transformers/llm/operator/csrc/dispatcher/include/bestla_packq_impl.hpp b/intel_extension_for_transformers/transformers/llm/operator/csrc/dispatcher/include/bestla_packq_impl.hpp index 88b16ac36f4..1956414528a 100644 --- a/intel_extension_for_transformers/transformers/llm/operator/csrc/dispatcher/include/bestla_packq_impl.hpp +++ b/intel_extension_for_transformers/transformers/llm/operator/csrc/dispatcher/include/bestla_packq_impl.hpp @@ -30,6 +30,6 @@ enum PACKW_ACQUIRE_TYPE { IS_ASYM }; -void bestla_packq(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx); +void bestla_packq(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task); torch::Tensor get_packw_info(torch::Tensor& packw, PACKW_ACQUIRE_TYPE ACQ_T); } // namespace woq diff --git a/intel_extension_for_transformers/transformers/llm/operator/csrc/dispatcher/include/bestla_weightonly_dispatcher.hpp b/intel_extension_for_transformers/transformers/llm/operator/csrc/dispatcher/include/bestla_weightonly_dispatcher.hpp index 7f8cba2b23e..fd8b25f68b2 100644 --- a/intel_extension_for_transformers/transformers/llm/operator/csrc/dispatcher/include/bestla_weightonly_dispatcher.hpp +++ b/intel_extension_for_transformers/transformers/llm/operator/csrc/dispatcher/include/bestla_weightonly_dispatcher.hpp @@ -25,6 +25,8 @@ enum WOQ_TASK { WOQ_QUANTIZE, WOQ_DEQUANTIZE, WOQ_LINEAR, + WOQ_REPACK, + WOQ_GET_PACKW_SIZE, }; struct woq_param_base { @@ -47,6 +49,7 @@ struct repack_quantized_weight_param : public woq_param_base { struct repack_quantized_weight_ctx { torch::Tensor *qweight, *scale, *zp, *g_idx, *output; int n, k; + size_t packw_size; }; struct woq_runtime_ctx { diff --git a/intel_extension_for_transformers/transformers/llm/operator/csrc/dispatcher/src/bestla_packq_impl.cpp b/intel_extension_for_transformers/transformers/llm/operator/csrc/dispatcher/src/bestla_packq_impl.cpp index 0b28a727fbd..9ae684c9790 100644 --- a/intel_extension_for_transformers/transformers/llm/operator/csrc/dispatcher/src/bestla_packq_impl.cpp +++ b/intel_extension_for_transformers/transformers/llm/operator/csrc/dispatcher/src/bestla_packq_impl.cpp @@ -3,12 +3,14 @@ namespace woq { template -void execute_qpack(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx) { +void execute_qpack(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) { using proB = bestla::prologue_b::gemm::WeightKBlockNInteger; static proB ker; auto qpackw = ker.createStorage(ctx->n, ctx->k, p->blocksize, wei2bestladt_map[p->weight_type], scale2bestladt_map[p->scale_type], BTLA_DTYPE::BF16, p->asym); if (p->enable_act_shuffle) ker.enableShuffle(&qpackw); + ctx->packw_size = qpackw.mSize; + if (task == WOQ_GET_PACKW_SIZE) return; *(ctx->output) = torch::empty(qpackw.mSize, torch::kInt8); qpackw.assign(ctx->output->data_ptr()); if (p->enable_act_shuffle) @@ -106,7 +108,8 @@ void bestla_2dcpy_tensor(int row, int col, int ld_src, torch::Tensor& dst, void* dst = torch::empty({row, col}, get_torch_dtype(dtype)); auto dt_size = get_sizeof_bestla_dtype(dtype); for (int i = 0; i < row; i++) { - memcpy(reinterpret_cast(dst.data_ptr()) + i * col * dt_size, reinterpret_cast(src) + i * ld_src * dt_size, col * dt_size); + memcpy(reinterpret_cast(dst.data_ptr()) + i * col * dt_size, + reinterpret_cast(src) + i * ld_src * dt_size, col * dt_size); } } @@ -164,36 +167,50 @@ torch::Tensor get_packw_info(torch::Tensor& packw, PACKW_ACQUIRE_TYPE ACQ_T) { return output; } -void bestla_packq(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx) { +void bestla_packq(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) { TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int4_fullrange", "Qbits: only support Integer WOQ in PACKQ"); + // NTILE & compute-dtype determine the padsize. + // in qbits: + // avx_vnni/avx512f_vnni/amx-int8 NTILE==48, compute-dtype=int8; + // avx2/avx512f NTILE==48, compute-dtype=fp32; + // amx-bf16 NTILE==64, compute-dtype=bf16. + if (task == WOQ_GET_PACKW_SIZE) { + if (p->compute_type == "int8") + return execute_qpack, BTLA_ISA::AMX_INT8>(p, ctx, task); + if (p->compute_type == "fp32") + return execute_qpack, BTLA_ISA::AVX512F>(p, ctx, task); + if (p->compute_type == "bf16") + return execute_qpack, BTLA_ISA::AMX_BF16>(p, ctx, task); + } + if (p->compute_type == "int8") { if (dispatcher_utils::check_amx() && p->blocksize % bestla::gemm::ICoreRowNAmxint8KBlock<48, 16>::KTILE == 0) { - return execute_qpack, BTLA_ISA::AMX_INT8>(p, ctx); + return execute_qpack, BTLA_ISA::AMX_INT8>(p, ctx, task); } if (dispatcher_utils::check_avx512_vnni() && p->blocksize % bestla::gemm::ICoreRowNAvx512vnniKBlock<48, 4>::KTILE == 0) { - return execute_qpack, BTLA_ISA::AVX512_VNNI>(p, ctx); + return execute_qpack, BTLA_ISA::AVX512_VNNI>(p, ctx, task); } if (dispatcher_utils::check_avx_vnni() && p->blocksize % bestla::gemm::ICoreRowNAvxvnniKBlock<48, 2>::KTILE == 0) { - return execute_qpack, BTLA_ISA::AVX_VNNI>(p, ctx); + return execute_qpack, BTLA_ISA::AVX_VNNI>(p, ctx, task); } TORCH_CHECK(false, "Qbits: Illegal config in int8 compute_type, blocksize:", p->blocksize, ", ISA support vnni:", dispatcher_utils::check_avx_vnni()); } if (p->compute_type == "fp32") { if (dispatcher_utils::check_avx512f()) { - return execute_qpack, BTLA_ISA::AVX512F>(p, ctx); + return execute_qpack, BTLA_ISA::AVX512F>(p, ctx, task); } if (dispatcher_utils::check_avx2()) { - return execute_qpack, BTLA_ISA::AVX2>(p, ctx); + return execute_qpack, BTLA_ISA::AVX2>(p, ctx, task); } TORCH_CHECK(false, "Qbits: device ISA must support BTLA_ISA::AVX2 when compute_type==fp32"); } if (p->compute_type == "bf16") { if (dispatcher_utils::check_amx()) { - return execute_qpack, BTLA_ISA::AMX_BF16>(p, ctx); + return execute_qpack, BTLA_ISA::AMX_BF16>(p, ctx, task); } TORCH_CHECK(false, "Qbits: device ISA must support AMX-BF16 when compute_type==bf16"); } diff --git a/intel_extension_for_transformers/transformers/llm/operator/csrc/qbits.cpp b/intel_extension_for_transformers/transformers/llm/operator/csrc/qbits.cpp index 00f8449a7ba..1692dbadbee 100755 --- a/intel_extension_for_transformers/transformers/llm/operator/csrc/qbits.cpp +++ b/intel_extension_for_transformers/transformers/llm/operator/csrc/qbits.cpp @@ -58,26 +58,38 @@ static void inline init_woq_config_param(woq::woq_config_param* p, woq::woq_runt } } -static torch::Tensor repack_quantized_weight(const torch::Tensor& qweight, const torch::Tensor& scale, const torch::Tensor& zp, - const torch::Tensor& g_idx, const std::string& weight_type, - const std::string& scale_type, const std::string& compute_type, bool asym, - int64_t blocksize) { +static torch::Tensor repack_quantized_weight(const torch::Tensor& qweight, const torch::Tensor& scale, + const torch::Tensor& zp, const torch::Tensor& g_idx, + const std::string& weight_type, const std::string& scale_type, + const std::string& compute_type, bool asym, int64_t blocksize) { torch::Tensor output; - woq::repack_quantized_weight_param p{compute_type, weight_type, scale_type, asym, static_cast(blocksize), g_idx.numel() != 0}; + woq::repack_quantized_weight_param p{compute_type, weight_type, scale_type, asym, static_cast(blocksize), + g_idx.numel() != 0}; woq::repack_quantized_weight_ctx ctx{const_cast(&qweight), - const_cast(&scale), - const_cast(&zp), - const_cast(&g_idx), - &output, - static_cast(qweight.sizes()[1]), - static_cast(qweight.sizes()[0])}; - woq::bestla_packq(&p, &ctx); + const_cast(&scale), + const_cast(&zp), + const_cast(&g_idx), + &output, + static_cast(qweight.sizes()[1]), + static_cast(qweight.sizes()[0])}; + woq::bestla_packq(&p, &ctx, woq::WOQ_REPACK); return output; } +static size_t get_packed_weight_size(int k, int n, const std::string& weight_type, const std::string& scale_type, + const std::string& compute_type, bool asym, int64_t blocksize, bool act_shuf) { + woq::repack_quantized_weight_param p{compute_type, weight_type, scale_type, asym, static_cast(blocksize), + act_shuf}; + woq::repack_quantized_weight_ctx ctx; + ctx.n = n; + ctx.k = k; + woq::bestla_packq(&p, &ctx, woq::WOQ_GET_PACKW_SIZE); + return ctx.packw_size; +} + static torch::Tensor quantize_to_packed_weight(const torch::Tensor& fp32_weight, bool transpose, int64_t blocksize, - const std::string& compute_type, const std::string& weight_type, - const std::string& scale_type, bool asym) { + const std::string& compute_type, const std::string& weight_type, + const std::string& scale_type, bool asym) { torch::Tensor output; woq::woq_config_param p; woq::woq_runtime_ctx ctx{nullptr, const_cast(&fp32_weight), nullptr, &output, transpose}; @@ -87,9 +99,9 @@ static torch::Tensor quantize_to_packed_weight(const torch::Tensor& fp32_weight, return output; } -static void dequantize_packed_weight(const torch::Tensor& compressed_weight, torch::Tensor& dequantize_weight, bool transpose, - const std::string& compute_type, const std::string& weight_type, - const std::string& scale_type) { +static void dequantize_packed_weight(const torch::Tensor& compressed_weight, torch::Tensor& dequantize_weight, + bool transpose, const std::string& compute_type, const std::string& weight_type, + const std::string& scale_type) { woq::woq_config_param p; woq::woq_runtime_ctx ctx{nullptr, const_cast(&compressed_weight), nullptr, &dequantize_weight, transpose}; @@ -163,6 +175,7 @@ PYBIND11_MODULE(qbits, m) { m.def("woq_linear", &woq_linear); m.def("dequantize_packed_weight", &dequantize_packed_weight); m.def("repack_quantized_weight", &repack_quantized_weight); + m.def("get_packed_weight_size", &get_packed_weight_size); m.def("set_woq_workspace", &set_woq_workspace); m.def("matmul", &bestlaop_gemm); m.def("acquire_packed_weight_info", &acquire_packed_weight_info);