Skip to content

Commit 49377e8

Browse files
committed
Fix passing of keyword args to Dense layers in create_tower
Current behavior: kwargs are passed to tf.keras.Sequential.add, so they are not passed on to tf.keras.layers.Dense as intended. For example, when passing `use_bias=False` to create_tower with the kwarg name `kernel_regularizer`, it throws an exception: Traceback (most recent call last): File "/Users/brussell/development/ranking/tensorflow_ranking/python/keras/layers_test.py", line 33, in test_create_tower_with_kwargs tower = layers.create_tower([3, 2, 1], 1, activation='relu', use_bias=False) File "/Users/brussell/development/ranking/tensorflow_ranking/python/keras/layers.py", line 70, in create_tower model.add(tf.keras.layers.Dense(units=layer_width), **kwargs) File "/usr/local/anaconda3/lib/python3.9/site-packages/tensorflow/python/trackable/base.py", line 205, in _method_wrapper result = method(self, *args, **kwargs) File "/usr/local/anaconda3/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 61, in error_handler return fn(*args, **kwargs) TypeError: add() got an unexpected keyword argument 'use_bias' test_create_tower_with_kwargs Fix: This PR fixes the behavior by shifting the closing paren of tf.keras.layers.Dense to the correct location.
1 parent dfae631 commit 49377e8

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

tensorflow_ranking/python/keras/layers.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def create_tower(hidden_layer_dims: List[int],
5757
dropout: When not `None`, the probability we will drop out a given
5858
coordinate.
5959
name: Name of the Keras layer.
60-
**kwargs: Keyword arguments for every `tf.keras.Dense` layers.
60+
**kwargs: Keyword arguments for every `tf.keras.layers.Dense` layer.
6161
6262
Returns:
6363
A `tf.keras.Sequential` object.
@@ -67,13 +67,13 @@ def create_tower(hidden_layer_dims: List[int],
6767
if input_batch_norm:
6868
model.add(tf.keras.layers.BatchNormalization(momentum=batch_norm_moment))
6969
for layer_width in hidden_layer_dims:
70-
model.add(tf.keras.layers.Dense(units=layer_width), **kwargs)
70+
model.add(tf.keras.layers.Dense(units=layer_width, **kwargs))
7171
if use_batch_norm:
7272
model.add(tf.keras.layers.BatchNormalization(momentum=batch_norm_moment))
7373
model.add(tf.keras.layers.Activation(activation=activation))
7474
if dropout:
7575
model.add(tf.keras.layers.Dropout(rate=dropout))
76-
model.add(tf.keras.layers.Dense(units=output_units), **kwargs)
76+
model.add(tf.keras.layers.Dense(units=output_units, **kwargs))
7777
return model
7878

7979

tensorflow_ranking/python/keras/layers_test.py

+4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ def test_create_tower(self):
2828
outputs = tower(inputs)
2929
self.assertAllEqual([2, 3, 1], outputs.get_shape().as_list())
3030

31+
def test_create_tower_with_bias_kwarg(self):
32+
tower = layers.create_tower([3, 2], 1), use_bias=False)
33+
tower_layers_bias = [tower.get_layer(name).use_bias for name in ['dense_1', 'dense_2']]
34+
self.assertAllEqual([False, False], tower_layers_bias)
3135

3236
class FlattenListTest(tf.test.TestCase):
3337

0 commit comments

Comments
 (0)