Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 9e22650

Browse files
committed
remove clip postfix
1 parent bc268c5 commit 9e22650

File tree

7 files changed

+21
-28
lines changed

7 files changed

+21
-28
lines changed

docs/qbits.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import intel_extension_for_transformers.qbits as qbits
1616
transpose (bool): Whether to transpose the weight tensor (required for quantize_to_packed_weight with KxN weight shape).
1717
blocksize (int): Blocksize for weight-only quantization.
1818
compute_type (str): Computation type (fp32/bf16/int8). fp32 will leverage AVX2/AVX512F to compute, bf16 will be AMX_BF16, int8 will be VNNI/AMX_INT8.
19-
weight_type (str): Quantization type (int8/int4_clip/int4_fullrange/nf4/fp4_e2m1).
19+
weight_type (str): Quantization type (int8/int4/int3/int2/nf4/fp4_e2m1).
2020
scale_type (str): Scale type (fp32/bf16).
2121
asym (bool): Whether to use asymmetric quantization.
2222
@@ -37,7 +37,7 @@ pack_weight = qbits.quantize_to_packed_weight(
3737
g_idx (torch.Tensor): shuffle index used by GPTQ, dtype must be int32.
3838
blocksize (int): Blocksize for weight-only quantization.
3939
compute_type (str): Computation type (fp32/bf16/int8). fp32 will leverage AVX2/AVX512F to compute, bf16 will be AMX_BF16, int8 will be VNNI/AMX_INT8.
40-
weight_type (str): Quantization type (int8/int4_clip/int4_fullrange/nf4/fp4_e2m1).
40+
weight_type (str): Quantization type (int8/int4/int3/int2/nf4/fp4_e2m1).
4141
scale_type (str): Scale type (fp32/bf16).
4242
asym (bool): Whether to use asymmetric quantization.
4343
@@ -57,7 +57,7 @@ pack_weight = qbits.repack_quantized_weight(
5757
bias (torch.Tensor): Bias tensor, must be fp32, if bias is empty woq_linear will not add bias.
5858
output (torch.Tensor): Output tensor, support fp32/bf16, shape must be MxN.
5959
compute_type (str): Computation type (fp32/bf16/int8).fp32 will leverage AVX2/AVX512F to compute, bf16 will leverage AMX_BF16 to compute, int8 will leverage VNNI/AMX_INT8 to compute.
60-
weight_type (str): Quantization type (int8/int4_clip/int4_fullrange/nf4/fp4_e2m1).
60+
weight_type (str): Quantization type (int8/int4/int3/int2/nf4/fp4_e2m1).
6161
scale_type (str): Scale type (fp32/bf16).
6262
asym (bool): Whether to use asymmetric quantization.
6363
"""

examples/vllm/vllm_acceleration_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
4141
config = RtnConfig(compute_dtype="int8",
4242
group_size=128,
4343
scale_dtype="bf16",
44-
weight_dtype="int4_clip",
44+
weight_dtype="int4",
4545
bits=4)
4646
print(config)
4747
prompts = [args.prompt]

intel_extension_for_transformers/qbits/dispatcher/include/bestla_weightonly_dispatcher.hpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,10 @@ struct woq_runtime_ctx {
5959
bestla::storage::gemm::IWeightBase* deseries_wei;
6060
};
6161

62-
static std::map<std::string, BTLA_DTYPE> wei2bestladt_map{{"int8", BTLA_DTYPE::S8},
63-
{"int4_clip", BTLA_DTYPE::S4_CLIP},
64-
{"int3_clip", BTLA_DTYPE::S3_CLIP},
65-
{"int2_clip", BTLA_DTYPE::S2_CLIP},
66-
{"nf4", BTLA_DTYPE::F4_NF4},
67-
{"fp4_e2m1_bnb", BTLA_DTYPE::F4_BNB},
68-
{"fp4_e2m1", BTLA_DTYPE::F4_E2M1},
69-
{"fp8_e4m3", BTLA_DTYPE::F8_E4M3},
70-
{"fp8_e5m2", BTLA_DTYPE::F8_E5M2}};
62+
static std::map<std::string, BTLA_DTYPE> wei2bestladt_map{
63+
{"int8", BTLA_DTYPE::S8}, {"int4", BTLA_DTYPE::S4_CLIP}, {"int3", BTLA_DTYPE::S3_CLIP},
64+
{"int2", BTLA_DTYPE::S2_CLIP}, {"nf4", BTLA_DTYPE::F4_NF4}, {"fp4_e2m1_bnb", BTLA_DTYPE::F4_BNB},
65+
{"fp4_e2m1", BTLA_DTYPE::F4_E2M1}, {"fp8_e4m3", BTLA_DTYPE::F8_E4M3}, {"fp8_e5m2", BTLA_DTYPE::F8_E5M2}};
7166
static std::map<std::string, BTLA_DTYPE> scale2bestladt_map{
7267
{"fp32", BTLA_DTYPE::F32}, {"bf16", BTLA_DTYPE::BF16}, {"fp8_e8m0", BTLA_DTYPE::F8_E8M0}};
7368

intel_extension_for_transformers/qbits/dispatcher/src/bestla_packq_impl.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ void execute_qpack(repack_quantized_weight_param* p, repack_quantized_weight_ctx
4242

4343
template <class GemmCore, BTLA_ISA ISA>
4444
void parse_prob(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) {
45-
if (p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" ||
46-
p->weight_type == "int2_clip") {
45+
if (p->weight_type == "int8" || p->weight_type == "int4" || p->weight_type == "int3" || p->weight_type == "int2") {
4746
return execute_qpack<bestla::prologue_b::gemm::WeightKBlockNInteger<GemmCore, ISA>>(p, ctx, task);
4847
}
4948
if (p->weight_type == "nf4" || p->weight_type == "fp4_e2m1_bnb" || p->weight_type == "fp4_e2m1") {
@@ -61,11 +60,11 @@ std::string get_dtype_str(BTLA_DTYPE dtype) {
6160
case BTLA_DTYPE::BF16:
6261
return "bf16";
6362
case BTLA_DTYPE::S4_CLIP:
64-
return "int4_clip";
63+
return "int4";
6564
case BTLA_DTYPE::S3_CLIP:
66-
return "int3_clip";
65+
return "int3";
6766
case BTLA_DTYPE::S2_CLIP:
68-
return "int2_clip";
67+
return "int2";
6968
case BTLA_DTYPE::F4_NF4:
7069
return "nf4";
7170
case BTLA_DTYPE::F4_E2M1:
@@ -205,9 +204,9 @@ torch::Tensor get_packw_info(torch::Tensor& packw, PACKW_ACQUIRE_TYPE ACQ_T) {
205204

206205
void bestla_packq(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) {
207206
if (p->compute_type == "int8") {
208-
TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" ||
209-
p->weight_type == "int2_clip",
210-
"Qbits: only support Integer weight-type with int8 compute-type");
207+
TORCH_CHECK(
208+
p->weight_type == "int8" || p->weight_type == "int4" || p->weight_type == "int3" || p->weight_type == "int2",
209+
"Qbits: only support Integer weight-type with int8 compute-type");
211210
if (dispatcher_utils::check_amx() && p->blocksize % bestla::gemm::ICoreRowNAmxint8KBlock<64, 16>::KTILE == 0) {
212211
return parse_prob<bestla::gemm::ICoreRowNAmxint8KBlock<64, 16>, BTLA_ISA::AMX_INT8>(p, ctx, task);
213212
}

intel_extension_for_transformers/qbits/dispatcher/src/bestla_weightonly_dispatcher.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,7 @@ void parse_activation(woq_config_param* p, woq_runtime_ctx* ctx) {
276276
template <WOQ_TASK TASK, class GemmCore>
277277
void parse_weight(woq_config_param* p, woq_runtime_ctx* ctx) {
278278
using namespace bestla::prologue_b::gemm;
279-
if (p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" ||
280-
p->weight_type == "int2_clip") {
279+
if (p->weight_type == "int8" || p->weight_type == "int4" || p->weight_type == "int3" || p->weight_type == "int2") {
281280
return parse_activation<TASK, GemmCore, WeightKBlockNInteger>(p, ctx);
282281
}
283282
if (p->weight_type == "nf4" || p->weight_type == "fp4_e2m1_bnb" || p->weight_type == "fp4_e2m1" ||

intel_extension_for_transformers/qbits/qbits_ut/test_packq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class acquire_type(Enum):
4848
@pytest.mark.parametrize("k", [512])
4949
@pytest.mark.parametrize("blocksize", [128])
5050
@pytest.mark.parametrize("compute_type", ["fp32", "bf16", "int8"])
51-
@pytest.mark.parametrize("weight_type", ["int8", "int4_clip"])
51+
@pytest.mark.parametrize("weight_type", ["int8", "int4"])
5252
@pytest.mark.parametrize("scale_type", ["fp32"])
5353
@pytest.mark.parametrize("asym", [True, False])
5454
def test(m, k, n, weight_type, scale_type, compute_type, asym, blocksize, dump_tensor=False):

intel_extension_for_transformers/qbits/qbits_ut/test_weightonly.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717

1818
from ut_utils import *
1919

20-
cmpt_configs = {"int8": {"int8", "bf16", "fp32"}, "int4_clip": {"int8", "fp32", "bf16"}, "int3_clip": {"int8", "fp32", "bf16"}, "int2_clip": {"int8", "fp32", "bf16"}, "fp4_e2m1_bnb": {"fp32", "bf16"}, "fp4_e2m1": {"fp32", "bf16"}, "nf4": {"fp32", "bf16"},
20+
cmpt_configs = {"int8": {"int8", "bf16", "fp32"}, "int4": {"int8", "fp32", "bf16"}, "int3": {"int8", "fp32", "bf16"}, "int2": {"int8", "fp32", "bf16"}, "fp4_e2m1_bnb": {"fp32", "bf16"}, "fp4_e2m1": {"fp32", "bf16"}, "nf4": {"fp32", "bf16"},
2121
"fp8_e5m2": {"fp32", "bf16"}, "fp8_e4m3": {"fp32", "bf16"}
2222
}
2323

24-
scale_configs = {"int8": {"fp32", "bf16"}, "int4_clip": {"fp32", "bf16"}, "int3_clip": {"fp32", "bf16"}, "int2_clip": {"fp32", "bf16"}, "fp4_e2m1_bnb": {"fp32", "bf16"}, "fp4_e2m1": {"fp32", "bf16"}, "nf4": {"fp32", "bf16"},
24+
scale_configs = {"int8": {"fp32", "bf16"}, "int4": {"fp32", "bf16"}, "int3": {"fp32", "bf16"}, "int2": {"fp32", "bf16"}, "fp4_e2m1_bnb": {"fp32", "bf16"}, "fp4_e2m1": {"fp32", "bf16"}, "nf4": {"fp32", "bf16"},
2525
"fp8_e5m2": {"fp32", "fp8_e8m0"}, "fp8_e4m3": {"fp32", "fp8_e8m0"}}
2626

27-
asym_configs = {"int8", "int4_clip", "int3_clip", "int2_clip"}
27+
asym_configs = {"int8", "int4", "int3", "int2"}
2828

2929

3030
@capture_args
@@ -33,7 +33,7 @@
3333
@pytest.mark.parametrize("k", [512])
3434
@pytest.mark.parametrize("blocksize", [128, -1])
3535
@pytest.mark.parametrize("compute_type", ["int8", "bf16", "fp32"])
36-
@pytest.mark.parametrize("weight_type", ["int8", "int4_clip", "int3_clip", "int2_clip", "nf4", "fp4_e2m1_bnb", "fp4_e2m1", "fp8_e5m2", "fp8_e4m3"])
36+
@pytest.mark.parametrize("weight_type", ["int8", "int4", "int3", "int2", "nf4", "fp4_e2m1_bnb", "fp4_e2m1", "fp8_e5m2", "fp8_e4m3"])
3737
@pytest.mark.parametrize("scale_type", ["fp32", "bf16", "fp8_e8m0"])
3838
@pytest.mark.parametrize("asym", [True, False])
3939
@pytest.mark.parametrize("transpose", [True, False])

0 commit comments

Comments
 (0)