1
1
import torch
2
2
import numpy as np
3
3
import pandas as pd
4
- from tqdm .auto import tqdm
5
4
from transformers import (
6
- AdamW ,
7
5
T5ForConditionalGeneration ,
8
6
MT5ForConditionalGeneration ,
9
7
ByT5Tokenizer ,
12
10
MT5TokenizerFast as MT5Tokenizer ,
13
11
)
14
12
from transformers import AutoTokenizer
15
-
16
- # from fastT5 import export_and_get_onnx_model
13
+ from torch .optim import AdamW
17
14
from torch .utils .data import Dataset , DataLoader
18
15
from transformers import AutoModelWithLMHead , AutoTokenizer
19
16
import pytorch_lightning as pl
20
17
from pytorch_lightning .loggers import TensorBoardLogger
21
- from pytorch_lightning .callbacks import ModelCheckpoint
22
18
from pytorch_lightning .callbacks .early_stopping import EarlyStopping
19
+ from pytorch_lightning .callbacks .progress import TQDMProgressBar
23
20
24
21
torch .cuda .empty_cache ()
25
22
pl .seed_everything (42 )
@@ -37,7 +34,6 @@ def __init__(
37
34
):
38
35
"""
39
36
initiates a PyTorch Dataset Module for input data
40
-
41
37
Args:
42
38
data (pd.DataFrame): input pandas dataframe. Dataframe must have 2 column --> "source_text" and "target_text"
43
39
tokenizer (PreTrainedTokenizer): a PreTrainedTokenizer (T5Tokenizer, MT5Tokenizer, or ByT5Tokenizer)
@@ -85,8 +81,6 @@ def __getitem__(self, index: int):
85
81
] = - 100 # to make sure we have correct labels for T5 text generation
86
82
87
83
return dict (
88
- source_text = source_text ,
89
- target_text = data_row ["target_text" ],
90
84
source_text_input_ids = source_text_encoding ["input_ids" ].flatten (),
91
85
source_text_attention_mask = source_text_encoding ["attention_mask" ].flatten (),
92
86
labels = labels .flatten (),
@@ -105,10 +99,10 @@ def __init__(
105
99
batch_size : int = 4 ,
106
100
source_max_token_len : int = 512 ,
107
101
target_max_token_len : int = 512 ,
102
+ num_workers : int = 2 ,
108
103
):
109
104
"""
110
105
initiates a PyTorch Lightning Data Module
111
-
112
106
Args:
113
107
train_df (pd.DataFrame): training dataframe. Dataframe must contain 2 columns --> "source_text" & "target_text"
114
108
test_df (pd.DataFrame): validation dataframe. Dataframe must contain 2 columns --> "source_text" & "target_text"
@@ -125,6 +119,7 @@ def __init__(
125
119
self .tokenizer = tokenizer
126
120
self .source_max_token_len = source_max_token_len
127
121
self .target_max_token_len = target_max_token_len
122
+ self .num_workers = num_workers
128
123
129
124
def setup (self , stage = None ):
130
125
self .train_dataset = PyTorchDataModule (
@@ -143,38 +138,56 @@ def setup(self, stage=None):
143
138
def train_dataloader (self ):
144
139
""" training dataloader """
145
140
return DataLoader (
146
- self .train_dataset , batch_size = self .batch_size , shuffle = True , num_workers = 2
141
+ self .train_dataset ,
142
+ batch_size = self .batch_size ,
143
+ shuffle = True ,
144
+ num_workers = self .num_workers ,
147
145
)
148
146
149
147
def test_dataloader (self ):
150
148
""" test dataloader """
151
149
return DataLoader (
152
- self .test_dataset , batch_size = self .batch_size , shuffle = False , num_workers = 2
150
+ self .test_dataset ,
151
+ batch_size = self .batch_size ,
152
+ shuffle = False ,
153
+ num_workers = self .num_workers ,
153
154
)
154
155
155
156
def val_dataloader (self ):
156
157
""" validation dataloader """
157
158
return DataLoader (
158
- self .test_dataset , batch_size = self .batch_size , shuffle = False , num_workers = 2
159
+ self .test_dataset ,
160
+ batch_size = self .batch_size ,
161
+ shuffle = False ,
162
+ num_workers = self .num_workers ,
159
163
)
160
164
161
165
162
166
class LightningModel (pl .LightningModule ):
163
167
""" PyTorch Lightning Model class"""
164
168
165
- def __init__ (self , tokenizer , model , outputdir : str = "outputs" ):
169
+ def __init__ (
170
+ self ,
171
+ tokenizer ,
172
+ model ,
173
+ outputdir : str = "outputs" ,
174
+ save_only_last_epoch : bool = False ,
175
+ ):
166
176
"""
167
177
initiates a PyTorch Lightning Model
168
-
169
178
Args:
170
179
tokenizer : T5/MT5/ByT5 tokenizer
171
180
model : T5/MT5/ByT5 model
172
181
outputdir (str, optional): output directory to save model checkpoints. Defaults to "outputs".
182
+ save_only_last_epoch (bool, optional): If True, save just the last epoch else models are saved for every epoch
173
183
"""
174
184
super ().__init__ ()
175
185
self .model = model
176
186
self .tokenizer = tokenizer
177
187
self .outputdir = outputdir
188
+ self .average_training_loss = None
189
+ self .average_validation_loss = None
190
+ self .save_only_last_epoch = save_only_last_epoch
178
191
179
192
def forward (self , input_ids , attention_mask , decoder_attention_mask , labels = None ):
180
193
""" forward step """
@@ -201,7 +214,9 @@ def training_step(self, batch, batch_size):
201
214
labels = labels ,
202
215
)
203
216
204
- self .log ("train_loss" , loss , prog_bar = True , logger = True )
217
+ self .log (
218
+ "train_loss" , loss , prog_bar = True , logger = True , on_epoch = True , on_step = True
219
+ )
205
220
return loss
206
221
207
222
def validation_step (self , batch , batch_size ):
@@ -218,7 +233,9 @@ def validation_step(self, batch, batch_size):
218
233
labels = labels ,
219
234
)
220
235
221
- self .log ("val_loss" , loss , prog_bar = True , logger = True )
236
+ self .log (
237
+ "val_loss" , loss , prog_bar = True , logger = True , on_epoch = True , on_step = True
238
+ )
222
239
return loss
223
240
224
241
def test_step (self , batch , batch_size ):
@@ -244,19 +261,25 @@ def configure_optimizers(self):
244
261
245
262
def training_epoch_end (self , training_step_outputs ):
246
263
""" save tokenizer and model on epoch end """
247
- avg_traning_loss = np .round (
264
+ self . average_training_loss = np .round (
248
265
torch .mean (torch .stack ([x ["loss" ] for x in training_step_outputs ])).item (),
249
266
4 ,
250
267
)
251
- path = f"{ self .outputdir } /simplet5-epoch-{ self .current_epoch } -train-loss-{ str (avg_traning_loss )} "
252
- self .tokenizer .save_pretrained (path )
253
- self .model .save_pretrained (path )
268
+ path = f"{ self .outputdir } /simplet5-epoch-{ self .current_epoch } -train-loss-{ str (self .average_training_loss )} -val-loss-{ str (self .average_validation_loss )} "
269
+ if self .save_only_last_epoch :
270
+ if self .current_epoch == self .trainer .max_epochs - 1 :
271
+ self .tokenizer .save_pretrained (path )
272
+ self .model .save_pretrained (path )
273
+ else :
274
+ self .tokenizer .save_pretrained (path )
275
+ self .model .save_pretrained (path )
254
276
255
- # def validation_epoch_end(self, validation_step_outputs):
256
- # # val_loss = torch.stack([x['loss'] for x in validation_step_outputs]).mean()
257
- # path = f"{self.outputdir}/T5-epoch-{self.current_epoch}"
258
- # self.tokenizer.save_pretrained(path)
259
- # # self.model.save_pretrained(path)
277
+ def validation_epoch_end (self , validation_step_outputs ):
278
+ _loss = [x .cpu () for x in validation_step_outputs ]
279
+ self .average_validation_loss = np .round (
280
+ torch .mean (torch .stack (_loss )).item (),
281
+ 4 ,
282
+ )
260
283
261
284
262
285
class SimpleT5 :
@@ -269,7 +292,6 @@ def __init__(self) -> None:
269
292
def from_pretrained (self , model_type = "t5" , model_name = "t5-base" ) -> None :
270
293
"""
271
294
loads T5/MT5 Model model for training/finetuning
272
-
273
295
Args:
274
296
model_type (str, optional): "t5" or "mt5" . Defaults to "t5".
275
297
model_name (str, optional): exact model architecture name, "t5-base" or "t5-large". Defaults to "t5-base".
@@ -302,10 +324,12 @@ def train(
302
324
outputdir : str = "outputs" ,
303
325
early_stopping_patience_epochs : int = 0 , # 0 to disable early stopping feature
304
326
precision = 32 ,
327
+ logger = "default" ,
328
+ dataloader_num_workers : int = 2 ,
329
+ save_only_last_epoch : bool = False ,
305
330
):
306
331
"""
307
332
trains T5/MT5 model on custom dataset
308
-
309
333
Args:
310
334
train_df (pd.DataFrame): training datarame. Dataframe must have 2 column --> "source_text" and "target_text"
311
335
eval_df ([type], optional): validation datarame. Dataframe must have 2 column --> "source_text" and "target_text"
@@ -317,65 +341,64 @@ def train(
317
341
outputdir (str, optional): output directory to save model checkpoints. Defaults to "outputs".
318
342
early_stopping_patience_epochs (int, optional): monitors val_loss on epoch end and stops training, if val_loss does not improve after the specied number of epochs. set 0 to disable early stopping. Defaults to 0 (disabled)
319
343
precision (int, optional): sets precision training - Double precision (64), full precision (32) or half precision (16). Defaults to 32.
344
+ logger (pytorch_lightning.loggers) : any logger supported by PyTorch Lightning. Defaults to "default". If "default", pytorch lightning default logger is used.
345
+ dataloader_num_workers (int, optional): number of workers in train/test/val dataloader
346
+ save_only_last_epoch (bool, optional): If True, saves only the last epoch else models are saved at every epoch
320
347
"""
321
- self .target_max_token_len = target_max_token_len
322
348
self .data_module = LightningDataModule (
323
349
train_df ,
324
350
eval_df ,
325
351
self .tokenizer ,
326
352
batch_size = batch_size ,
327
353
source_max_token_len = source_max_token_len ,
328
354
target_max_token_len = target_max_token_len ,
355
+ num_workers = dataloader_num_workers ,
329
356
)
330
357
331
358
self .T5Model = LightningModel (
332
- tokenizer = self .tokenizer , model = self .model , outputdir = outputdir
359
+ tokenizer = self .tokenizer ,
360
+ model = self .model ,
361
+ outputdir = outputdir ,
362
+ save_only_last_epoch = save_only_last_epoch ,
333
363
)
334
364
335
- # checkpoint_callback = ModelCheckpoint(
336
- # dirpath="checkpoints",
337
- # filename="best-checkpoint-{epoch}-{train_loss:.2f}",
338
- # save_top_k=-1,
339
- # verbose=True,
340
- # monitor="train_loss",
341
- # mode="min",
342
- # )
343
-
344
- # logger = TensorBoardLogger("SimpleT5", name="SimpleT5-Logger")
345
-
346
- early_stop_callback = (
347
- [
348
- EarlyStopping (
349
- monitor = "val_loss" ,
350
- min_delta = 0.00 ,
351
- patience = early_stopping_patience_epochs ,
352
- verbose = True ,
353
- mode = "min" ,
354
- )
355
- ]
356
- if early_stopping_patience_epochs > 0
357
- else None
358
- )
365
+ # add callbacks
366
+ callbacks = [TQDMProgressBar (refresh_rate = 5 )]
359
367
368
+ if early_stopping_patience_epochs > 0 :
369
+ early_stop_callback = EarlyStopping (
370
+ monitor = "val_loss" ,
371
+ min_delta = 0.00 ,
372
+ patience = early_stopping_patience_epochs ,
373
+ verbose = True ,
374
+ mode = "min" ,
375
+ )
376
+ callbacks .append (early_stop_callback )
377
+
378
+ # add gpu support
360
379
gpus = 1 if use_gpu else 0
361
380
381
+ # add logger
382
+ loggers = True if logger == "default" else logger
383
+
384
+ # prepare trainer
362
385
trainer = pl .Trainer (
363
- # logger=logger ,
364
- callbacks = early_stop_callback ,
386
+ logger = loggers ,
387
+ callbacks = callbacks ,
365
388
max_epochs = max_epochs ,
366
389
gpus = gpus ,
367
- progress_bar_refresh_rate = 5 ,
368
390
precision = precision ,
391
+ log_every_n_steps = 1 ,
369
392
)
370
393
394
+ # fit trainer
371
395
trainer .fit (self .T5Model , self .data_module )
372
396
373
397
def load_model (
374
398
self , model_type : str = "t5" , model_dir : str = "outputs" , use_gpu : bool = False
375
399
):
376
400
"""
377
401
loads a checkpoint for inferencing/prediction
378
-
379
402
Args:
380
403
model_type (str, optional): "t5" or "mt5". Defaults to "t5".
381
404
model_dir (str, optional): path to model directory. Defaults to "outputs".
@@ -418,7 +441,6 @@ def predict(
418
441
):
419
442
"""
420
443
generates prediction for T5/MT5 model
421
-
422
444
Args:
423
445
source_text (str): any text for generating predictions
424
446
max_length (int, optional): max token length of prediction. Defaults to 512.
@@ -432,7 +454,6 @@ def predict(
432
454
early_stopping (bool, optional): Defaults to True.
433
455
skip_special_tokens (bool, optional): Defaults to True.
434
456
clean_up_tokenization_spaces (bool, optional): Defaults to True.
435
-
436
457
Returns:
437
458
list[str]: returns predictions
438
459
"""
@@ -459,20 +480,4 @@ def predict(
459
480
)
460
481
for g in generated_ids
461
482
]
462
- return preds
463
-
464
- # def convert_and_load_onnx_model(self, model_dir: str):
465
- # """ returns ONNX model """
466
- # self.onnx_model = export_and_get_onnx_model(model_dir)
467
- # self.onnx_tokenizer = AutoTokenizer.from_pretrained(model_dir)
468
-
469
- # def onnx_predict(self, source_text: str):
470
- # """ generates prediction from ONNX model """
471
- # token = self.onnx_tokenizer(source_text, return_tensors="pt")
472
- # tokens = self.onnx_model.generate(
473
- # input_ids=token["input_ids"],
474
- # attention_mask=token["attention_mask"],
475
- # num_beams=2,
476
- # )
477
- # output = self.onnx_tokenizer.decode(tokens.squeeze(), skip_special_tokens=True)
478
- # return output
483
+ return preds
0 commit comments