|
44 | 44 | SmoothQuantConfig,
|
45 | 45 | StaticQuantConfig,
|
46 | 46 | DynamicQuantConfig,
|
| 47 | + QuantAwareTrainingConfig, |
47 | 48 | RtnConfig,
|
48 | 49 | AwqConfig,
|
49 | 50 | TeqConfig,
|
@@ -413,7 +414,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
413 | 414 | "Quantization_config loading failed. If you want to load saved "
|
414 | 415 | "low bit model, please check your quantizate_config.json."
|
415 | 416 | )
|
416 |
| - elif use_neural_speed and not config.quantization_config["quant_method"] in ["dynamic", "static"]: |
| 417 | + elif use_neural_speed and not config.quantization_config["quant_method"] in ["dynamic", "static", "qat"]: |
417 | 418 | if not os.path.exists(pretrained_model_name_or_path):
|
418 | 419 | from huggingface_hub import snapshot_download
|
419 | 420 | pretrained_model_name_or_path = snapshot_download(repo_id=pretrained_model_name_or_path,
|
@@ -1160,6 +1161,157 @@ def calib_func(model):
|
1160 | 1161 | model.quantization_config = quantization_config
|
1161 | 1162 | logger.info("StaticQuant done.")
|
1162 | 1163 | return model
|
| 1164 | + elif isinstance(quantization_config, QuantAwareTrainingConfig): |
| 1165 | + model = cls.ORIG_MODEL.from_pretrained( |
| 1166 | + pretrained_model_name_or_path, |
| 1167 | + *model_args, |
| 1168 | + config=config, |
| 1169 | + low_cpu_mem_usage=True, |
| 1170 | + torch_dtype=torch.float, |
| 1171 | + **kwargs, |
| 1172 | + ) |
| 1173 | + |
| 1174 | + if ( |
| 1175 | + not torch.cuda.is_available() |
| 1176 | + or device_map == "cpu" |
| 1177 | + or device_map == torch.device("cpu") |
| 1178 | + ) and model.config.model_type == "chatglm": |
| 1179 | + model = model.float() |
| 1180 | + logger.info("Applying QuantAwareTraining.") |
| 1181 | + # train function |
| 1182 | + train_func = quantization_config.train_func |
| 1183 | + tokenizer = quantization_config.tokenizer |
| 1184 | + if train_func is None: |
| 1185 | + if quantization_config.tokenizer is None: |
| 1186 | + logger.error( |
| 1187 | + "Please provide the tokenizer or provide train_func directly," |
| 1188 | + + " the following is how to get tokenizer. \n" |
| 1189 | + + " from transformer import AutoTokenizer \n" |
| 1190 | + + " tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) \n" |
| 1191 | + ) |
| 1192 | + exit(0) |
| 1193 | + |
| 1194 | + from datasets import load_dataset |
| 1195 | + from torch.utils.data import DataLoader |
| 1196 | + |
| 1197 | + train_dataset = quantization_config.train_dataset |
| 1198 | + train_shuffle = quantization_config.train_shuffle |
| 1199 | + train_iters = quantization_config.train_iters |
| 1200 | + train_padding = quantization_config.train_padding |
| 1201 | + train_len = quantization_config.train_len |
| 1202 | + train_pad_val = quantization_config.train_pad_val |
| 1203 | + from torch.nn.functional import pad |
| 1204 | + |
| 1205 | + train_dataset = load_dataset( |
| 1206 | + train_dataset, |
| 1207 | + split=( |
| 1208 | + "test" |
| 1209 | + if train_dataset in ["mbpp", "openai_humaneval"] |
| 1210 | + else "train" |
| 1211 | + ), |
| 1212 | + ) |
| 1213 | + if train_shuffle: |
| 1214 | + train_dataset = train_dataset.shuffle(seed=42) |
| 1215 | + |
| 1216 | + def tokenize_function(examples): |
| 1217 | + if "code" in examples: |
| 1218 | + example = tokenizer(examples["code"]) |
| 1219 | + elif "prompt" in examples: |
| 1220 | + example = tokenizer(examples["prompt"]) |
| 1221 | + elif "text" in examples: |
| 1222 | + example = tokenizer(examples["text"]) |
| 1223 | + else: |
| 1224 | + logger.error( |
| 1225 | + "Please check dataset prompt identifier," |
| 1226 | + + " NeelNanda/pile-10k is default used calibration dataset." |
| 1227 | + ) |
| 1228 | + exit(0) |
| 1229 | + return example |
| 1230 | + |
| 1231 | + def collate_batch(batch): |
| 1232 | + input_ids_padded = [] |
| 1233 | + last_ind = [] |
| 1234 | + for text in batch: |
| 1235 | + input_ids = text["input_ids"] |
| 1236 | + if not train_padding: |
| 1237 | + input_ids = ( |
| 1238 | + input_ids[: int(train_len)] |
| 1239 | + if len(input_ids) > int(train_len) |
| 1240 | + else input_ids |
| 1241 | + ) # no_padding |
| 1242 | + else: |
| 1243 | + pad_len = train_len - input_ids.shape[0] |
| 1244 | + input_ids = pad( |
| 1245 | + input_ids, (0, pad_len), value=train_pad_val |
| 1246 | + ) |
| 1247 | + |
| 1248 | + last_ind.append(input_ids.shape[0] - 1) |
| 1249 | + input_ids_padded.append(input_ids) |
| 1250 | + |
| 1251 | + return ( |
| 1252 | + { |
| 1253 | + "input_ids": torch.vstack(input_ids_padded), |
| 1254 | + }, |
| 1255 | + torch.tensor(last_ind), |
| 1256 | + ) |
| 1257 | + |
| 1258 | + |
| 1259 | + tokenized_dataset = train_dataset.map(tokenize_function, batched=True) |
| 1260 | + tokenized_dataset.set_format(type="torch", columns=["input_ids"]) |
| 1261 | + train_dataloader = DataLoader( |
| 1262 | + tokenized_dataset, |
| 1263 | + batch_size=quantization_config.train_batch_size, |
| 1264 | + shuffle=False, |
| 1265 | + collate_fn=collate_batch, |
| 1266 | + ) |
| 1267 | + |
| 1268 | + def train_func(model): |
| 1269 | + optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) |
| 1270 | + # switch to evaluate mode |
| 1271 | + model.train() |
| 1272 | + for i, (inputs, last_ind) in enumerate(train_dataloader): |
| 1273 | + if i >= train_iters: |
| 1274 | + break |
| 1275 | + output = model(**inputs) |
| 1276 | + if isinstance(output, tuple): |
| 1277 | + loss = output[0].mean() |
| 1278 | + elif isinstance(output, dict): |
| 1279 | + loss = output["logits"].mean() |
| 1280 | + else: |
| 1281 | + loss = output.mean() |
| 1282 | + optimizer.zero_grad() |
| 1283 | + loss.backward() |
| 1284 | + optimizer.step() |
| 1285 | + print('Iteration [{}], Loss: {:.4f}'.format(i+1, loss)) |
| 1286 | + return model |
| 1287 | + |
| 1288 | + logger.info( |
| 1289 | + "The default calibration function is used, " |
| 1290 | + + "the calibration dataset is NeelNanda/pile-10k, " |
| 1291 | + + "batchsize is 1 and calibration iteration is 100." |
| 1292 | + ) |
| 1293 | + train_func = train_func |
| 1294 | + |
| 1295 | + |
| 1296 | + # call inc static quant |
| 1297 | + from neural_compressor import QuantizationAwareTrainingConfig, quantization |
| 1298 | + from neural_compressor.training import prepare_compression |
| 1299 | + conf = QuantizationAwareTrainingConfig( |
| 1300 | + backend=quantization_config.backend, |
| 1301 | + excluded_precisions=quantization_config.excluded_precisions, |
| 1302 | + op_type_dict=quantization_config.op_type_dict, |
| 1303 | + op_name_dict=quantization_config.op_name_dict, |
| 1304 | + ) |
| 1305 | + compression_manager = prepare_compression(model, conf) |
| 1306 | + compression_manager.callbacks.on_train_begin() |
| 1307 | + model = compression_manager.model |
| 1308 | + train_func(model) |
| 1309 | + compression_manager.callbacks.on_train_end() |
| 1310 | + compression_manager.model.save_pretrained = types.MethodType(save_low_bit, model) |
| 1311 | + quantization_config.remove_redundant_parameters() |
| 1312 | + compression_manager.model.quantization_config = quantization_config |
| 1313 | + logger.info("Quant Aware Training done.") |
| 1314 | + return compression_manager.model |
1163 | 1315 | else:
|
1164 | 1316 | if use_neural_speed:
|
1165 | 1317 | logger.info("Using Neural Speed with FP32 model dtype.")
|
@@ -1294,6 +1446,8 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
1294 | 1446 | quantization_config = StaticQuantConfig.from_dict(quantization_config)
|
1295 | 1447 | elif quantization_config["quant_method"] == "dynamic":
|
1296 | 1448 | quantization_config = DynamicQuantConfig.from_dict(quantization_config)
|
| 1449 | + elif quantization_config["quant_method"] == "qat": |
| 1450 | + quantization_config = QuantAwareTrainingConfig.from_dict(quantization_config) |
1297 | 1451 | assert (
|
1298 | 1452 | quantization_config is not None
|
1299 | 1453 | ), "Detect this model is not a low-bit model."
|
@@ -1538,7 +1692,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
1538 | 1692 | # - we assume all floating dtype weights are of the same dtype
|
1539 | 1693 | # we also may have config.torch_dtype available, but we won't rely on it till v5
|
1540 | 1694 | # Pretrained Model
|
1541 |
| - if quantization_config.quant_method in ["static", "dynamic"]: |
| 1695 | + if quantization_config.quant_method in ["static", "dynamic", "qat"]: |
1542 | 1696 | model = model_class(config, *model_args, **kwargs)
|
1543 | 1697 | from neural_compressor.utils.pytorch import load
|
1544 | 1698 | weights_file = os.path.join(
|
|
0 commit comments