Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

get packed-weight size via config #1459

Merged
merged 1 commit into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ enum PACKW_ACQUIRE_TYPE {
IS_ASYM
};

void bestla_packq(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx);
void bestla_packq(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task);
torch::Tensor get_packw_info(torch::Tensor& packw, PACKW_ACQUIRE_TYPE ACQ_T);
} // namespace woq
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ enum WOQ_TASK {
WOQ_QUANTIZE,
WOQ_DEQUANTIZE,
WOQ_LINEAR,
WOQ_REPACK,
WOQ_GET_PACKW_SIZE,
};

struct woq_param_base {
Expand All @@ -47,6 +49,7 @@ struct repack_quantized_weight_param : public woq_param_base {
struct repack_quantized_weight_ctx {
torch::Tensor *qweight, *scale, *zp, *g_idx, *output;
int n, k;
size_t packw_size;
};

struct woq_runtime_ctx {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

namespace woq {
template <class GemmCore, BTLA_ISA ISA>
void execute_qpack(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx) {
void execute_qpack(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) {
using proB = bestla::prologue_b::gemm::WeightKBlockNInteger<GemmCore, ISA>;
static proB ker;
auto qpackw = ker.createStorage(ctx->n, ctx->k, p->blocksize, wei2bestladt_map[p->weight_type],
scale2bestladt_map[p->scale_type], BTLA_DTYPE::BF16, p->asym);
if (p->enable_act_shuffle) ker.enableShuffle(&qpackw);
ctx->packw_size = qpackw.mSize;
if (task == WOQ_GET_PACKW_SIZE) return;
*(ctx->output) = torch::empty(qpackw.mSize, torch::kInt8);
qpackw.assign(ctx->output->data_ptr<int8_t>());
if (p->enable_act_shuffle)
Expand Down Expand Up @@ -106,7 +108,8 @@ void bestla_2dcpy_tensor(int row, int col, int ld_src, torch::Tensor& dst, void*
dst = torch::empty({row, col}, get_torch_dtype(dtype));
auto dt_size = get_sizeof_bestla_dtype(dtype);
for (int i = 0; i < row; i++) {
memcpy(reinterpret_cast<char*>(dst.data_ptr()) + i * col * dt_size, reinterpret_cast<char*>(src) + i * ld_src * dt_size, col * dt_size);
memcpy(reinterpret_cast<char*>(dst.data_ptr()) + i * col * dt_size,
reinterpret_cast<char*>(src) + i * ld_src * dt_size, col * dt_size);
}
}

Expand Down Expand Up @@ -164,36 +167,50 @@ torch::Tensor get_packw_info(torch::Tensor& packw, PACKW_ACQUIRE_TYPE ACQ_T) {
return output;
}

void bestla_packq(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx) {
void bestla_packq(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) {
TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int4_fullrange",
"Qbits: only support Integer WOQ in PACKQ");

// NTILE & compute-dtype determine the padsize.
// in qbits:
// avx_vnni/avx512f_vnni/amx-int8 NTILE==48, compute-dtype=int8;
// avx2/avx512f NTILE==48, compute-dtype=fp32;
// amx-bf16 NTILE==64, compute-dtype=bf16.
if (task == WOQ_GET_PACKW_SIZE) {
if (p->compute_type == "int8")
return execute_qpack<bestla::gemm::ICoreRowNAmxint8KBlock<48, 16>, BTLA_ISA::AMX_INT8>(p, ctx, task);
if (p->compute_type == "fp32")
return execute_qpack<bestla::gemm::SCoreRowNAvx512f<48, 8>, BTLA_ISA::AVX512F>(p, ctx, task);
if (p->compute_type == "bf16")
return execute_qpack<bestla::gemm::HCoreRowNAmxbf16<64, 16>, BTLA_ISA::AMX_BF16>(p, ctx, task);
}

if (p->compute_type == "int8") {
if (dispatcher_utils::check_amx() && p->blocksize % bestla::gemm::ICoreRowNAmxint8KBlock<48, 16>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAmxint8KBlock<48, 16>, BTLA_ISA::AMX_INT8>(p, ctx);
return execute_qpack<bestla::gemm::ICoreRowNAmxint8KBlock<48, 16>, BTLA_ISA::AMX_INT8>(p, ctx, task);
}
if (dispatcher_utils::check_avx512_vnni() &&
p->blocksize % bestla::gemm::ICoreRowNAvx512vnniKBlock<48, 4>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAvx512vnniKBlock<48, 4>, BTLA_ISA::AVX512_VNNI>(p, ctx);
return execute_qpack<bestla::gemm::ICoreRowNAvx512vnniKBlock<48, 4>, BTLA_ISA::AVX512_VNNI>(p, ctx, task);
}
if (dispatcher_utils::check_avx_vnni() && p->blocksize % bestla::gemm::ICoreRowNAvxvnniKBlock<48, 2>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAvxvnniKBlock<48, 2>, BTLA_ISA::AVX_VNNI>(p, ctx);
return execute_qpack<bestla::gemm::ICoreRowNAvxvnniKBlock<48, 2>, BTLA_ISA::AVX_VNNI>(p, ctx, task);
}
TORCH_CHECK(false, "Qbits: Illegal config in int8 compute_type, blocksize:", p->blocksize,
", ISA support vnni:", dispatcher_utils::check_avx_vnni());
}
if (p->compute_type == "fp32") {
if (dispatcher_utils::check_avx512f()) {
return execute_qpack<bestla::gemm::SCoreRowNAvx512f<48, 8>, BTLA_ISA::AVX512F>(p, ctx);
return execute_qpack<bestla::gemm::SCoreRowNAvx512f<48, 8>, BTLA_ISA::AVX512F>(p, ctx, task);
}
if (dispatcher_utils::check_avx2()) {
return execute_qpack<bestla::gemm::SCoreRowNAvx2<48, 2>, BTLA_ISA::AVX2>(p, ctx);
return execute_qpack<bestla::gemm::SCoreRowNAvx2<48, 2>, BTLA_ISA::AVX2>(p, ctx, task);
}
TORCH_CHECK(false, "Qbits: device ISA must support BTLA_ISA::AVX2 when compute_type==fp32");
}
if (p->compute_type == "bf16") {
if (dispatcher_utils::check_amx()) {
return execute_qpack<bestla::gemm::HCoreRowNAmxbf16<64, 16>, BTLA_ISA::AMX_BF16>(p, ctx);
return execute_qpack<bestla::gemm::HCoreRowNAmxbf16<64, 16>, BTLA_ISA::AMX_BF16>(p, ctx, task);
}
TORCH_CHECK(false, "Qbits: device ISA must support AMX-BF16 when compute_type==bf16");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,26 +58,38 @@ static void inline init_woq_config_param(woq::woq_config_param* p, woq::woq_runt
}
}

static torch::Tensor repack_quantized_weight(const torch::Tensor& qweight, const torch::Tensor& scale, const torch::Tensor& zp,
const torch::Tensor& g_idx, const std::string& weight_type,
const std::string& scale_type, const std::string& compute_type, bool asym,
int64_t blocksize) {
static torch::Tensor repack_quantized_weight(const torch::Tensor& qweight, const torch::Tensor& scale,
const torch::Tensor& zp, const torch::Tensor& g_idx,
const std::string& weight_type, const std::string& scale_type,
const std::string& compute_type, bool asym, int64_t blocksize) {
torch::Tensor output;
woq::repack_quantized_weight_param p{compute_type, weight_type, scale_type, asym, static_cast<int>(blocksize), g_idx.numel() != 0};
woq::repack_quantized_weight_param p{compute_type, weight_type, scale_type, asym, static_cast<int>(blocksize),
g_idx.numel() != 0};
woq::repack_quantized_weight_ctx ctx{const_cast<torch::Tensor*>(&qweight),
const_cast<torch::Tensor*>(&scale),
const_cast<torch::Tensor*>(&zp),
const_cast<torch::Tensor*>(&g_idx),
&output,
static_cast<int>(qweight.sizes()[1]),
static_cast<int>(qweight.sizes()[0])};
woq::bestla_packq(&p, &ctx);
const_cast<torch::Tensor*>(&scale),
const_cast<torch::Tensor*>(&zp),
const_cast<torch::Tensor*>(&g_idx),
&output,
static_cast<int>(qweight.sizes()[1]),
static_cast<int>(qweight.sizes()[0])};
woq::bestla_packq(&p, &ctx, woq::WOQ_REPACK);
return output;
}

static size_t get_packed_weight_size(int k, int n, const std::string& weight_type, const std::string& scale_type,
const std::string& compute_type, bool asym, int64_t blocksize, bool act_shuf) {
woq::repack_quantized_weight_param p{compute_type, weight_type, scale_type, asym, static_cast<int>(blocksize),
act_shuf};
woq::repack_quantized_weight_ctx ctx;
ctx.n = n;
ctx.k = k;
woq::bestla_packq(&p, &ctx, woq::WOQ_GET_PACKW_SIZE);
return ctx.packw_size;
}

static torch::Tensor quantize_to_packed_weight(const torch::Tensor& fp32_weight, bool transpose, int64_t blocksize,
const std::string& compute_type, const std::string& weight_type,
const std::string& scale_type, bool asym) {
const std::string& compute_type, const std::string& weight_type,
const std::string& scale_type, bool asym) {
torch::Tensor output;
woq::woq_config_param p;
woq::woq_runtime_ctx ctx{nullptr, const_cast<torch::Tensor*>(&fp32_weight), nullptr, &output, transpose};
Expand All @@ -87,9 +99,9 @@ static torch::Tensor quantize_to_packed_weight(const torch::Tensor& fp32_weight,
return output;
}

static void dequantize_packed_weight(const torch::Tensor& compressed_weight, torch::Tensor& dequantize_weight, bool transpose,
const std::string& compute_type, const std::string& weight_type,
const std::string& scale_type) {
static void dequantize_packed_weight(const torch::Tensor& compressed_weight, torch::Tensor& dequantize_weight,
bool transpose, const std::string& compute_type, const std::string& weight_type,
const std::string& scale_type) {
woq::woq_config_param p;
woq::woq_runtime_ctx ctx{nullptr, const_cast<torch::Tensor*>(&compressed_weight), nullptr, &dequantize_weight,
transpose};
Expand Down Expand Up @@ -163,6 +175,7 @@ PYBIND11_MODULE(qbits, m) {
m.def("woq_linear", &woq_linear);
m.def("dequantize_packed_weight", &dequantize_packed_weight);
m.def("repack_quantized_weight", &repack_quantized_weight);
m.def("get_packed_weight_size", &get_packed_weight_size);
m.def("set_woq_workspace", &set_woq_workspace);
m.def("matmul", &bestlaop_gemm);
m.def("acquire_packed_weight_info", &acquire_packed_weight_info);
Expand Down
Loading