Skip to content

Commit fa829b6

Browse files
author
“Avinash”
committed
Address review comments
1 parent e175fa6 commit fa829b6

File tree

3 files changed

+26
-28
lines changed

3 files changed

+26
-28
lines changed

examples/bert/bert_hypertuning_config_classifier.py

-11
This file was deleted.

examples/bert/bert_with_hypertuning_main.py

+20-17
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
parser = argparse.ArgumentParser()
3636
parser.add_argument(
37-
"--config-downstream", default="bert_hypertuning_config_classifier",
37+
"--config-downstream", default="config_classifier",
3838
help="Configuration of the downstream part of the model")
3939
parser.add_argument(
4040
'--pretrained-model-name', type=str, default='bert-base-uncased',
@@ -72,8 +72,9 @@
7272
class ModelWrapper(nn.Module):
7373
r"""This class wraps a model (in this case a BERT classifier) and implements
7474
:meth:`forward` and :meth:`predict` to conform to the requirements of
75-
:class:`Executor` class. Particularly, :meth:`forward` returns a dict with
76-
keys "loss" and "preds" and :meth:`predict` returns a dict with key "preds".
75+
:class:`texar.torch.run.Executor` class. Particularly, :meth:`forward`
76+
returns a dict with keys "loss" and "preds" and :meth:`predict` returns a
77+
dict with key "preds".
7778
7879
Args:
7980
`model`: BERTClassifier
@@ -96,13 +97,15 @@ def _compute_loss(self, logits, labels):
9697

9798
def forward(self, # type: ignore
9899
batch: tx.data.Batch) -> Dict[str, torch.Tensor]:
99-
r"""Run forward through the network and return a dict to be consumed
100-
by the :class:`Executor`. This method will be called by
101-
:class:``Executor` during training.
100+
r"""Run forward through the model and return a dict to be consumed
101+
by the :class:`texar.torch.run.Executor`. This method will be called by
102+
:class:`texar.torch.run.Executor` during training. See Executor's
103+
`General args <https://texar-pytorch.readthedocs.io/en/latest/code/run.html#executor-general-args>`
104+
for more details.
102105
103106
Args:
104-
`batch`: tx.data.Batch
105-
A batch of inputs to be passed through the network
107+
`batch`: :class:`texar.data.Batch`
108+
A batch of inputs to be passed through the model
106109
107110
Returns:
108111
A dict with keys "loss" and "preds" containing the loss and
@@ -123,8 +126,8 @@ def forward(self, # type: ignore
123126
def predict(self, batch: tx.data.Batch) -> Dict[str, torch.Tensor]:
124127
r"""Predict the labels for the :attr:`batch` of examples. This method
125128
will be called instead of :meth:`forward` during validation or testing,
126-
if :class:`Executor`'s :attr:`validate_mode` or :attr:`test_mode` is set
127-
to ``"predict"`` instead of ``"eval"``.
129+
if :class:`texar.torch.run.Executor`'s :attr:`validate_mode` or
130+
:attr:`test_mode` is set to ``"predict"`` instead of ``"eval"``.
128131
129132
Args:
130133
`batch`: tx.data.Batch
@@ -152,7 +155,7 @@ class TPE:
152155
153156
"""
154157
def __init__(self, model_config: Dict, output_dir: str = "output/"):
155-
tx.utils.maybe_create_dir(args.output_dir)
158+
tx.utils.maybe_create_dir(output_dir)
156159

157160
self.model_config = model_config
158161

@@ -190,8 +193,8 @@ def __init__(self, model_config: Dict, output_dir: str = "output/"):
190193

191194
self.optim = tx.core.BertAdam
192195

193-
def objective_func(self, params: Dict):
194-
r"""Compute a "loss" for a given hyperparameter values. This function is
196+
def objective_func(self, hyperparams: Dict):
197+
r"""Compute "loss" for a given hyperparameter values. This function is
195198
passed to hyperopt's ``"fmin"`` (see the :meth:`run` method) function
196199
and gets repeatedly called to find the best set of hyperparam values.
197200
Below is an example of how to use this method
@@ -208,24 +211,24 @@ def objective_func(self, params: Dict):
208211
trials=trials)
209212
210213
Args:
211-
params: Dict
214+
hyperparams: Dict
212215
A `(key, value)` dict representing the ``"value"`` to try for
213216
the hyperparam ``"key"``
214217
215218
Returns:
216219
A dict with keys "loss", "status" and "model" indicating the loss
217220
for this trial, the status, and the path to the saved model.
218221
"""
219-
print(f"Using {params} for trial {self.exp_number}")
222+
print(f"Using {hyperparams} for trial {self.exp_number}")
220223

221224
# Loads data
222225
num_train_data = config_data.num_train_data
223226
num_train_steps = int(num_train_data / config_data.train_batch_size *
224227
config_data.max_train_epoch)
225228

226229
# hyperparams
227-
num_warmup_steps = params["optimizer.warmup_steps"]
228-
static_lr = params["optimizer.static_lr"]
230+
num_warmup_steps = hyperparams["optimizer.warmup_steps"]
231+
static_lr = hyperparams["optimizer.static_lr"]
229232

230233
vars_with_decay = []
231234
vars_without_decay = []

examples/bert/config_classifier.py

+6
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,9 @@
33
clas_strategy = "cls_time"
44
dropout = 0.1
55
num_classes = 2
6+
7+
# This hyperparams is used in bert_with_hypertuning_main.py example
8+
hyperparams = {
9+
"optimizer.warmup_steps": {"start": 10000, "end": 20000, "dtype": int},
10+
"optimizer.static_lr": {"start": 1e-3, "end": 1e-2, "dtype": float}
11+
}

0 commit comments

Comments
 (0)