@@ -136,11 +136,12 @@ def replace_linear(
136
136
if modules_to_not_convert is None :
137
137
# output_layer is chatglm last layer name
138
138
# embed_out is dolly_v2 last layer name
139
- modules_to_not_convert = ["lm_head" , "output_layer" , "embed_out" ]
139
+ modules_to_not_convert = []
140
140
if quantization_config .llm_int8_skip_modules :
141
- modules_to_not_convert = modules_to_not_convert .extend (
141
+ modules_to_not_convert .extend (
142
142
quantization_config .llm_int8_skip_modules
143
143
)
144
+ modules_to_not_convert = list (set (modules_to_not_convert ))
144
145
model , is_replaced = _replace_linear (
145
146
model ,
146
147
modules_to_not_convert ,
@@ -559,9 +560,11 @@ def convert_to_quantized_model(model, config, device="cpu"):
559
560
group_size = config .group_size ,
560
561
use_layer_wise = config .layer_wise ,
561
562
)
562
- quant_config .set_local (".*lm_head" , RTNConfig (dtype = "fp32" ))
563
- quant_config .set_local (".*output_layer" , RTNConfig (dtype = "fp32" ))
564
- quant_config .set_local (".*embed_out" , RTNConfig (dtype = "fp32" ))
563
+ if config .llm_int8_skip_modules != []:
564
+ for module in config .llm_int8_skip_modules :
565
+ module_name = ".*" + module
566
+ quant_config .set_local (module_name , RTNConfig (dtype = "fp32" ))
567
+ logger .info (f"Do RTN algorithm with config { quant_config } " )
565
568
model = prepare (model , quant_config )
566
569
model = convert (model )
567
570
elif config .quant_method .value == "awq" :
@@ -575,9 +578,10 @@ def convert_to_quantized_model(model, config, device="cpu"):
575
578
use_auto_clip = config .auto_clip ,
576
579
folding = True ,
577
580
)
578
- quant_config .set_local (".*lm_head" , AWQConfig (dtype = "fp32" ))
579
- quant_config .set_local (".*output_layer" , AWQConfig (dtype = "fp32" ))
580
- quant_config .set_local (".*embed_out" , AWQConfig (dtype = "fp32" ))
581
+ if config .llm_int8_skip_modules != []:
582
+ for module in config .llm_int8_skip_modules :
583
+ module_name = ".*" + module
584
+ quant_config .set_local (module_name , AWQConfig (dtype = "fp32" ))
581
585
logger .info (f"Do AWQ algorithm with config { quant_config } " )
582
586
run_fn = default_run_fn
583
587
run_args = (
@@ -601,9 +605,10 @@ def convert_to_quantized_model(model, config, device="cpu"):
601
605
use_layer_wise = config .layer_wise ,
602
606
absorb_to_layer = config .absorb_to_layer
603
607
)
604
- quant_config .set_local (".*lm_head" , TEQConfig (dtype = "fp32" ))
605
- quant_config .set_local (".*output_layer" , TEQConfig (dtype = "fp32" ))
606
- quant_config .set_local (".*embed_out" , TEQConfig (dtype = "fp32" ))
608
+ if config .llm_int8_skip_modules != []:
609
+ for module in config .llm_int8_skip_modules :
610
+ module_name = ".*" + module
611
+ quant_config .set_local (module_name , TEQConfig (dtype = "fp32" ))
607
612
logger .info (f"Do TEQ algorithm with config { quant_config } " )
608
613
run_fn = default_run_fn
609
614
run_args = (
@@ -632,9 +637,10 @@ def convert_to_quantized_model(model, config, device="cpu"):
632
637
block_size = config .blocksize ,
633
638
static_groups = config .static_groups ,
634
639
)
635
- quant_config .set_local (".*lm_head" , GPTQConfig (dtype = "fp32" ))
636
- quant_config .set_local (".*output_layer" , GPTQConfig (dtype = "fp32" ))
637
- quant_config .set_local (".*embed_out" , GPTQConfig (dtype = "fp32" ))
640
+ if config .llm_int8_skip_modules != []:
641
+ for module in config .llm_int8_skip_modules :
642
+ module_name = ".*" + module
643
+ quant_config .set_local (module_name , GPTQConfig (dtype = "fp32" ))
638
644
logger .info (f"Do GPTQ algorithm with config { quant_config } " )
639
645
run_fn = default_run_fn
640
646
run_args = (
@@ -662,10 +668,10 @@ def convert_to_quantized_model(model, config, device="cpu"):
662
668
iters = config .iters ,
663
669
scale_dtype = config .scale_dtype ,
664
670
)
665
- if config .quant_lm_head is False :
666
- quant_config . set_local ( ".*lm_head" , AutoRoundConfig ( dtype = "fp32" ))
667
- quant_config . set_local ( ".*output_layer" , AutoRoundConfig ( dtype = "fp32" ))
668
- quant_config .set_local (".*embed_out" , AutoRoundConfig (dtype = "fp32" ))
671
+ if config .llm_int8_skip_modules != [] :
672
+ for module in config . llm_int8_skip_modules :
673
+ module_name = ".*" + module
674
+ quant_config .set_local (module_name , AutoRoundConfig (dtype = "fp32" ))
669
675
logger .info (f"Do AutoRound algorithm with config { quant_config } " )
670
676
dataloader = get_autoround_dataloader (tokenizer = config .tokenizer ,
671
677
seqlen = config .seq_len ,
0 commit comments