Skip to content

Commit e175fa6

Browse files
author
“Avinash”
committed
Address review comments
1 parent 1fc8aea commit e175fa6

File tree

6 files changed

+100
-22
lines changed

6 files changed

+100
-22
lines changed

examples/bert/README.md

+9-7
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ To summarize, this example showcases:
1515
* Building and fine-tuning on downstream tasks
1616
* Use of Texar `RecordData` module for data loading and processing
1717
* Use of Texar `Executor` module for simplified training loops and TensorBoard visualization
18-
* Use of Hyperopt library to tune hyperparameters with `Executor` module
18+
* Use of [Hyperopt]((https://github.com/hyperopt/hyperopt)) library to tune hyperparameters with
19+
`Executor` module
1920

2021
Future work:
2122

@@ -188,15 +189,16 @@ To run this example, please install `hyperopt` by issuing the following command
188189
pip install hyperopt
189190
```
190191

191-
`bert_with_tpe.py` shows an example of how to tune hyperparameters with Executor using `hyperopt`.
192+
`bert_with_hypertuning_main.py` shows an example of how to tune hyperparameters with Executor using `hyperopt`.
192193
To run this example, run the following command
193194

194195
```commandline
195-
python bert_with_tpe.py
196+
python bert_with_hypertuning_main.py
196197
```
197198

198199
In this simple example, the hyperparameters to be tuned are provided as a `dict` in
199-
`bert_tpe_config_classifier.py` which are fed into `objective_func()` . We use `TPE` algorithm for
200-
tuning the hyperparams (provided in `hyperopt` library). The example runs for 3 trials to find the
201-
best hyperparam settings. The final model is saved in `/model/{exp_number}` folder. More
202-
information about the libary can be found at [Hyperopt](https://github.com/hyperopt/hyperopt)
200+
`bert_hypertuning_config_classifier.py` which are fed into `objective_func()` . We use `TPE`
201+
(Tree-structured Parzen Estimator) algorithm for tuning the hyperparams (provided in `hyperopt`
202+
library). The example runs for 3 trials to find the best hyperparam settings. The final model is
203+
saved in `output_dir` provided by the user. More information about the libary can be
204+
found at [Hyperopt](https://github.com/hyperopt/hyperopt)

examples/bert/bert_with_tpe.py renamed to examples/bert/bert_with_hypertuning_main.py

+79-7
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
parser = argparse.ArgumentParser()
3636
parser.add_argument(
37-
"--config-downstream", default="bert_tpe_config_classifier",
37+
"--config-downstream", default="bert_hypertuning_config_classifier",
3838
help="Configuration of the downstream part of the model")
3939
parser.add_argument(
4040
'--pretrained-model-name', type=str, default='bert-base-uncased',
@@ -70,13 +70,21 @@
7070

7171

7272
class ModelWrapper(nn.Module):
73+
r"""This class wraps a model (in this case a BERT classifier) and implements
74+
:meth:`forward` and :meth:`predict` to conform to the requirements of
75+
:class:`Executor` class. Particularly, :meth:`forward` returns a dict with
76+
keys "loss" and "preds" and :meth:`predict` returns a dict with key "preds".
77+
78+
Args:
79+
`model`: BERTClassifier
80+
A BERTClassifier model
81+
"""
82+
7383
def __init__(self, model: BERTClassifier):
7484
super().__init__()
7585
self.model = model
7686

7787
def _compute_loss(self, logits, labels):
78-
r"""Compute loss.
79-
"""
8088
if self.model.is_binary:
8189
loss = F.binary_cross_entropy(
8290
logits.view(-1), labels.view(-1), reduction='mean')
@@ -88,6 +96,18 @@ def _compute_loss(self, logits, labels):
8896

8997
def forward(self, # type: ignore
9098
batch: tx.data.Batch) -> Dict[str, torch.Tensor]:
99+
r"""Run forward through the network and return a dict to be consumed
100+
by the :class:`Executor`. This method will be called by
101+
:class:``Executor` during training.
102+
103+
Args:
104+
`batch`: tx.data.Batch
105+
A batch of inputs to be passed through the network
106+
107+
Returns:
108+
A dict with keys "loss" and "preds" containing the loss and
109+
predictions on :attr:`batch` respectively.
110+
"""
91111
input_ids = batch["input_ids"]
92112
segment_ids = batch["segment_ids"]
93113
labels = batch["label_ids"]
@@ -101,6 +121,15 @@ def forward(self, # type: ignore
101121
return {"loss": loss, "preds": preds}
102122

103123
def predict(self, batch: tx.data.Batch) -> Dict[str, torch.Tensor]:
124+
r"""Predict the labels for the :attr:`batch` of examples. This method
125+
will be called instead of :meth:`forward` during validation or testing,
126+
if :class:`Executor`'s :attr:`validate_mode` or :attr:`test_mode` is set
127+
to ``"predict"`` instead of ``"eval"``.
128+
129+
Args:
130+
`batch`: tx.data.Batch
131+
A batch of inputs to run prediction on
132+
"""
104133
input_ids = batch["input_ids"]
105134
segment_ids = batch["segment_ids"]
106135

@@ -112,11 +141,23 @@ def predict(self, batch: tx.data.Batch) -> Dict[str, torch.Tensor]:
112141

113142

114143
class TPE:
115-
def __init__(self, model_config=None):
144+
r""":class:`TPE` uses Tree-structured Parzen Estimator algorithm from
145+
`hyperopt` for hyperparameter tuning.
146+
147+
Args:
148+
model_config: Dict
149+
A conf dict which is passed to BERT classifier
150+
output_dir: str
151+
A path to store the models
152+
153+
"""
154+
def __init__(self, model_config: Dict, output_dir: str = "output/"):
116155
tx.utils.maybe_create_dir(args.output_dir)
117156

118157
self.model_config = model_config
119158

159+
self.output_dir = output_dir
160+
120161
# create datasets
121162
self.train_dataset = tx.data.RecordData(
122163
hparams=config_data.train_hparam, device=device)
@@ -150,7 +191,31 @@ def __init__(self, model_config=None):
150191
self.optim = tx.core.BertAdam
151192

152193
def objective_func(self, params: Dict):
153-
194+
r"""Compute a "loss" for a given hyperparameter values. This function is
195+
passed to hyperopt's ``"fmin"`` (see the :meth:`run` method) function
196+
and gets repeatedly called to find the best set of hyperparam values.
197+
Below is an example of how to use this method
198+
199+
.. code-block:: python
200+
201+
import hyperopt as hpo
202+
203+
trials = hpo.Trials()
204+
hpo.fmin(fn=self.objective_func,
205+
space=space,
206+
algo=hpo.tpe.suggest,
207+
max_evals=3,
208+
trials=trials)
209+
210+
Args:
211+
params: Dict
212+
A `(key, value)` dict representing the ``"value"`` to try for
213+
the hyperparam ``"key"``
214+
215+
Returns:
216+
A dict with keys "loss", "status" and "model" indicating the loss
217+
for this trial, the status, and the path to the saved model.
218+
"""
154219
print(f"Using {params} for trial {self.exp_number}")
155220

156221
# Loads data
@@ -188,7 +253,7 @@ def objective_func(self, params: Dict):
188253

189254
valid_metric = metric.Accuracy(pred_name="preds",
190255
label_name="label_ids")
191-
checkpoint_dir = f"./models/exp{self.exp_number}"
256+
checkpoint_dir = f"./{self.output_dir}/exp{self.exp_number}"
192257

193258
executor = Executor(
194259
# supply executor with the model
@@ -232,6 +297,13 @@ def objective_func(self, params: Dict):
232297
}
233298

234299
def run(self, hyperparams: Dict):
300+
r"""Run the TPE algorithm with hyperparameters :attr:`hyperparams`
301+
302+
Args:
303+
hyperparams: Dict
304+
The `(key, value)` pairs of hyperparameters along their range of
305+
values.
306+
"""
235307
space = {}
236308
for k, v in hyperparams.items():
237309
if isinstance(v, dict):
@@ -258,7 +330,7 @@ def run(self, hyperparams: Dict):
258330
def main():
259331
model_config = {k: v for k, v in config_downstream.items() if
260332
k != "hyperparams"}
261-
tpe = TPE(model_config=model_config)
333+
tpe = TPE(model_config=model_config, output_dir=args.output_dir)
262334
hyperparams = config_downstream["hyperparams"]
263335
tpe.run(hyperparams)
264336

examples/bert/config_data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
max_batch_tokens = 128
88

99
train_batch_size = 32
10-
max_train_epoch = 3
10+
max_train_epoch = 5
1111
display_steps = 50 # Print training loss every display_steps; -1 to disable
1212

1313
# tbx config

examples/bert/requirements.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
tensorflow
2-
tensorboardX>=1.8
3-
hyperopt
1+
tensorflow # used for loading BERT official model checkpoint
2+
tensorboardX>=1.8 # used only in bert_classifier_using_executor_main.py
3+
hyperopt # used only in bert_with_hypertuning_main.py

texar/torch/run/executor.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1348,11 +1348,11 @@ def _try_get_data_size(executor: 'Executor'):
13481348
finally:
13491349
self._train_tracker.stop()
13501350

1351+
self._fire_event(Event.Training, True)
1352+
13511353
# close the log files
13521354
self._close_files()
13531355

1354-
self._fire_event(Event.Training, True)
1355-
13561356
def test(self, dataset: OptionalDict[DataBase] = None):
13571357
r"""Start the test loop.
13581358
@@ -1414,11 +1414,11 @@ def test(self, dataset: OptionalDict[DataBase] = None):
14141414

14151415
self._fire_event(Event.Testing, True)
14161416

1417+
self.model.train(model_mode)
1418+
14171419
# close the log files
14181420
self._close_files()
14191421

1420-
self.model.train(model_mode)
1421-
14221422
def _register_logging_actions(self, show_live_progress: List[str]):
14231423
# Register logging actions.
14241424
Points = Sequence[Union[Condition, Event]]
@@ -1701,6 +1701,10 @@ def _register_hook(self, event_point: EventPoint, action: ActionFn,
17011701
f"Specified hook point {event_point} is invalid") from None
17021702

17031703
def _open_files(self):
1704+
self._opened_files = []
1705+
self._log_destination = []
1706+
self._log_destination_is_tty = []
1707+
17041708
for dest in utils.to_list(self.log_destination):
17051709
if isinstance(dest, (str, Path)):
17061710
# Append to the logs to prevent accidentally overwriting

0 commit comments

Comments
 (0)