From 1400f70d3449ae1969ec7284fbe512dbddbaf865 Mon Sep 17 00:00:00 2001 From: brussell Date: Thu, 8 Dec 2022 08:28:22 -0800 Subject: [PATCH] Fix passing of keyword args to Dense layers in create_tower Current behavior: kwargs are passed to tf.keras.Sequential.add, so they are not passed on to tf.keras.layers.Dense as intended. For example, when passing `use_bias=False` to create_tower with the kwarg name `kernel_regularizer`, it throws an exception: Traceback (most recent call last): File "/Users/brussell/development/ranking/tensorflow_ranking/python/keras/layers_test.py", line 33, in test_create_tower_with_kwargs tower = layers.create_tower([3, 2, 1], 1, activation='relu', use_bias=False) File "/Users/brussell/development/ranking/tensorflow_ranking/python/keras/layers.py", line 70, in create_tower model.add(tf.keras.layers.Dense(units=layer_width), **kwargs) File "/usr/local/anaconda3/lib/python3.9/site-packages/tensorflow/python/trackable/base.py", line 205, in _method_wrapper result = method(self, *args, **kwargs) File "/usr/local/anaconda3/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 61, in error_handler return fn(*args, **kwargs) TypeError: add() got an unexpected keyword argument 'use_bias' test_create_tower_with_kwargs Fix: This PR fixes the behavior by shifting the closing paren of tf.keras.layers.Dense to the correct location. --- tensorflow_ranking/python/keras/layers.py | 6 +++--- tensorflow_ranking/python/keras/layers_test.py | 4 ++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tensorflow_ranking/python/keras/layers.py b/tensorflow_ranking/python/keras/layers.py index 854c73d..8aa41dd 100644 --- a/tensorflow_ranking/python/keras/layers.py +++ b/tensorflow_ranking/python/keras/layers.py @@ -57,7 +57,7 @@ def create_tower(hidden_layer_dims: List[int], dropout: When not `None`, the probability we will drop out a given coordinate. name: Name of the Keras layer. - **kwargs: Keyword arguments for every `tf.keras.Dense` layers. + **kwargs: Keyword arguments for every `tf.keras.layers.Dense` layer. Returns: A `tf.keras.Sequential` object. @@ -67,13 +67,13 @@ def create_tower(hidden_layer_dims: List[int], if input_batch_norm: model.add(tf.keras.layers.BatchNormalization(momentum=batch_norm_moment)) for layer_width in hidden_layer_dims: - model.add(tf.keras.layers.Dense(units=layer_width), **kwargs) + model.add(tf.keras.layers.Dense(units=layer_width, **kwargs)) if use_batch_norm: model.add(tf.keras.layers.BatchNormalization(momentum=batch_norm_moment)) model.add(tf.keras.layers.Activation(activation=activation)) if dropout: model.add(tf.keras.layers.Dropout(rate=dropout)) - model.add(tf.keras.layers.Dense(units=output_units), **kwargs) + model.add(tf.keras.layers.Dense(units=output_units, **kwargs)) return model diff --git a/tensorflow_ranking/python/keras/layers_test.py b/tensorflow_ranking/python/keras/layers_test.py index d66ba14..9383f89 100644 --- a/tensorflow_ranking/python/keras/layers_test.py +++ b/tensorflow_ranking/python/keras/layers_test.py @@ -28,6 +28,10 @@ def test_create_tower(self): outputs = tower(inputs) self.assertAllEqual([2, 3, 1], outputs.get_shape().as_list()) + def test_create_tower_with_bias_kwarg(self): + tower = layers.create_tower([3, 2], 1, use_bias=False) + tower_layers_bias = [tower.get_layer(name).use_bias for name in ['dense_1', 'dense_2']] + self.assertAllEqual([False, False], tower_layers_bias) class FlattenListTest(tf.test.TestCase):