43
43
MixedPrecisionConfig ,
44
44
SmoothQuantConfig ,
45
45
StaticQuantConfig ,
46
+ DynamicQuantConfig ,
47
+ QuantAwareTrainingConfig ,
46
48
RtnConfig ,
47
49
AwqConfig ,
48
50
TeqConfig ,
@@ -412,7 +414,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
412
414
"Quantization_config loading failed. If you want to load saved "
413
415
"low bit model, please check your quantizate_config.json."
414
416
)
415
- elif use_neural_speed and not config .quantization_config ["quant_method" ] == " static" :
417
+ elif use_neural_speed and not config .quantization_config ["quant_method" ] in [ "dynamic" , " static", "qat" ] :
416
418
if not os .path .exists (pretrained_model_name_or_path ):
417
419
from huggingface_hub import snapshot_download
418
420
pretrained_model_name_or_path = snapshot_download (repo_id = pretrained_model_name_or_path ,
@@ -972,6 +974,42 @@ def calib_func(model):
972
974
),
973
975
)
974
976
logger .info ("SmoothQuant done." )
977
+ elif isinstance (quantization_config , DynamicQuantConfig ):
978
+ model = cls .ORIG_MODEL .from_pretrained (
979
+ pretrained_model_name_or_path ,
980
+ * model_args ,
981
+ config = config ,
982
+ low_cpu_mem_usage = True ,
983
+ torch_dtype = torch .float ,
984
+ ** kwargs ,
985
+ )
986
+
987
+ if (
988
+ not torch .cuda .is_available ()
989
+ or device_map == "cpu"
990
+ or device_map == torch .device ("cpu" )
991
+ ) and model .config .model_type == "chatglm" :
992
+ model = model .float ()
993
+ model .eval ()
994
+ logger .info ("Applying DynamicQuant." )
995
+ # call inc dynamic quant
996
+ from neural_compressor import PostTrainingQuantConfig , quantization
997
+
998
+ conf = PostTrainingQuantConfig (
999
+ approach = "dynamic" ,
1000
+ excluded_precisions = quantization_config .excluded_precisions ,
1001
+ op_type_dict = quantization_config .op_type_dict ,
1002
+ op_name_dict = quantization_config .op_name_dict ,
1003
+ )
1004
+ model = quantization .fit (
1005
+ model ,
1006
+ conf ,
1007
+ )
1008
+ model .save_pretrained = types .MethodType (save_low_bit , model )
1009
+ quantization_config .remove_redundant_parameters ()
1010
+ model .quantization_config = quantization_config
1011
+ logger .info ("DynamicQuant done." )
1012
+ return model
975
1013
elif isinstance (quantization_config , StaticQuantConfig ):
976
1014
if quantization_config .backend == "ipex" :
977
1015
try :
@@ -1107,7 +1145,7 @@ def calib_func(model):
1107
1145
from neural_compressor import PostTrainingQuantConfig , quantization
1108
1146
1109
1147
conf = PostTrainingQuantConfig (
1110
- backend = quantization_config .backend , # default is ipex
1148
+ backend = quantization_config .backend ,
1111
1149
excluded_precisions = quantization_config .excluded_precisions ,
1112
1150
op_type_dict = quantization_config .op_type_dict ,
1113
1151
op_name_dict = quantization_config .op_name_dict ,
@@ -1123,6 +1161,157 @@ def calib_func(model):
1123
1161
model .quantization_config = quantization_config
1124
1162
logger .info ("StaticQuant done." )
1125
1163
return model
1164
+ elif isinstance (quantization_config , QuantAwareTrainingConfig ):
1165
+ model = cls .ORIG_MODEL .from_pretrained (
1166
+ pretrained_model_name_or_path ,
1167
+ * model_args ,
1168
+ config = config ,
1169
+ low_cpu_mem_usage = True ,
1170
+ torch_dtype = torch .float ,
1171
+ ** kwargs ,
1172
+ )
1173
+
1174
+ if (
1175
+ not torch .cuda .is_available ()
1176
+ or device_map == "cpu"
1177
+ or device_map == torch .device ("cpu" )
1178
+ ) and model .config .model_type == "chatglm" :
1179
+ model = model .float ()
1180
+ logger .info ("Applying QuantAwareTraining." )
1181
+ # train function
1182
+ train_func = quantization_config .train_func
1183
+ tokenizer = quantization_config .tokenizer
1184
+ if train_func is None :
1185
+ if quantization_config .tokenizer is None :
1186
+ logger .error (
1187
+ "Please provide the tokenizer or provide train_func directly,"
1188
+ + " the following is how to get tokenizer. \n "
1189
+ + " from transformer import AutoTokenizer \n "
1190
+ + " tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) \n "
1191
+ )
1192
+ exit (0 )
1193
+
1194
+ from datasets import load_dataset
1195
+ from torch .utils .data import DataLoader
1196
+
1197
+ train_dataset = quantization_config .train_dataset
1198
+ train_shuffle = quantization_config .train_shuffle
1199
+ train_iters = quantization_config .train_iters
1200
+ train_padding = quantization_config .train_padding
1201
+ train_len = quantization_config .train_len
1202
+ train_pad_val = quantization_config .train_pad_val
1203
+ from torch .nn .functional import pad
1204
+
1205
+ train_dataset = load_dataset (
1206
+ train_dataset ,
1207
+ split = (
1208
+ "test"
1209
+ if train_dataset in ["mbpp" , "openai_humaneval" ]
1210
+ else "train"
1211
+ ),
1212
+ )
1213
+ if train_shuffle :
1214
+ train_dataset = train_dataset .shuffle (seed = 42 )
1215
+
1216
+ def tokenize_function (examples ):
1217
+ if "code" in examples :
1218
+ example = tokenizer (examples ["code" ])
1219
+ elif "prompt" in examples :
1220
+ example = tokenizer (examples ["prompt" ])
1221
+ elif "text" in examples :
1222
+ example = tokenizer (examples ["text" ])
1223
+ else :
1224
+ logger .error (
1225
+ "Please check dataset prompt identifier,"
1226
+ + " NeelNanda/pile-10k is default used calibration dataset."
1227
+ )
1228
+ exit (0 )
1229
+ return example
1230
+
1231
+ def collate_batch (batch ):
1232
+ input_ids_padded = []
1233
+ last_ind = []
1234
+ for text in batch :
1235
+ input_ids = text ["input_ids" ]
1236
+ if not train_padding :
1237
+ input_ids = (
1238
+ input_ids [: int (train_len )]
1239
+ if len (input_ids ) > int (train_len )
1240
+ else input_ids
1241
+ ) # no_padding
1242
+ else :
1243
+ pad_len = train_len - input_ids .shape [0 ]
1244
+ input_ids = pad (
1245
+ input_ids , (0 , pad_len ), value = train_pad_val
1246
+ )
1247
+
1248
+ last_ind .append (input_ids .shape [0 ] - 1 )
1249
+ input_ids_padded .append (input_ids )
1250
+
1251
+ return (
1252
+ {
1253
+ "input_ids" : torch .vstack (input_ids_padded ),
1254
+ },
1255
+ torch .tensor (last_ind ),
1256
+ )
1257
+
1258
+
1259
+ tokenized_dataset = train_dataset .map (tokenize_function , batched = True )
1260
+ tokenized_dataset .set_format (type = "torch" , columns = ["input_ids" ])
1261
+ train_dataloader = DataLoader (
1262
+ tokenized_dataset ,
1263
+ batch_size = quantization_config .train_batch_size ,
1264
+ shuffle = False ,
1265
+ collate_fn = collate_batch ,
1266
+ )
1267
+
1268
+ def train_func (model ):
1269
+ optimizer = torch .optim .SGD (model .parameters (), lr = 0.0001 )
1270
+ # switch to evaluate mode
1271
+ model .train ()
1272
+ for i , (inputs , last_ind ) in enumerate (train_dataloader ):
1273
+ if i >= train_iters :
1274
+ break
1275
+ output = model (** inputs )
1276
+ if isinstance (output , tuple ):
1277
+ loss = output [0 ].mean ()
1278
+ elif isinstance (output , dict ):
1279
+ loss = output ["logits" ].mean ()
1280
+ else :
1281
+ loss = output .mean ()
1282
+ optimizer .zero_grad ()
1283
+ loss .backward ()
1284
+ optimizer .step ()
1285
+ print ('Iteration [{}], Loss: {:.4f}' .format (i + 1 , loss ))
1286
+ return model
1287
+
1288
+ logger .info (
1289
+ "The default calibration function is used, "
1290
+ + "the calibration dataset is NeelNanda/pile-10k, "
1291
+ + "batchsize is 1 and calibration iteration is 100."
1292
+ )
1293
+ train_func = train_func
1294
+
1295
+
1296
+ # call inc static quant
1297
+ from neural_compressor import QuantizationAwareTrainingConfig , quantization
1298
+ from neural_compressor .training import prepare_compression
1299
+ conf = QuantizationAwareTrainingConfig (
1300
+ backend = quantization_config .backend ,
1301
+ excluded_precisions = quantization_config .excluded_precisions ,
1302
+ op_type_dict = quantization_config .op_type_dict ,
1303
+ op_name_dict = quantization_config .op_name_dict ,
1304
+ )
1305
+ compression_manager = prepare_compression (model , conf )
1306
+ compression_manager .callbacks .on_train_begin ()
1307
+ model = compression_manager .model
1308
+ train_func (model )
1309
+ compression_manager .callbacks .on_train_end ()
1310
+ compression_manager .model .save_pretrained = types .MethodType (save_low_bit , model )
1311
+ quantization_config .remove_redundant_parameters ()
1312
+ compression_manager .model .quantization_config = quantization_config
1313
+ logger .info ("Quant Aware Training done." )
1314
+ return compression_manager .model
1126
1315
else :
1127
1316
if use_neural_speed :
1128
1317
logger .info ("Using Neural Speed with FP32 model dtype." )
@@ -1255,6 +1444,10 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1255
1444
quantization_config = AutoRoundConfig .from_dict (quantization_config )
1256
1445
elif quantization_config ["quant_method" ] == "static" :
1257
1446
quantization_config = StaticQuantConfig .from_dict (quantization_config )
1447
+ elif quantization_config ["quant_method" ] == "dynamic" :
1448
+ quantization_config = DynamicQuantConfig .from_dict (quantization_config )
1449
+ elif quantization_config ["quant_method" ] == "qat" :
1450
+ quantization_config = QuantAwareTrainingConfig .from_dict (quantization_config )
1258
1451
assert (
1259
1452
quantization_config is not None
1260
1453
), "Detect this model is not a low-bit model."
@@ -1499,7 +1692,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1499
1692
# - we assume all floating dtype weights are of the same dtype
1500
1693
# we also may have config.torch_dtype available, but we won't rely on it till v5
1501
1694
# Pretrained Model
1502
- if quantization_config .quant_method == "static" :
1695
+ if quantization_config .quant_method in [ "static" , "dynamic" , "qat" ] :
1503
1696
model = model_class (config , * model_args , ** kwargs )
1504
1697
from neural_compressor .utils .pytorch import load
1505
1698
weights_file = os .path .join (
0 commit comments