Skip to content

Commit cb4d892

Browse files
author
Shivanandroy
committed
added logger, save model only at last epoch, compatible with latest transformers + pytorch lightning
1 parent 94ad898 commit cb4d892

File tree

4 files changed

+90
-83
lines changed

4 files changed

+90
-83
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Here's a link to [Medium article](https://snrspeaks.medium.com/simplet5-train-t5
2727

2828
## Install
2929
```python
30+
# It's advisable to create a new python environment and install simplet5
3031
pip install --upgrade simplet5
3132
```
3233

requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
numpy
12
pandas
23
sentencepiece
34
torch>=1.7.0,!=1.8.0
4-
transformers==4.10.0
5-
pytorch-lightning==1.4.5
5+
transformers==4.16.2
6+
pytorch-lightning==1.5.10

setup.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
setuptools.setup(
1010
name="simplet5",
11-
version="0.1.3",
11+
version="0.1.4",
1212
license="apache-2.0",
1313
author="Shivanand Roy",
1414
author_email="[email protected]",
@@ -39,12 +39,12 @@
3939
packages=setuptools.find_packages(),
4040
python_requires=">=3.5",
4141
install_requires=[
42+
"numpy",
43+
"pandas",
4244
"sentencepiece",
4345
"torch>=1.7.0,!=1.8.0", # excludes torch v1.8.0
44-
"transformers==4.10.0",
45-
"pytorch-lightning==1.4.5",
46-
"tqdm"
47-
# "fastt5==0.0.7",
46+
"transformers==4.16.2",
47+
"pytorch-lightning==1.5.10",
4848
],
4949
classifiers=[
5050
"Intended Audience :: Developers",

simplet5/simplet5.py

Lines changed: 81 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import torch
22
import numpy as np
33
import pandas as pd
4-
from tqdm.auto import tqdm
54
from transformers import (
6-
AdamW,
75
T5ForConditionalGeneration,
86
MT5ForConditionalGeneration,
97
ByT5Tokenizer,
@@ -12,14 +10,13 @@
1210
MT5TokenizerFast as MT5Tokenizer,
1311
)
1412
from transformers import AutoTokenizer
15-
16-
# from fastT5 import export_and_get_onnx_model
13+
from torch.optim import AdamW
1714
from torch.utils.data import Dataset, DataLoader
1815
from transformers import AutoModelWithLMHead, AutoTokenizer
1916
import pytorch_lightning as pl
2017
from pytorch_lightning.loggers import TensorBoardLogger
21-
from pytorch_lightning.callbacks import ModelCheckpoint
2218
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
19+
from pytorch_lightning.callbacks.progress import TQDMProgressBar
2320

2421
torch.cuda.empty_cache()
2522
pl.seed_everything(42)
@@ -37,7 +34,6 @@ def __init__(
3734
):
3835
"""
3936
initiates a PyTorch Dataset Module for input data
40-
4137
Args:
4238
data (pd.DataFrame): input pandas dataframe. Dataframe must have 2 column --> "source_text" and "target_text"
4339
tokenizer (PreTrainedTokenizer): a PreTrainedTokenizer (T5Tokenizer, MT5Tokenizer, or ByT5Tokenizer)
@@ -85,8 +81,6 @@ def __getitem__(self, index: int):
8581
] = -100 # to make sure we have correct labels for T5 text generation
8682

8783
return dict(
88-
source_text=source_text,
89-
target_text=data_row["target_text"],
9084
source_text_input_ids=source_text_encoding["input_ids"].flatten(),
9185
source_text_attention_mask=source_text_encoding["attention_mask"].flatten(),
9286
labels=labels.flatten(),
@@ -105,10 +99,10 @@ def __init__(
10599
batch_size: int = 4,
106100
source_max_token_len: int = 512,
107101
target_max_token_len: int = 512,
102+
num_workers: int = 2,
108103
):
109104
"""
110105
initiates a PyTorch Lightning Data Module
111-
112106
Args:
113107
train_df (pd.DataFrame): training dataframe. Dataframe must contain 2 columns --> "source_text" & "target_text"
114108
test_df (pd.DataFrame): validation dataframe. Dataframe must contain 2 columns --> "source_text" & "target_text"
@@ -125,6 +119,7 @@ def __init__(
125119
self.tokenizer = tokenizer
126120
self.source_max_token_len = source_max_token_len
127121
self.target_max_token_len = target_max_token_len
122+
self.num_workers = num_workers
128123

129124
def setup(self, stage=None):
130125
self.train_dataset = PyTorchDataModule(
@@ -143,38 +138,56 @@ def setup(self, stage=None):
143138
def train_dataloader(self):
144139
""" training dataloader """
145140
return DataLoader(
146-
self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=2
141+
self.train_dataset,
142+
batch_size=self.batch_size,
143+
shuffle=True,
144+
num_workers=self.num_workers,
147145
)
148146

149147
def test_dataloader(self):
150148
""" test dataloader """
151149
return DataLoader(
152-
self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2
150+
self.test_dataset,
151+
batch_size=self.batch_size,
152+
shuffle=False,
153+
num_workers=self.num_workers,
153154
)
154155

155156
def val_dataloader(self):
156157
""" validation dataloader """
157158
return DataLoader(
158-
self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2
159+
self.test_dataset,
160+
batch_size=self.batch_size,
161+
shuffle=False,
162+
num_workers=self.num_workers,
159163
)
160164

161165

162166
class LightningModel(pl.LightningModule):
163167
""" PyTorch Lightning Model class"""
164168

165-
def __init__(self, tokenizer, model, outputdir: str = "outputs"):
169+
def __init__(
170+
self,
171+
tokenizer,
172+
model,
173+
outputdir: str = "outputs",
174+
save_only_last_epoch: bool = False,
175+
):
166176
"""
167177
initiates a PyTorch Lightning Model
168-
169178
Args:
170179
tokenizer : T5/MT5/ByT5 tokenizer
171180
model : T5/MT5/ByT5 model
172181
outputdir (str, optional): output directory to save model checkpoints. Defaults to "outputs".
182+
save_only_last_epoch (bool, optional): If True, save just the last epoch else models are saved for every epoch
173183
"""
174184
super().__init__()
175185
self.model = model
176186
self.tokenizer = tokenizer
177187
self.outputdir = outputdir
188+
self.average_training_loss = None
189+
self.average_validation_loss = None
190+
self.save_only_last_epoch = save_only_last_epoch
178191

179192
def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
180193
""" forward step """
@@ -201,7 +214,9 @@ def training_step(self, batch, batch_size):
201214
labels=labels,
202215
)
203216

204-
self.log("train_loss", loss, prog_bar=True, logger=True)
217+
self.log(
218+
"train_loss", loss, prog_bar=True, logger=True, on_epoch=True, on_step=True
219+
)
205220
return loss
206221

207222
def validation_step(self, batch, batch_size):
@@ -218,7 +233,9 @@ def validation_step(self, batch, batch_size):
218233
labels=labels,
219234
)
220235

221-
self.log("val_loss", loss, prog_bar=True, logger=True)
236+
self.log(
237+
"val_loss", loss, prog_bar=True, logger=True, on_epoch=True, on_step=True
238+
)
222239
return loss
223240

224241
def test_step(self, batch, batch_size):
@@ -244,19 +261,25 @@ def configure_optimizers(self):
244261

245262
def training_epoch_end(self, training_step_outputs):
246263
""" save tokenizer and model on epoch end """
247-
avg_traning_loss = np.round(
264+
self.average_training_loss = np.round(
248265
torch.mean(torch.stack([x["loss"] for x in training_step_outputs])).item(),
249266
4,
250267
)
251-
path = f"{self.outputdir}/simplet5-epoch-{self.current_epoch}-train-loss-{str(avg_traning_loss)}"
252-
self.tokenizer.save_pretrained(path)
253-
self.model.save_pretrained(path)
268+
path = f"{self.outputdir}/simplet5-epoch-{self.current_epoch}-train-loss-{str(self.average_training_loss)}-val-loss-{str(self.average_validation_loss)}"
269+
if self.save_only_last_epoch:
270+
if self.current_epoch == self.trainer.max_epochs - 1:
271+
self.tokenizer.save_pretrained(path)
272+
self.model.save_pretrained(path)
273+
else:
274+
self.tokenizer.save_pretrained(path)
275+
self.model.save_pretrained(path)
254276

255-
# def validation_epoch_end(self, validation_step_outputs):
256-
# # val_loss = torch.stack([x['loss'] for x in validation_step_outputs]).mean()
257-
# path = f"{self.outputdir}/T5-epoch-{self.current_epoch}"
258-
# self.tokenizer.save_pretrained(path)
259-
# # self.model.save_pretrained(path)
277+
def validation_epoch_end(self, validation_step_outputs):
278+
_loss = [x.cpu() for x in validation_step_outputs]
279+
self.average_validation_loss = np.round(
280+
torch.mean(torch.stack(_loss)).item(),
281+
4,
282+
)
260283

261284

262285
class SimpleT5:
@@ -269,7 +292,6 @@ def __init__(self) -> None:
269292
def from_pretrained(self, model_type="t5", model_name="t5-base") -> None:
270293
"""
271294
loads T5/MT5 Model model for training/finetuning
272-
273295
Args:
274296
model_type (str, optional): "t5" or "mt5" . Defaults to "t5".
275297
model_name (str, optional): exact model architecture name, "t5-base" or "t5-large". Defaults to "t5-base".
@@ -302,10 +324,12 @@ def train(
302324
outputdir: str = "outputs",
303325
early_stopping_patience_epochs: int = 0, # 0 to disable early stopping feature
304326
precision=32,
327+
logger="default",
328+
dataloader_num_workers: int = 2,
329+
save_only_last_epoch: bool = False,
305330
):
306331
"""
307332
trains T5/MT5 model on custom dataset
308-
309333
Args:
310334
train_df (pd.DataFrame): training datarame. Dataframe must have 2 column --> "source_text" and "target_text"
311335
eval_df ([type], optional): validation datarame. Dataframe must have 2 column --> "source_text" and "target_text"
@@ -317,65 +341,64 @@ def train(
317341
outputdir (str, optional): output directory to save model checkpoints. Defaults to "outputs".
318342
early_stopping_patience_epochs (int, optional): monitors val_loss on epoch end and stops training, if val_loss does not improve after the specied number of epochs. set 0 to disable early stopping. Defaults to 0 (disabled)
319343
precision (int, optional): sets precision training - Double precision (64), full precision (32) or half precision (16). Defaults to 32.
344+
logger (pytorch_lightning.loggers) : any logger supported by PyTorch Lightning. Defaults to "default". If "default", pytorch lightning default logger is used.
345+
dataloader_num_workers (int, optional): number of workers in train/test/val dataloader
346+
save_only_last_epoch (bool, optional): If True, saves only the last epoch else models are saved at every epoch
320347
"""
321-
self.target_max_token_len = target_max_token_len
322348
self.data_module = LightningDataModule(
323349
train_df,
324350
eval_df,
325351
self.tokenizer,
326352
batch_size=batch_size,
327353
source_max_token_len=source_max_token_len,
328354
target_max_token_len=target_max_token_len,
355+
num_workers=dataloader_num_workers,
329356
)
330357

331358
self.T5Model = LightningModel(
332-
tokenizer=self.tokenizer, model=self.model, outputdir=outputdir
359+
tokenizer=self.tokenizer,
360+
model=self.model,
361+
outputdir=outputdir,
362+
save_only_last_epoch=save_only_last_epoch,
333363
)
334364

335-
# checkpoint_callback = ModelCheckpoint(
336-
# dirpath="checkpoints",
337-
# filename="best-checkpoint-{epoch}-{train_loss:.2f}",
338-
# save_top_k=-1,
339-
# verbose=True,
340-
# monitor="train_loss",
341-
# mode="min",
342-
# )
343-
344-
# logger = TensorBoardLogger("SimpleT5", name="SimpleT5-Logger")
345-
346-
early_stop_callback = (
347-
[
348-
EarlyStopping(
349-
monitor="val_loss",
350-
min_delta=0.00,
351-
patience=early_stopping_patience_epochs,
352-
verbose=True,
353-
mode="min",
354-
)
355-
]
356-
if early_stopping_patience_epochs > 0
357-
else None
358-
)
365+
# add callbacks
366+
callbacks = [TQDMProgressBar(refresh_rate=5)]
359367

368+
if early_stopping_patience_epochs > 0:
369+
early_stop_callback = EarlyStopping(
370+
monitor="val_loss",
371+
min_delta=0.00,
372+
patience=early_stopping_patience_epochs,
373+
verbose=True,
374+
mode="min",
375+
)
376+
callbacks.append(early_stop_callback)
377+
378+
# add gpu support
360379
gpus = 1 if use_gpu else 0
361380

381+
# add logger
382+
loggers = True if logger == "default" else logger
383+
384+
# prepare trainer
362385
trainer = pl.Trainer(
363-
# logger=logger,
364-
callbacks=early_stop_callback,
386+
logger=loggers,
387+
callbacks=callbacks,
365388
max_epochs=max_epochs,
366389
gpus=gpus,
367-
progress_bar_refresh_rate=5,
368390
precision=precision,
391+
log_every_n_steps=1,
369392
)
370393

394+
# fit trainer
371395
trainer.fit(self.T5Model, self.data_module)
372396

373397
def load_model(
374398
self, model_type: str = "t5", model_dir: str = "outputs", use_gpu: bool = False
375399
):
376400
"""
377401
loads a checkpoint for inferencing/prediction
378-
379402
Args:
380403
model_type (str, optional): "t5" or "mt5". Defaults to "t5".
381404
model_dir (str, optional): path to model directory. Defaults to "outputs".
@@ -418,7 +441,6 @@ def predict(
418441
):
419442
"""
420443
generates prediction for T5/MT5 model
421-
422444
Args:
423445
source_text (str): any text for generating predictions
424446
max_length (int, optional): max token length of prediction. Defaults to 512.
@@ -432,7 +454,6 @@ def predict(
432454
early_stopping (bool, optional): Defaults to True.
433455
skip_special_tokens (bool, optional): Defaults to True.
434456
clean_up_tokenization_spaces (bool, optional): Defaults to True.
435-
436457
Returns:
437458
list[str]: returns predictions
438459
"""
@@ -459,20 +480,4 @@ def predict(
459480
)
460481
for g in generated_ids
461482
]
462-
return preds
463-
464-
# def convert_and_load_onnx_model(self, model_dir: str):
465-
# """ returns ONNX model """
466-
# self.onnx_model = export_and_get_onnx_model(model_dir)
467-
# self.onnx_tokenizer = AutoTokenizer.from_pretrained(model_dir)
468-
469-
# def onnx_predict(self, source_text: str):
470-
# """ generates prediction from ONNX model """
471-
# token = self.onnx_tokenizer(source_text, return_tensors="pt")
472-
# tokens = self.onnx_model.generate(
473-
# input_ids=token["input_ids"],
474-
# attention_mask=token["attention_mask"],
475-
# num_beams=2,
476-
# )
477-
# output = self.onnx_tokenizer.decode(tokens.squeeze(), skip_special_tokens=True)
478-
# return output
483+
return preds

0 commit comments

Comments
 (0)