Skip to content

Commit 1fc8aea

Browse files
author
“Avinash”
committed
Add Hyperopt example for BERT classifier
1 parent aa31d85 commit 1fc8aea

File tree

6 files changed

+339
-25
lines changed

6 files changed

+339
-25
lines changed

examples/bert/README.md

+22
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ To summarize, this example showcases:
1515
* Building and fine-tuning on downstream tasks
1616
* Use of Texar `RecordData` module for data loading and processing
1717
* Use of Texar `Executor` module for simplified training loops and TensorBoard visualization
18+
* Use of Hyperopt library to tune hyperparameters with `Executor` module
1819

1920
Future work:
2021

@@ -178,3 +179,24 @@ tensorboard --logdir runs/
178179
```
179180

180181
![Visualizing loss/accuarcy on Tensorboard](tbx.png)
182+
183+
## Hyperparameter tuning with Executor
184+
185+
To run this example, please install `hyperopt` by issuing the following command
186+
187+
```commandline
188+
pip install hyperopt
189+
```
190+
191+
`bert_with_tpe.py` shows an example of how to tune hyperparameters with Executor using `hyperopt`.
192+
To run this example, run the following command
193+
194+
```commandline
195+
python bert_with_tpe.py
196+
```
197+
198+
In this simple example, the hyperparameters to be tuned are provided as a `dict` in
199+
`bert_tpe_config_classifier.py` which are fed into `objective_func()` . We use `TPE` algorithm for
200+
tuning the hyperparams (provided in `hyperopt` library). The example runs for 3 trials to find the
201+
best hyperparam settings. The final model is saved in `/model/{exp_number}` folder. More
202+
information about the libary can be found at [Hyperopt](https://github.com/hyperopt/hyperopt)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
name = "bert_classifier"
2+
hidden_size = 768
3+
clas_strategy = "cls_time"
4+
dropout = 0.1
5+
num_classes = 2
6+
7+
# hyperparams
8+
hyperparams = {
9+
"optimizer.warmup_steps": {"start": 10000, "end": 20000, "dtype": int},
10+
"optimizer.static_lr": {"start": 1e-3, "end": 1e-2, "dtype": float}
11+
}

examples/bert/bert_with_tpe.py

+267
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
# Copyright 2019 The Texar Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import functools
17+
import importlib
18+
import logging
19+
import shutil
20+
from typing import Dict
21+
22+
import torch
23+
from torch import nn
24+
import torch.nn.functional as F
25+
26+
import hyperopt as hpo
27+
28+
import texar.torch as tx
29+
from texar.torch.run import *
30+
from texar.torch.modules import BERTClassifier
31+
32+
from utils import model_utils
33+
34+
35+
parser = argparse.ArgumentParser()
36+
parser.add_argument(
37+
"--config-downstream", default="bert_tpe_config_classifier",
38+
help="Configuration of the downstream part of the model")
39+
parser.add_argument(
40+
'--pretrained-model-name', type=str, default='bert-base-uncased',
41+
choices=tx.modules.BERTEncoder.available_checkpoints(),
42+
help="Name of the pre-trained checkpoint to load.")
43+
parser.add_argument(
44+
"--config-data", default="config_data", help="The dataset config.")
45+
parser.add_argument(
46+
"--output-dir", default="output/",
47+
help="The output directory where the model checkpoints will be written.")
48+
parser.add_argument(
49+
"--checkpoint", type=str, default=None,
50+
help="Path to a model checkpoint (including bert modules) to restore from.")
51+
parser.add_argument(
52+
"--do-train", action="store_true", help="Whether to run training.")
53+
parser.add_argument(
54+
"--do-eval", action="store_true",
55+
help="Whether to run eval on the dev set.")
56+
parser.add_argument(
57+
"--do-test", action="store_true",
58+
help="Whether to run test on the test set.")
59+
args = parser.parse_args()
60+
61+
config_data = importlib.import_module(args.config_data)
62+
config_downstream = importlib.import_module(args.config_downstream)
63+
config_downstream = {
64+
k: v for k, v in config_downstream.__dict__.items()
65+
if not k.startswith('__')}
66+
67+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68+
69+
logging.root.setLevel(logging.INFO)
70+
71+
72+
class ModelWrapper(nn.Module):
73+
def __init__(self, model: BERTClassifier):
74+
super().__init__()
75+
self.model = model
76+
77+
def _compute_loss(self, logits, labels):
78+
r"""Compute loss.
79+
"""
80+
if self.model.is_binary:
81+
loss = F.binary_cross_entropy(
82+
logits.view(-1), labels.view(-1), reduction='mean')
83+
else:
84+
loss = F.cross_entropy(
85+
logits.view(-1, self.model.num_classes),
86+
labels.view(-1), reduction='mean')
87+
return loss
88+
89+
def forward(self, # type: ignore
90+
batch: tx.data.Batch) -> Dict[str, torch.Tensor]:
91+
input_ids = batch["input_ids"]
92+
segment_ids = batch["segment_ids"]
93+
labels = batch["label_ids"]
94+
95+
input_length = (1 - (input_ids == 0).int()).sum(dim=1)
96+
97+
logits, preds = self.model(input_ids, input_length, segment_ids)
98+
99+
loss = self._compute_loss(logits, labels)
100+
101+
return {"loss": loss, "preds": preds}
102+
103+
def predict(self, batch: tx.data.Batch) -> Dict[str, torch.Tensor]:
104+
input_ids = batch["input_ids"]
105+
segment_ids = batch["segment_ids"]
106+
107+
input_length = (1 - (input_ids == 0).int()).sum(dim=1)
108+
109+
_, preds = self.model(input_ids, input_length, segment_ids)
110+
111+
return {"preds": preds}
112+
113+
114+
class TPE:
115+
def __init__(self, model_config=None):
116+
tx.utils.maybe_create_dir(args.output_dir)
117+
118+
self.model_config = model_config
119+
120+
# create datasets
121+
self.train_dataset = tx.data.RecordData(
122+
hparams=config_data.train_hparam, device=device)
123+
self.eval_dataset = tx.data.RecordData(
124+
hparams=config_data.eval_hparam, device=device)
125+
126+
# Builds BERT
127+
model = tx.modules.BERTClassifier(
128+
pretrained_model_name=args.pretrained_model_name,
129+
hparams=self.model_config)
130+
self.model = ModelWrapper(model=model)
131+
self.model.to(device)
132+
133+
# batching
134+
self.batching_strategy = tx.data.TokenCountBatchingStrategy(
135+
max_tokens=config_data.max_batch_tokens)
136+
137+
# logging formats
138+
self.log_format = "{time} : Epoch {epoch:2d} @ {iteration:6d}it " \
139+
"({progress}%, {speed}), " \
140+
"lr = {lr:.9e}, loss = {loss:.3f}"
141+
self.valid_log_format = "{time} : Epoch {epoch}, " \
142+
"{split} accuracy = {Accuracy:.3f}, " \
143+
"loss = {loss:.3f}"
144+
self.valid_progress_log_format = "{time} : Evaluating on " \
145+
"{split} ({progress}%, {speed})"
146+
147+
# exp number
148+
self.exp_number = 1
149+
150+
self.optim = tx.core.BertAdam
151+
152+
def objective_func(self, params: Dict):
153+
154+
print(f"Using {params} for trial {self.exp_number}")
155+
156+
# Loads data
157+
num_train_data = config_data.num_train_data
158+
num_train_steps = int(num_train_data / config_data.train_batch_size *
159+
config_data.max_train_epoch)
160+
161+
# hyperparams
162+
num_warmup_steps = params["optimizer.warmup_steps"]
163+
static_lr = params["optimizer.static_lr"]
164+
165+
vars_with_decay = []
166+
vars_without_decay = []
167+
for name, param in self.model.named_parameters():
168+
if 'layer_norm' in name or name.endswith('bias'):
169+
vars_without_decay.append(param)
170+
else:
171+
vars_with_decay.append(param)
172+
173+
opt_params = [{
174+
'params': vars_with_decay,
175+
'weight_decay': 0.01,
176+
}, {
177+
'params': vars_without_decay,
178+
'weight_decay': 0.0,
179+
}]
180+
181+
optim = self.optim(opt_params, betas=(0.9, 0.999), eps=1e-6,
182+
lr=static_lr)
183+
184+
scheduler = torch.optim.lr_scheduler.LambdaLR(
185+
optim, functools.partial(model_utils.get_lr_multiplier,
186+
total_steps=num_train_steps,
187+
warmup_steps=num_warmup_steps))
188+
189+
valid_metric = metric.Accuracy(pred_name="preds",
190+
label_name="label_ids")
191+
checkpoint_dir = f"./models/exp{self.exp_number}"
192+
193+
executor = Executor(
194+
# supply executor with the model
195+
model=self.model,
196+
# define datasets
197+
train_data=self.train_dataset,
198+
valid_data=self.eval_dataset,
199+
batching_strategy=self.batching_strategy,
200+
device=device,
201+
# training and stopping details
202+
optimizer=optim,
203+
lr_scheduler=scheduler,
204+
stop_training_on=cond.epoch(config_data.max_train_epoch),
205+
# logging details
206+
log_every=[cond.epoch(1)],
207+
# logging format
208+
log_format=self.log_format,
209+
# define metrics
210+
train_metrics=[
211+
("loss", metric.RunningAverage(1)),
212+
("lr", metric.LR(optim))],
213+
valid_metrics=[valid_metric, ("loss", metric.Average())],
214+
validate_every=cond.epoch(1),
215+
save_every=cond.epoch(config_data.max_train_epoch),
216+
checkpoint_dir=checkpoint_dir,
217+
max_to_keep=1,
218+
show_live_progress=True,
219+
print_model_arch=False
220+
)
221+
222+
executor.train()
223+
224+
print(f"Loss on the valid dataset "
225+
f"{executor.valid_metrics['loss'].value()}")
226+
self.exp_number += 1
227+
228+
return {
229+
"loss": executor.valid_metrics["loss"].value(),
230+
"status": hpo.STATUS_OK,
231+
"model": checkpoint_dir
232+
}
233+
234+
def run(self, hyperparams: Dict):
235+
space = {}
236+
for k, v in hyperparams.items():
237+
if isinstance(v, dict):
238+
if v["dtype"] == int:
239+
space[k] = hpo.hp.choice(
240+
k, range(v["start"], v["end"]))
241+
else:
242+
space[k] = hpo.hp.uniform(k, v["start"], v["end"])
243+
trials = hpo.Trials()
244+
hpo.fmin(fn=self.objective_func,
245+
space=space,
246+
algo=hpo.tpe.suggest,
247+
max_evals=3,
248+
trials=trials)
249+
_, best_trial = min((trial["result"]["loss"], trial)
250+
for trial in trials.trials)
251+
252+
# delete all the other models
253+
for trial in trials.trials:
254+
if trial is not best_trial:
255+
shutil.rmtree(trial["result"]["model"])
256+
257+
258+
def main():
259+
model_config = {k: v for k, v in config_downstream.items() if
260+
k != "hyperparams"}
261+
tpe = TPE(model_config=model_config)
262+
hyperparams = config_downstream["hyperparams"]
263+
tpe.run(hyperparams)
264+
265+
266+
if __name__ == '__main__':
267+
main()

examples/bert/config_data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
max_batch_tokens = 128
88

99
train_batch_size = 32
10-
max_train_epoch = 5
10+
max_train_epoch = 3
1111
display_steps = 50 # Print training loss every display_steps; -1 to disable
1212

1313
# tbx config

examples/bert/requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
tensorflow
2-
tensorboardX>=1.8
2+
tensorboardX>=1.8
3+
hyperopt

0 commit comments

Comments
 (0)