34
34
35
35
parser = argparse .ArgumentParser ()
36
36
parser .add_argument (
37
- "--config-downstream" , default = "bert_hypertuning_config_classifier " ,
37
+ "--config-downstream" , default = "config_classifier " ,
38
38
help = "Configuration of the downstream part of the model" )
39
39
parser .add_argument (
40
40
'--pretrained-model-name' , type = str , default = 'bert-base-uncased' ,
72
72
class ModelWrapper (nn .Module ):
73
73
r"""This class wraps a model (in this case a BERT classifier) and implements
74
74
:meth:`forward` and :meth:`predict` to conform to the requirements of
75
- :class:`Executor` class. Particularly, :meth:`forward` returns a dict with
76
- keys "loss" and "preds" and :meth:`predict` returns a dict with key "preds".
75
+ :class:`texar.torch.run.Executor` class. Particularly, :meth:`forward`
76
+ returns a dict with keys "loss" and "preds" and :meth:`predict` returns a
77
+ dict with key "preds".
77
78
78
79
Args:
79
80
`model`: BERTClassifier
@@ -96,13 +97,15 @@ def _compute_loss(self, logits, labels):
96
97
97
98
def forward (self , # type: ignore
98
99
batch : tx .data .Batch ) -> Dict [str , torch .Tensor ]:
99
- r"""Run forward through the network and return a dict to be consumed
100
- by the :class:`Executor`. This method will be called by
101
- :class:``Executor` during training.
100
+ r"""Run forward through the model and return a dict to be consumed
101
+ by the :class:`texar.torch.run.Executor`. This method will be called by
102
+ :class:`texar.torch.run.Executor` during training. See Executor's
103
+ `General args <https://texar-pytorch.readthedocs.io/en/latest/code/run.html#executor-general-args>`
104
+ for more details.
102
105
103
106
Args:
104
- `batch`: tx .data.Batch
105
- A batch of inputs to be passed through the network
107
+ `batch`: :class:`texar .data.Batch`
108
+ A batch of inputs to be passed through the model
106
109
107
110
Returns:
108
111
A dict with keys "loss" and "preds" containing the loss and
@@ -123,8 +126,8 @@ def forward(self, # type: ignore
123
126
def predict (self , batch : tx .data .Batch ) -> Dict [str , torch .Tensor ]:
124
127
r"""Predict the labels for the :attr:`batch` of examples. This method
125
128
will be called instead of :meth:`forward` during validation or testing,
126
- if :class:`Executor`'s :attr:`validate_mode` or :attr:`test_mode` is set
127
- to ``"predict"`` instead of ``"eval"``.
129
+ if :class:`texar.torch.run. Executor`'s :attr:`validate_mode` or
130
+ :attr:`test_mode` is set to ``"predict"`` instead of ``"eval"``.
128
131
129
132
Args:
130
133
`batch`: tx.data.Batch
@@ -152,7 +155,7 @@ class TPE:
152
155
153
156
"""
154
157
def __init__ (self , model_config : Dict , output_dir : str = "output/" ):
155
- tx .utils .maybe_create_dir (args . output_dir )
158
+ tx .utils .maybe_create_dir (output_dir )
156
159
157
160
self .model_config = model_config
158
161
@@ -190,8 +193,8 @@ def __init__(self, model_config: Dict, output_dir: str = "output/"):
190
193
191
194
self .optim = tx .core .BertAdam
192
195
193
- def objective_func (self , params : Dict ):
194
- r"""Compute a "loss" for a given hyperparameter values. This function is
196
+ def objective_func (self , hyperparams : Dict ):
197
+ r"""Compute "loss" for a given hyperparameter values. This function is
195
198
passed to hyperopt's ``"fmin"`` (see the :meth:`run` method) function
196
199
and gets repeatedly called to find the best set of hyperparam values.
197
200
Below is an example of how to use this method
@@ -208,24 +211,24 @@ def objective_func(self, params: Dict):
208
211
trials=trials)
209
212
210
213
Args:
211
- params : Dict
214
+ hyperparams : Dict
212
215
A `(key, value)` dict representing the ``"value"`` to try for
213
216
the hyperparam ``"key"``
214
217
215
218
Returns:
216
219
A dict with keys "loss", "status" and "model" indicating the loss
217
220
for this trial, the status, and the path to the saved model.
218
221
"""
219
- print (f"Using { params } for trial { self .exp_number } " )
222
+ print (f"Using { hyperparams } for trial { self .exp_number } " )
220
223
221
224
# Loads data
222
225
num_train_data = config_data .num_train_data
223
226
num_train_steps = int (num_train_data / config_data .train_batch_size *
224
227
config_data .max_train_epoch )
225
228
226
229
# hyperparams
227
- num_warmup_steps = params ["optimizer.warmup_steps" ]
228
- static_lr = params ["optimizer.static_lr" ]
230
+ num_warmup_steps = hyperparams ["optimizer.warmup_steps" ]
231
+ static_lr = hyperparams ["optimizer.static_lr" ]
229
232
230
233
vars_with_decay = []
231
234
vars_without_decay = []
0 commit comments