Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit b400cb9

Browse files
authored
Adapt quant lm head (#1671)
Signed-off-by: Wang, Chang <[email protected]>
1 parent 3e78ae8 commit b400cb9

File tree

4 files changed

+33
-31
lines changed

4 files changed

+33
-31
lines changed

examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,6 @@
9696
help="Use determined group to do quantization",
9797
)
9898
# ============AutoRound==================
99-
parser.add_argument(
100-
"--autoround_iters",
101-
default=2048,
102-
type=int,
103-
help="Calibration dataset max or padding max length for AutoRound.",
104-
)
10599
parser.add_argument(
106100
"--lr",
107101
type=float,
@@ -172,7 +166,6 @@
172166
bits=args.bits,
173167
sym=True if args.scheme == "sym" else False,
174168
group_size=args.group_size,
175-
seq_len=args.seq_len,
176169
compute_dtype=args.compute_dtype,
177170
scale_dtype=args.compute_dtype,
178171
weight_dtype=args.weight_dtype,

intel_extension_for_transformers/transformers/llm/quantization/utils.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,12 @@ def replace_linear(
136136
if modules_to_not_convert is None:
137137
# output_layer is chatglm last layer name
138138
# embed_out is dolly_v2 last layer name
139-
modules_to_not_convert = ["lm_head", "output_layer", "embed_out"]
139+
modules_to_not_convert = []
140140
if quantization_config.llm_int8_skip_modules:
141-
modules_to_not_convert = modules_to_not_convert.extend(
141+
modules_to_not_convert.extend(
142142
quantization_config.llm_int8_skip_modules
143143
)
144+
modules_to_not_convert = list(set(modules_to_not_convert))
144145
model, is_replaced = _replace_linear(
145146
model,
146147
modules_to_not_convert,
@@ -559,9 +560,11 @@ def convert_to_quantized_model(model, config, device="cpu"):
559560
group_size=config.group_size,
560561
use_layer_wise=config.layer_wise,
561562
)
562-
quant_config.set_local(".*lm_head", RTNConfig(dtype="fp32"))
563-
quant_config.set_local(".*output_layer", RTNConfig(dtype="fp32"))
564-
quant_config.set_local(".*embed_out", RTNConfig(dtype="fp32"))
563+
if config.llm_int8_skip_modules != []:
564+
for module in config.llm_int8_skip_modules:
565+
module_name = ".*" + module
566+
quant_config.set_local(module_name, RTNConfig(dtype="fp32"))
567+
logger.info(f"Do RTN algorithm with config {quant_config}")
565568
model = prepare(model, quant_config)
566569
model = convert(model)
567570
elif config.quant_method.value == "awq":
@@ -575,9 +578,10 @@ def convert_to_quantized_model(model, config, device="cpu"):
575578
use_auto_clip=config.auto_clip,
576579
folding=True,
577580
)
578-
quant_config.set_local(".*lm_head", AWQConfig(dtype="fp32"))
579-
quant_config.set_local(".*output_layer", AWQConfig(dtype="fp32"))
580-
quant_config.set_local(".*embed_out", AWQConfig(dtype="fp32"))
581+
if config.llm_int8_skip_modules != []:
582+
for module in config.llm_int8_skip_modules:
583+
module_name = ".*" + module
584+
quant_config.set_local(module_name, AWQConfig(dtype="fp32"))
581585
logger.info(f"Do AWQ algorithm with config {quant_config}")
582586
run_fn = default_run_fn
583587
run_args = (
@@ -601,9 +605,10 @@ def convert_to_quantized_model(model, config, device="cpu"):
601605
use_layer_wise=config.layer_wise,
602606
absorb_to_layer=config.absorb_to_layer
603607
)
604-
quant_config.set_local(".*lm_head", TEQConfig(dtype="fp32"))
605-
quant_config.set_local(".*output_layer", TEQConfig(dtype="fp32"))
606-
quant_config.set_local(".*embed_out", TEQConfig(dtype="fp32"))
608+
if config.llm_int8_skip_modules != []:
609+
for module in config.llm_int8_skip_modules:
610+
module_name = ".*" + module
611+
quant_config.set_local(module_name, TEQConfig(dtype="fp32"))
607612
logger.info(f"Do TEQ algorithm with config {quant_config}")
608613
run_fn = default_run_fn
609614
run_args = (
@@ -632,9 +637,10 @@ def convert_to_quantized_model(model, config, device="cpu"):
632637
block_size=config.blocksize,
633638
static_groups=config.static_groups,
634639
)
635-
quant_config.set_local(".*lm_head", GPTQConfig(dtype="fp32"))
636-
quant_config.set_local(".*output_layer", GPTQConfig(dtype="fp32"))
637-
quant_config.set_local(".*embed_out", GPTQConfig(dtype="fp32"))
640+
if config.llm_int8_skip_modules != []:
641+
for module in config.llm_int8_skip_modules:
642+
module_name = ".*" + module
643+
quant_config.set_local(module_name, GPTQConfig(dtype="fp32"))
638644
logger.info(f"Do GPTQ algorithm with config {quant_config}")
639645
run_fn = default_run_fn
640646
run_args = (
@@ -662,10 +668,10 @@ def convert_to_quantized_model(model, config, device="cpu"):
662668
iters=config.iters,
663669
scale_dtype=config.scale_dtype,
664670
)
665-
if config.quant_lm_head is False:
666-
quant_config.set_local(".*lm_head", AutoRoundConfig(dtype="fp32"))
667-
quant_config.set_local(".*output_layer", AutoRoundConfig(dtype="fp32"))
668-
quant_config.set_local(".*embed_out", AutoRoundConfig(dtype="fp32"))
671+
if config.llm_int8_skip_modules != []:
672+
for module in config.llm_int8_skip_modules:
673+
module_name = ".*" + module
674+
quant_config.set_local(module_name, AutoRoundConfig(dtype="fp32"))
669675
logger.info(f"Do AutoRound algorithm with config {quant_config}")
670676
dataloader = get_autoround_dataloader(tokenizer=config.tokenizer,
671677
seqlen=config.seq_len,

intel_extension_for_transformers/transformers/modeling/modeling_auto.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def build_woq_model(model, quantization_config):
161161
from neural_compressor.adaptor.torch_utils.util import set_module
162162
weight_dtype = quantization_config.weight_dtype
163163
for n, m in model.named_modules():
164-
if "lm_head" in n or "output_layer" in n or "embed_out" in n:
164+
if n in quantization_config.llm_int8_skip_modules:
165165
continue
166166
if isinstance(m, torch.nn.Linear):
167167
zp = getattr(
@@ -883,6 +883,7 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]:
883883
hasattr(torch, "xpu") and torch.xpu.is_available()
884884
), "There is no xpu device in this system!"
885885
quantization_config.update(**{"device": "xpu"})
886+
quantization_config.post_init_xpu()
886887
if (
887888
not torch.cuda.is_available()
888889
or device_map == "cpu"

intel_extension_for_transformers/transformers/utils/config.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -831,7 +831,7 @@ def __init__(
831831
self.double_quant_bits = double_quant_bits
832832
self.double_quant_use_sym = double_quant_use_sym
833833
self.double_quant_group_size = double_quant_group_size
834-
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", [])
834+
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"])
835835
self.use_ggml = use_ggml
836836
self.use_quant = use_quant
837837
self.use_neural_speed = use_neural_speed
@@ -911,7 +911,7 @@ def __init__(
911911
self.true_sequential = true_sequential
912912
self.layer_wise = layer_wise
913913
self.seq_len = seq_len
914-
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", [])
914+
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"])
915915
self.use_ggml = use_ggml
916916
self.use_quant = use_quant
917917
self.use_neural_speed = use_neural_speed
@@ -1009,7 +1009,7 @@ def __init__(
10091009
self.seq_len = seq_len
10101010
self.use_double_quant = use_double_quant
10111011
self.double_quant_scale_dtype = double_quant_scale_dtype
1012-
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", [])
1012+
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"])
10131013
self.use_ggml = use_ggml
10141014
self.use_quant = use_quant
10151015
self.use_neural_speed = use_neural_speed
@@ -1078,7 +1078,7 @@ def __init__(
10781078
self.seq_len = seq_len
10791079
self.use_double_quant = use_double_quant
10801080
self.double_quant_scale_dtype = double_quant_scale_dtype
1081-
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", [])
1081+
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"])
10821082
self.use_ggml = use_ggml
10831083
self.use_neural_speed = use_neural_speed
10841084
self.device = kwargs.get("device", "auto")
@@ -1154,7 +1154,9 @@ def __init__(
11541154
self.iters = iters
11551155
self.seq_len = seq_len
11561156
self.quant_lm_head = quant_lm_head
1157-
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", [])
1157+
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"])
1158+
if self.quant_lm_head:
1159+
self.llm_int8_skip_modules = []
11581160
self.use_ggml = use_ggml
11591161
self.use_neural_speed = use_neural_speed
11601162
self.batch_size = kwargs.pop("batch_size", 8)

0 commit comments

Comments
 (0)