Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit d70106d

Browse files
committed
add qatconfig
Signed-off-by: changwangss <[email protected]>
1 parent c7be2d9 commit d70106d

File tree

5 files changed

+217
-3
lines changed

5 files changed

+217
-3
lines changed

intel_extension_for_transformers/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
SmoothQuantConfig,
4747
StaticQuantConfig,
4848
DynamicQuantConfig,
49+
QuantAwareTrainingConfig,
4950
RtnConfig,
5051
AwqConfig,
5152
TeqConfig,

intel_extension_for_transformers/transformers/modeling/modeling_auto.py

Lines changed: 156 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
SmoothQuantConfig,
4545
StaticQuantConfig,
4646
DynamicQuantConfig,
47+
QuantAwareTrainingConfig,
4748
RtnConfig,
4849
AwqConfig,
4950
TeqConfig,
@@ -413,7 +414,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
413414
"Quantization_config loading failed. If you want to load saved "
414415
"low bit model, please check your quantizate_config.json."
415416
)
416-
elif use_neural_speed and not config.quantization_config["quant_method"] in ["dynamic", "static"]:
417+
elif use_neural_speed and not config.quantization_config["quant_method"] in ["dynamic", "static", "qat"]:
417418
if not os.path.exists(pretrained_model_name_or_path):
418419
from huggingface_hub import snapshot_download
419420
pretrained_model_name_or_path = snapshot_download(repo_id=pretrained_model_name_or_path,
@@ -1160,6 +1161,157 @@ def calib_func(model):
11601161
model.quantization_config = quantization_config
11611162
logger.info("StaticQuant done.")
11621163
return model
1164+
elif isinstance(quantization_config, QuantAwareTrainingConfig):
1165+
model = cls.ORIG_MODEL.from_pretrained(
1166+
pretrained_model_name_or_path,
1167+
*model_args,
1168+
config=config,
1169+
low_cpu_mem_usage=True,
1170+
torch_dtype=torch.float,
1171+
**kwargs,
1172+
)
1173+
1174+
if (
1175+
not torch.cuda.is_available()
1176+
or device_map == "cpu"
1177+
or device_map == torch.device("cpu")
1178+
) and model.config.model_type == "chatglm":
1179+
model = model.float()
1180+
logger.info("Applying QuantAwareTraining.")
1181+
# train function
1182+
train_func = quantization_config.train_func
1183+
tokenizer = quantization_config.tokenizer
1184+
if train_func is None:
1185+
if quantization_config.tokenizer is None:
1186+
logger.error(
1187+
"Please provide the tokenizer or provide train_func directly,"
1188+
+ " the following is how to get tokenizer. \n"
1189+
+ " from transformer import AutoTokenizer \n"
1190+
+ " tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) \n"
1191+
)
1192+
exit(0)
1193+
1194+
from datasets import load_dataset
1195+
from torch.utils.data import DataLoader
1196+
1197+
train_dataset = quantization_config.train_dataset
1198+
train_shuffle = quantization_config.train_shuffle
1199+
train_iters = quantization_config.train_iters
1200+
train_padding = quantization_config.train_padding
1201+
train_len = quantization_config.train_len
1202+
train_pad_val = quantization_config.train_pad_val
1203+
from torch.nn.functional import pad
1204+
1205+
train_dataset = load_dataset(
1206+
train_dataset,
1207+
split=(
1208+
"test"
1209+
if train_dataset in ["mbpp", "openai_humaneval"]
1210+
else "train"
1211+
),
1212+
)
1213+
if train_shuffle:
1214+
train_dataset = train_dataset.shuffle(seed=42)
1215+
1216+
def tokenize_function(examples):
1217+
if "code" in examples:
1218+
example = tokenizer(examples["code"])
1219+
elif "prompt" in examples:
1220+
example = tokenizer(examples["prompt"])
1221+
elif "text" in examples:
1222+
example = tokenizer(examples["text"])
1223+
else:
1224+
logger.error(
1225+
"Please check dataset prompt identifier,"
1226+
+ " NeelNanda/pile-10k is default used calibration dataset."
1227+
)
1228+
exit(0)
1229+
return example
1230+
1231+
def collate_batch(batch):
1232+
input_ids_padded = []
1233+
last_ind = []
1234+
for text in batch:
1235+
input_ids = text["input_ids"]
1236+
if not train_padding:
1237+
input_ids = (
1238+
input_ids[: int(train_len)]
1239+
if len(input_ids) > int(train_len)
1240+
else input_ids
1241+
) # no_padding
1242+
else:
1243+
pad_len = train_len - input_ids.shape[0]
1244+
input_ids = pad(
1245+
input_ids, (0, pad_len), value=train_pad_val
1246+
)
1247+
1248+
last_ind.append(input_ids.shape[0] - 1)
1249+
input_ids_padded.append(input_ids)
1250+
1251+
return (
1252+
{
1253+
"input_ids": torch.vstack(input_ids_padded),
1254+
},
1255+
torch.tensor(last_ind),
1256+
)
1257+
1258+
1259+
tokenized_dataset = train_dataset.map(tokenize_function, batched=True)
1260+
tokenized_dataset.set_format(type="torch", columns=["input_ids"])
1261+
train_dataloader = DataLoader(
1262+
tokenized_dataset,
1263+
batch_size=quantization_config.train_batch_size,
1264+
shuffle=False,
1265+
collate_fn=collate_batch,
1266+
)
1267+
1268+
def train_func(model):
1269+
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
1270+
# switch to evaluate mode
1271+
model.train()
1272+
for i, (inputs, last_ind) in enumerate(train_dataloader):
1273+
if i >= train_iters:
1274+
break
1275+
output = model(**inputs)
1276+
if isinstance(output, tuple):
1277+
loss = output[0].mean()
1278+
elif isinstance(output, dict):
1279+
loss = output["logits"].mean()
1280+
else:
1281+
loss = output.mean()
1282+
optimizer.zero_grad()
1283+
loss.backward()
1284+
optimizer.step()
1285+
print('Iteration [{}], Loss: {:.4f}'.format(i+1, loss))
1286+
return model
1287+
1288+
logger.info(
1289+
"The default calibration function is used, "
1290+
+ "the calibration dataset is NeelNanda/pile-10k, "
1291+
+ "batchsize is 1 and calibration iteration is 100."
1292+
)
1293+
train_func = train_func
1294+
1295+
1296+
# call inc static quant
1297+
from neural_compressor import QuantizationAwareTrainingConfig, quantization
1298+
from neural_compressor.training import prepare_compression
1299+
conf = QuantizationAwareTrainingConfig(
1300+
backend=quantization_config.backend,
1301+
excluded_precisions=quantization_config.excluded_precisions,
1302+
op_type_dict=quantization_config.op_type_dict,
1303+
op_name_dict=quantization_config.op_name_dict,
1304+
)
1305+
compression_manager = prepare_compression(model, conf)
1306+
compression_manager.callbacks.on_train_begin()
1307+
model = compression_manager.model
1308+
train_func(model)
1309+
compression_manager.callbacks.on_train_end()
1310+
compression_manager.model.save_pretrained = types.MethodType(save_low_bit, model)
1311+
quantization_config.remove_redundant_parameters()
1312+
compression_manager.model.quantization_config = quantization_config
1313+
logger.info("Quant Aware Training done.")
1314+
return compression_manager.model
11631315
else:
11641316
if use_neural_speed:
11651317
logger.info("Using Neural Speed with FP32 model dtype.")
@@ -1294,6 +1446,8 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
12941446
quantization_config = StaticQuantConfig.from_dict(quantization_config)
12951447
elif quantization_config["quant_method"] == "dynamic":
12961448
quantization_config = DynamicQuantConfig.from_dict(quantization_config)
1449+
elif quantization_config["quant_method"] == "qat":
1450+
quantization_config = QuantAwareTrainingConfig.from_dict(quantization_config)
12971451
assert (
12981452
quantization_config is not None
12991453
), "Detect this model is not a low-bit model."
@@ -1538,7 +1692,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
15381692
# - we assume all floating dtype weights are of the same dtype
15391693
# we also may have config.torch_dtype available, but we won't rely on it till v5
15401694
# Pretrained Model
1541-
if quantization_config.quant_method in ["static", "dynamic"]:
1695+
if quantization_config.quant_method in ["static", "dynamic", "qat"]:
15421696
model = model_class(config, *model_args, **kwargs)
15431697
from neural_compressor.utils.pytorch import load
15441698
weights_file = os.path.join(

intel_extension_for_transformers/transformers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
SmoothQuantConfig,
2323
StaticQuantConfig,
2424
DynamicQuantConfig,
25+
QuantAwareTrainingConfig,
2526
SparsityConfig,
2627
RtnConfig,
2728
AwqConfig,

intel_extension_for_transformers/transformers/utils/config.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ class QuantizationMethod(str, Enum):
5151
DYNAMIC = "dynamic"
5252
STATIC = "static"
5353
SmoothQuant = "sq"
54+
QuantAwareTraining = "qat"
55+
5456

5557

5658
class SparsityConfig(PretrainedConfig):
@@ -538,7 +540,9 @@ def remove_redundant_parameters(self):
538540
"double_quant_scale_dtype", "use_double_quant", "mse_range", "scheme", "tokenizer", "use_ggml",
539541
"use_neural_speed", "use_quant", "layer_wise", "blocksize", "nsamples", "max_input_length", "static_groups",
540542
"lr", "minmax_lr", "iters", "use_quant_input", "device", "calib_dataset", "calib_pad_val", "calib_shuffle",
541-
"calib_padding", "example_inputs", "excluded_precisions", "op_name_dict", "op_type_dict"]
543+
"calib_padding", "example_inputs", "excluded_precisions", "op_name_dict", "op_type_dict", "train_dataloader",
544+
"train_func", "train_iters", "train_len", "train_padding", "train_dataset", "train_pad_val", "train_shuffle",
545+
"train_batch_size"]
542546
for parameter in remove_parameters:
543547
if hasattr(self, parameter):
544548
delattr(self, parameter)
@@ -600,6 +604,43 @@ def get_config_dict(
600604
return super().get_config_dict(
601605
pretrained_model_name_or_path, _configuration_file=cf, **kwargs
602606
)
607+
608+
class QuantAwareTrainingConfig(ITREXQuantizationConfigMixin):
609+
def __init__(
610+
self,
611+
backend="default",
612+
tokenizer=None,
613+
train_dataset="NeelNanda/pile-10k",
614+
train_dataloader=None,
615+
train_func=None,
616+
train_shuffle=True,
617+
train_iters=100,
618+
train_padding=True,
619+
train_batch_size=8,
620+
train_len=512,
621+
train_pad_val=1,
622+
op_name_dict=None,
623+
op_type_dict=None,
624+
excluded_precisions=[],
625+
**kwargs,
626+
):
627+
self.quant_method = QuantizationMethod.QuantAwareTraining
628+
self.backend = backend
629+
self.tokenizer = tokenizer
630+
self.train_dataset = train_dataset
631+
self.train_dataloader = train_dataloader
632+
self.train_func = train_func
633+
self.train_shuffle = train_shuffle
634+
self.train_iters = train_iters
635+
self.train_padding = train_padding
636+
self.train_len = train_len
637+
self.train_pad_val = train_pad_val
638+
self.train_batch_size = train_batch_size
639+
self.op_name_dict = op_name_dict
640+
self.op_type_dict = op_type_dict
641+
self.excluded_precisions = excluded_precisions
642+
643+
603644
class DynamicQuantConfig(ITREXQuantizationConfigMixin):
604645
def __init__(
605646
self,

tests/CI/test_quantization.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ def test_quantization_for_llm(self):
318318
SmoothQuantConfig,
319319
StaticQuantConfig,
320320
DynamicQuantConfig,
321+
QuantAwareTrainingConfig,
321322
RtnConfig,
322323
AwqConfig,
323324
TeqConfig,
@@ -359,6 +360,22 @@ def test_quantization_for_llm(self):
359360
loading_model.eval()
360361
output = loading_model(dummy_input)
361362
self.assertTrue(isclose(float(output[0][0][0][0]), 0.17378684878349304, rel_tol=1e-04))
363+
# Quant aware training
364+
qat_config = QuantAwareTrainingConfig(
365+
tokenizer=tokenizer, # either two of one, tokenizer or train_func
366+
train_iters=2,
367+
)
368+
q_model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
369+
quantization_config=qat_config,
370+
)
371+
q_model.eval()
372+
output = q_model(dummy_input)
373+
self.assertTrue(isclose(float(output[0][0][0][0]), 0.17362995445728302, rel_tol=1e-04))
374+
q_model.save_pretrained("./saved_results")
375+
loading_model = AutoModelForCausalLM.from_pretrained("./saved_results")
376+
loading_model.eval()
377+
output = loading_model(dummy_input)
378+
self.assertTrue(isclose(float(output[0][0][0][0]), 0.17362995445728302, rel_tol=1e-04))
362379
# Smoothquant
363380
sq_config = SmoothQuantConfig(
364381
tokenizer=tokenizer, # either two of one, tokenizer or calib_func

0 commit comments

Comments
 (0)