Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 6a15b48

Browse files
authored
Add DynamicQuantConfig and QuantAwareTrainingConfig (#1505)
1 parent 8116fbb commit 6a15b48

File tree

5 files changed

+288
-4
lines changed

5 files changed

+288
-4
lines changed

intel_extension_for_transformers/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
BitsAndBytesConfig,
4646
SmoothQuantConfig,
4747
StaticQuantConfig,
48+
DynamicQuantConfig,
49+
QuantAwareTrainingConfig,
4850
RtnConfig,
4951
AwqConfig,
5052
TeqConfig,

intel_extension_for_transformers/transformers/modeling/modeling_auto.py

Lines changed: 196 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
MixedPrecisionConfig,
4444
SmoothQuantConfig,
4545
StaticQuantConfig,
46+
DynamicQuantConfig,
47+
QuantAwareTrainingConfig,
4648
RtnConfig,
4749
AwqConfig,
4850
TeqConfig,
@@ -412,7 +414,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
412414
"Quantization_config loading failed. If you want to load saved "
413415
"low bit model, please check your quantizate_config.json."
414416
)
415-
elif use_neural_speed and not config.quantization_config["quant_method"] == "static":
417+
elif use_neural_speed and not config.quantization_config["quant_method"] in ["dynamic", "static", "qat"]:
416418
if not os.path.exists(pretrained_model_name_or_path):
417419
from huggingface_hub import snapshot_download
418420
pretrained_model_name_or_path = snapshot_download(repo_id=pretrained_model_name_or_path,
@@ -972,6 +974,42 @@ def calib_func(model):
972974
),
973975
)
974976
logger.info("SmoothQuant done.")
977+
elif isinstance(quantization_config, DynamicQuantConfig):
978+
model = cls.ORIG_MODEL.from_pretrained(
979+
pretrained_model_name_or_path,
980+
*model_args,
981+
config=config,
982+
low_cpu_mem_usage=True,
983+
torch_dtype=torch.float,
984+
**kwargs,
985+
)
986+
987+
if (
988+
not torch.cuda.is_available()
989+
or device_map == "cpu"
990+
or device_map == torch.device("cpu")
991+
) and model.config.model_type == "chatglm":
992+
model = model.float()
993+
model.eval()
994+
logger.info("Applying DynamicQuant.")
995+
# call inc dynamic quant
996+
from neural_compressor import PostTrainingQuantConfig, quantization
997+
998+
conf = PostTrainingQuantConfig(
999+
approach="dynamic",
1000+
excluded_precisions=quantization_config.excluded_precisions,
1001+
op_type_dict=quantization_config.op_type_dict,
1002+
op_name_dict=quantization_config.op_name_dict,
1003+
)
1004+
model = quantization.fit(
1005+
model,
1006+
conf,
1007+
)
1008+
model.save_pretrained = types.MethodType(save_low_bit, model)
1009+
quantization_config.remove_redundant_parameters()
1010+
model.quantization_config = quantization_config
1011+
logger.info("DynamicQuant done.")
1012+
return model
9751013
elif isinstance(quantization_config, StaticQuantConfig):
9761014
if quantization_config.backend == "ipex":
9771015
try:
@@ -1107,7 +1145,7 @@ def calib_func(model):
11071145
from neural_compressor import PostTrainingQuantConfig, quantization
11081146

11091147
conf = PostTrainingQuantConfig(
1110-
backend=quantization_config.backend, # default is ipex
1148+
backend=quantization_config.backend,
11111149
excluded_precisions=quantization_config.excluded_precisions,
11121150
op_type_dict=quantization_config.op_type_dict,
11131151
op_name_dict=quantization_config.op_name_dict,
@@ -1123,6 +1161,157 @@ def calib_func(model):
11231161
model.quantization_config = quantization_config
11241162
logger.info("StaticQuant done.")
11251163
return model
1164+
elif isinstance(quantization_config, QuantAwareTrainingConfig):
1165+
model = cls.ORIG_MODEL.from_pretrained(
1166+
pretrained_model_name_or_path,
1167+
*model_args,
1168+
config=config,
1169+
low_cpu_mem_usage=True,
1170+
torch_dtype=torch.float,
1171+
**kwargs,
1172+
)
1173+
1174+
if (
1175+
not torch.cuda.is_available()
1176+
or device_map == "cpu"
1177+
or device_map == torch.device("cpu")
1178+
) and model.config.model_type == "chatglm":
1179+
model = model.float()
1180+
logger.info("Applying QuantAwareTraining.")
1181+
# train function
1182+
train_func = quantization_config.train_func
1183+
tokenizer = quantization_config.tokenizer
1184+
if train_func is None:
1185+
if quantization_config.tokenizer is None:
1186+
logger.error(
1187+
"Please provide the tokenizer or provide train_func directly,"
1188+
+ " the following is how to get tokenizer. \n"
1189+
+ " from transformer import AutoTokenizer \n"
1190+
+ " tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) \n"
1191+
)
1192+
exit(0)
1193+
1194+
from datasets import load_dataset
1195+
from torch.utils.data import DataLoader
1196+
1197+
train_dataset = quantization_config.train_dataset
1198+
train_shuffle = quantization_config.train_shuffle
1199+
train_iters = quantization_config.train_iters
1200+
train_padding = quantization_config.train_padding
1201+
train_len = quantization_config.train_len
1202+
train_pad_val = quantization_config.train_pad_val
1203+
from torch.nn.functional import pad
1204+
1205+
train_dataset = load_dataset(
1206+
train_dataset,
1207+
split=(
1208+
"test"
1209+
if train_dataset in ["mbpp", "openai_humaneval"]
1210+
else "train"
1211+
),
1212+
)
1213+
if train_shuffle:
1214+
train_dataset = train_dataset.shuffle(seed=42)
1215+
1216+
def tokenize_function(examples):
1217+
if "code" in examples:
1218+
example = tokenizer(examples["code"])
1219+
elif "prompt" in examples:
1220+
example = tokenizer(examples["prompt"])
1221+
elif "text" in examples:
1222+
example = tokenizer(examples["text"])
1223+
else:
1224+
logger.error(
1225+
"Please check dataset prompt identifier,"
1226+
+ " NeelNanda/pile-10k is default used calibration dataset."
1227+
)
1228+
exit(0)
1229+
return example
1230+
1231+
def collate_batch(batch):
1232+
input_ids_padded = []
1233+
last_ind = []
1234+
for text in batch:
1235+
input_ids = text["input_ids"]
1236+
if not train_padding:
1237+
input_ids = (
1238+
input_ids[: int(train_len)]
1239+
if len(input_ids) > int(train_len)
1240+
else input_ids
1241+
) # no_padding
1242+
else:
1243+
pad_len = train_len - input_ids.shape[0]
1244+
input_ids = pad(
1245+
input_ids, (0, pad_len), value=train_pad_val
1246+
)
1247+
1248+
last_ind.append(input_ids.shape[0] - 1)
1249+
input_ids_padded.append(input_ids)
1250+
1251+
return (
1252+
{
1253+
"input_ids": torch.vstack(input_ids_padded),
1254+
},
1255+
torch.tensor(last_ind),
1256+
)
1257+
1258+
1259+
tokenized_dataset = train_dataset.map(tokenize_function, batched=True)
1260+
tokenized_dataset.set_format(type="torch", columns=["input_ids"])
1261+
train_dataloader = DataLoader(
1262+
tokenized_dataset,
1263+
batch_size=quantization_config.train_batch_size,
1264+
shuffle=False,
1265+
collate_fn=collate_batch,
1266+
)
1267+
1268+
def train_func(model):
1269+
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
1270+
# switch to evaluate mode
1271+
model.train()
1272+
for i, (inputs, last_ind) in enumerate(train_dataloader):
1273+
if i >= train_iters:
1274+
break
1275+
output = model(**inputs)
1276+
if isinstance(output, tuple):
1277+
loss = output[0].mean()
1278+
elif isinstance(output, dict):
1279+
loss = output["logits"].mean()
1280+
else:
1281+
loss = output.mean()
1282+
optimizer.zero_grad()
1283+
loss.backward()
1284+
optimizer.step()
1285+
print('Iteration [{}], Loss: {:.4f}'.format(i+1, loss))
1286+
return model
1287+
1288+
logger.info(
1289+
"The default calibration function is used, "
1290+
+ "the calibration dataset is NeelNanda/pile-10k, "
1291+
+ "batchsize is 1 and calibration iteration is 100."
1292+
)
1293+
train_func = train_func
1294+
1295+
1296+
# call inc static quant
1297+
from neural_compressor import QuantizationAwareTrainingConfig, quantization
1298+
from neural_compressor.training import prepare_compression
1299+
conf = QuantizationAwareTrainingConfig(
1300+
backend=quantization_config.backend,
1301+
excluded_precisions=quantization_config.excluded_precisions,
1302+
op_type_dict=quantization_config.op_type_dict,
1303+
op_name_dict=quantization_config.op_name_dict,
1304+
)
1305+
compression_manager = prepare_compression(model, conf)
1306+
compression_manager.callbacks.on_train_begin()
1307+
model = compression_manager.model
1308+
train_func(model)
1309+
compression_manager.callbacks.on_train_end()
1310+
compression_manager.model.save_pretrained = types.MethodType(save_low_bit, model)
1311+
quantization_config.remove_redundant_parameters()
1312+
compression_manager.model.quantization_config = quantization_config
1313+
logger.info("Quant Aware Training done.")
1314+
return compression_manager.model
11261315
else:
11271316
if use_neural_speed:
11281317
logger.info("Using Neural Speed with FP32 model dtype.")
@@ -1255,6 +1444,10 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
12551444
quantization_config = AutoRoundConfig.from_dict(quantization_config)
12561445
elif quantization_config["quant_method"] == "static":
12571446
quantization_config = StaticQuantConfig.from_dict(quantization_config)
1447+
elif quantization_config["quant_method"] == "dynamic":
1448+
quantization_config = DynamicQuantConfig.from_dict(quantization_config)
1449+
elif quantization_config["quant_method"] == "qat":
1450+
quantization_config = QuantAwareTrainingConfig.from_dict(quantization_config)
12581451
assert (
12591452
quantization_config is not None
12601453
), "Detect this model is not a low-bit model."
@@ -1499,7 +1692,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
14991692
# - we assume all floating dtype weights are of the same dtype
15001693
# we also may have config.torch_dtype available, but we won't rely on it till v5
15011694
# Pretrained Model
1502-
if quantization_config.quant_method == "static":
1695+
if quantization_config.quant_method in ["static", "dynamic", "qat"]:
15031696
model = model_class(config, *model_args, **kwargs)
15041697
from neural_compressor.utils.pytorch import load
15051698
weights_file = os.path.join(

intel_extension_for_transformers/transformers/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
BitsAndBytesConfig,
2222
SmoothQuantConfig,
2323
StaticQuantConfig,
24+
DynamicQuantConfig,
25+
QuantAwareTrainingConfig,
2426
SparsityConfig,
2527
RtnConfig,
2628
AwqConfig,

intel_extension_for_transformers/transformers/utils/config.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,11 @@ class QuantizationMethod(str, Enum):
4848
RTN = "rtn"
4949
AUTOROUND = "autoround"
5050
TEQ = "teq"
51+
DYNAMIC = "dynamic"
5152
STATIC = "static"
5253
SmoothQuant = "sq"
54+
QuantAwareTraining = "qat"
55+
5356

5457

5558
class SparsityConfig(PretrainedConfig):
@@ -537,7 +540,9 @@ def remove_redundant_parameters(self):
537540
"double_quant_scale_dtype", "use_double_quant", "mse_range", "scheme", "tokenizer", "use_ggml",
538541
"use_neural_speed", "use_quant", "layer_wise", "blocksize", "nsamples", "max_input_length", "static_groups",
539542
"lr", "minmax_lr", "iters", "use_quant_input", "device", "calib_dataset", "calib_pad_val", "calib_shuffle",
540-
"calib_padding", "example_inputs", "excluded_precisions", "op_name_dict", "op_type_dict"]
543+
"calib_padding", "example_inputs", "excluded_precisions", "op_name_dict", "op_type_dict", "train_dataloader",
544+
"train_func", "train_iters", "train_len", "train_padding", "train_dataset", "train_pad_val", "train_shuffle",
545+
"train_batch_size"]
541546
for parameter in remove_parameters:
542547
if hasattr(self, parameter):
543548
delattr(self, parameter)
@@ -600,6 +605,55 @@ def get_config_dict(
600605
pretrained_model_name_or_path, _configuration_file=cf, **kwargs
601606
)
602607

608+
class QuantAwareTrainingConfig(ITREXQuantizationConfigMixin):
609+
def __init__(
610+
self,
611+
backend="default",
612+
tokenizer=None,
613+
train_dataset="NeelNanda/pile-10k",
614+
train_dataloader=None,
615+
train_func=None,
616+
train_shuffle=True,
617+
train_iters=100,
618+
train_padding=True,
619+
train_batch_size=8,
620+
train_len=512,
621+
train_pad_val=1,
622+
op_name_dict=None,
623+
op_type_dict=None,
624+
excluded_precisions=[],
625+
**kwargs,
626+
):
627+
self.quant_method = QuantizationMethod.QuantAwareTraining
628+
self.backend = backend
629+
self.tokenizer = tokenizer
630+
self.train_dataset = train_dataset
631+
self.train_dataloader = train_dataloader
632+
self.train_func = train_func
633+
self.train_shuffle = train_shuffle
634+
self.train_iters = train_iters
635+
self.train_padding = train_padding
636+
self.train_len = train_len
637+
self.train_pad_val = train_pad_val
638+
self.train_batch_size = train_batch_size
639+
self.op_name_dict = op_name_dict
640+
self.op_type_dict = op_type_dict
641+
self.excluded_precisions = excluded_precisions
642+
643+
644+
class DynamicQuantConfig(ITREXQuantizationConfigMixin):
645+
def __init__(
646+
self,
647+
excluded_precisions=[],
648+
op_name_dict=None,
649+
op_type_dict=None,
650+
**kwargs,
651+
):
652+
self.quant_method = QuantizationMethod.DYNAMIC
653+
self.excluded_precisions = excluded_precisions
654+
self.op_name_dict = op_name_dict
655+
self.op_type_dict = op_type_dict
656+
603657
class StaticQuantConfig(ITREXQuantizationConfigMixin):
604658
def __init__(
605659
self,

tests/CI/test_quantization.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,8 @@ def test_quantization_for_llm(self):
317317
MixedPrecisionConfig,
318318
SmoothQuantConfig,
319319
StaticQuantConfig,
320+
DynamicQuantConfig,
321+
QuantAwareTrainingConfig,
320322
RtnConfig,
321323
AwqConfig,
322324
TeqConfig,
@@ -327,6 +329,21 @@ def test_quantization_for_llm(self):
327329
from intel_extension_for_transformers.transformers import AutoModelForCausalLM
328330
fp32_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, use_neural_speed=False)
329331
dummy_input = fp32_model.dummy_inputs["input_ids"]
332+
333+
# Dynamic quant
334+
dq_config = DynamicQuantConfig()
335+
q_model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
336+
quantization_config=dq_config,
337+
)
338+
q_model.eval()
339+
output = q_model(dummy_input)
340+
q_model.save_pretrained("./saved_results")
341+
output = q_model(dummy_input)
342+
self.assertTrue(isclose(float(output[0][0][0][0]), 0.17140813171863556, rel_tol=1e-04))
343+
q_model = AutoModelForCausalLM.from_pretrained("./saved_results"
344+
)
345+
output = q_model(dummy_input)
346+
self.assertTrue(isclose(float(output[0][0][0][0]), 0.17140813171863556, rel_tol=1e-04))
330347
# Static quant
331348
sq_config = StaticQuantConfig(
332349
tokenizer=tokenizer, # either two of one, tokenizer or calib_func
@@ -343,6 +360,22 @@ def test_quantization_for_llm(self):
343360
loading_model.eval()
344361
output = loading_model(dummy_input)
345362
self.assertTrue(isclose(float(output[0][0][0][0]), 0.17378684878349304, rel_tol=1e-04))
363+
# Quant aware training
364+
qat_config = QuantAwareTrainingConfig(
365+
tokenizer=tokenizer, # either two of one, tokenizer or train_func
366+
train_iters=2,
367+
)
368+
q_model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
369+
quantization_config=qat_config,
370+
)
371+
q_model.eval()
372+
output = q_model(dummy_input)
373+
self.assertTrue(isclose(float(output[0][0][0][0]), 0.17362995445728302, rel_tol=1e-04))
374+
q_model.save_pretrained("./saved_results")
375+
loading_model = AutoModelForCausalLM.from_pretrained("./saved_results")
376+
loading_model.eval()
377+
output = loading_model(dummy_input)
378+
self.assertTrue(isclose(float(output[0][0][0][0]), 0.17362995445728302, rel_tol=1e-04))
346379
# Smoothquant
347380
sq_config = SmoothQuantConfig(
348381
tokenizer=tokenizer, # either two of one, tokenizer or calib_func

0 commit comments

Comments
 (0)