34
34
35
35
parser = argparse .ArgumentParser ()
36
36
parser .add_argument (
37
- "--config-downstream" , default = "bert_tpe_config_classifier " ,
37
+ "--config-downstream" , default = "bert_hypertuning_config_classifier " ,
38
38
help = "Configuration of the downstream part of the model" )
39
39
parser .add_argument (
40
40
'--pretrained-model-name' , type = str , default = 'bert-base-uncased' ,
70
70
71
71
72
72
class ModelWrapper (nn .Module ):
73
+ r"""This class wraps a model (in this case a BERT classifier) and implements
74
+ :meth:`forward` and :meth:`predict` to conform to the requirements of
75
+ :class:`Executor` class. Particularly, :meth:`forward` returns a dict with
76
+ keys "loss" and "preds" and :meth:`predict` returns a dict with key "preds".
77
+
78
+ Args:
79
+ `model`: BERTClassifier
80
+ A BERTClassifier model
81
+ """
82
+
73
83
def __init__ (self , model : BERTClassifier ):
74
84
super ().__init__ ()
75
85
self .model = model
76
86
77
87
def _compute_loss (self , logits , labels ):
78
- r"""Compute loss.
79
- """
80
88
if self .model .is_binary :
81
89
loss = F .binary_cross_entropy (
82
90
logits .view (- 1 ), labels .view (- 1 ), reduction = 'mean' )
@@ -88,6 +96,18 @@ def _compute_loss(self, logits, labels):
88
96
89
97
def forward (self , # type: ignore
90
98
batch : tx .data .Batch ) -> Dict [str , torch .Tensor ]:
99
+ r"""Run forward through the network and return a dict to be consumed
100
+ by the :class:`Executor`. This method will be called by
101
+ :class:``Executor` during training.
102
+
103
+ Args:
104
+ `batch`: tx.data.Batch
105
+ A batch of inputs to be passed through the network
106
+
107
+ Returns:
108
+ A dict with keys "loss" and "preds" containing the loss and
109
+ predictions on :attr:`batch` respectively.
110
+ """
91
111
input_ids = batch ["input_ids" ]
92
112
segment_ids = batch ["segment_ids" ]
93
113
labels = batch ["label_ids" ]
@@ -101,6 +121,15 @@ def forward(self, # type: ignore
101
121
return {"loss" : loss , "preds" : preds }
102
122
103
123
def predict (self , batch : tx .data .Batch ) -> Dict [str , torch .Tensor ]:
124
+ r"""Predict the labels for the :attr:`batch` of examples. This method
125
+ will be called instead of :meth:`forward` during validation or testing,
126
+ if :class:`Executor`'s :attr:`validate_mode` or :attr:`test_mode` is set
127
+ to ``"predict"`` instead of ``"eval"``.
128
+
129
+ Args:
130
+ `batch`: tx.data.Batch
131
+ A batch of inputs to run prediction on
132
+ """
104
133
input_ids = batch ["input_ids" ]
105
134
segment_ids = batch ["segment_ids" ]
106
135
@@ -112,11 +141,23 @@ def predict(self, batch: tx.data.Batch) -> Dict[str, torch.Tensor]:
112
141
113
142
114
143
class TPE :
115
- def __init__ (self , model_config = None ):
144
+ r""":class:`TPE` uses Tree-structured Parzen Estimator algorithm from
145
+ `hyperopt` for hyperparameter tuning.
146
+
147
+ Args:
148
+ model_config: Dict
149
+ A conf dict which is passed to BERT classifier
150
+ output_dir: str
151
+ A path to store the models
152
+
153
+ """
154
+ def __init__ (self , model_config : Dict , output_dir : str = "output/" ):
116
155
tx .utils .maybe_create_dir (args .output_dir )
117
156
118
157
self .model_config = model_config
119
158
159
+ self .output_dir = output_dir
160
+
120
161
# create datasets
121
162
self .train_dataset = tx .data .RecordData (
122
163
hparams = config_data .train_hparam , device = device )
@@ -150,7 +191,31 @@ def __init__(self, model_config=None):
150
191
self .optim = tx .core .BertAdam
151
192
152
193
def objective_func (self , params : Dict ):
153
-
194
+ r"""Compute a "loss" for a given hyperparameter values. This function is
195
+ passed to hyperopt's ``"fmin"`` (see the :meth:`run` method) function
196
+ and gets repeatedly called to find the best set of hyperparam values.
197
+ Below is an example of how to use this method
198
+
199
+ .. code-block:: python
200
+
201
+ import hyperopt as hpo
202
+
203
+ trials = hpo.Trials()
204
+ hpo.fmin(fn=self.objective_func,
205
+ space=space,
206
+ algo=hpo.tpe.suggest,
207
+ max_evals=3,
208
+ trials=trials)
209
+
210
+ Args:
211
+ params: Dict
212
+ A `(key, value)` dict representing the ``"value"`` to try for
213
+ the hyperparam ``"key"``
214
+
215
+ Returns:
216
+ A dict with keys "loss", "status" and "model" indicating the loss
217
+ for this trial, the status, and the path to the saved model.
218
+ """
154
219
print (f"Using { params } for trial { self .exp_number } " )
155
220
156
221
# Loads data
@@ -188,7 +253,7 @@ def objective_func(self, params: Dict):
188
253
189
254
valid_metric = metric .Accuracy (pred_name = "preds" ,
190
255
label_name = "label_ids" )
191
- checkpoint_dir = f"./models /exp{ self .exp_number } "
256
+ checkpoint_dir = f"./{ self . output_dir } /exp{ self .exp_number } "
192
257
193
258
executor = Executor (
194
259
# supply executor with the model
@@ -232,6 +297,13 @@ def objective_func(self, params: Dict):
232
297
}
233
298
234
299
def run (self , hyperparams : Dict ):
300
+ r"""Run the TPE algorithm with hyperparameters :attr:`hyperparams`
301
+
302
+ Args:
303
+ hyperparams: Dict
304
+ The `(key, value)` pairs of hyperparameters along their range of
305
+ values.
306
+ """
235
307
space = {}
236
308
for k , v in hyperparams .items ():
237
309
if isinstance (v , dict ):
@@ -258,7 +330,7 @@ def run(self, hyperparams: Dict):
258
330
def main ():
259
331
model_config = {k : v for k , v in config_downstream .items () if
260
332
k != "hyperparams" }
261
- tpe = TPE (model_config = model_config )
333
+ tpe = TPE (model_config = model_config , output_dir = args . output_dir )
262
334
hyperparams = config_downstream ["hyperparams" ]
263
335
tpe .run (hyperparams )
264
336
0 commit comments