Skip to content

Commit e425aef

Browse files
authored
add a build_model_binary_normal model fucntion with more arguments to tweak architecture (#7)
1 parent bfba140 commit e425aef

File tree

3 files changed

+198
-7
lines changed

3 files changed

+198
-7
lines changed

pypsps/keras/models.py

+154-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Example model architectures for pypsps."""
22

3-
from typing import List
3+
from typing import List, Tuple
44

55
import tensorflow as tf
66

@@ -31,7 +31,9 @@ def recommended_callbacks(monitor="val_loss") -> List[tf.keras.callbacks.Callbac
3131

3232

3333
def _build_binary_continuous_causal_loss(
34-
n_states: int, alpha: float = 1.0
34+
n_states: int,
35+
alpha: float,
36+
df_penalty_l1: float,
3537
) -> losses.CausalLoss:
3638
"""Builds an example of binary treatment & continuous outcome causal loss."""
3739
psps_outcome_loss = losses.OutcomeLoss(
@@ -48,7 +50,7 @@ def _build_binary_continuous_causal_loss(
4850
alpha=alpha,
4951
outcome_loss_weight=1.0,
5052
predictive_states_regularizer=pypress.keras.regularizers.DegreesOfFreedom(
51-
10.0, df=n_states - 1
53+
l1=df_penalty_l1, df=n_states - 1
5254
),
5355
reduction="sum_over_batch_size",
5456
)
@@ -60,6 +62,8 @@ def build_toy_model(
6062
n_features: int,
6163
compile: bool = True,
6264
alpha: float = 1.0,
65+
df_penalty_l1: float = 1.0,
66+
learning_rate: float = 0.01,
6367
) -> tf.keras.Model:
6468
"""Builds a pypsps toy model for binary treatment & continous outcome.
6569
@@ -72,6 +76,8 @@ def build_toy_model(
7276
n_features: number of (numeric) features to use as input.
7377
compile: if True, compiles pypsps model with the appropriate pypsps causal loss functions.
7478
alpha: propensity score penalty (by default alpha = 1., which corresponds to equal weight)
79+
df_penalty_l1: l1 parameter for the DF regularization
80+
learning_rate: learning rate of the optimizer.
7581
7682
Returns:
7783
A tf.keras Model with the pypsps architecture (compiled model if `compile=True`).
@@ -141,11 +147,154 @@ def build_toy_model(
141147
if compile:
142148

143149
psps_causal_loss = _build_binary_continuous_causal_loss(
144-
n_states=n_states, alpha=alpha
150+
n_states=n_states,
151+
alpha=alpha,
152+
df_penalty_l1=df_penalty_l1,
153+
)
154+
model.compile(
155+
loss=psps_causal_loss,
156+
optimizer=tfk.optimizers.Nadam(learning_rate=learning_rate),
157+
metrics=[
158+
metrics.PropensityScoreBinaryCrossentropy(),
159+
metrics.PropensityScoreAUC(curve="PR"),
160+
metrics.OutcomeMeanSquaredError(),
161+
],
162+
)
163+
164+
return model
165+
166+
167+
def build_model_binary_normal(
168+
n_states: int,
169+
n_features: int,
170+
predictive_state_hidden_layers: List[Tuple[int, str]],
171+
outcome_hidden_layers: List[Tuple[int, str]],
172+
loc_layer: Tuple[int, str] = None,
173+
scale_layer: Tuple[int, str] = None,
174+
compile: bool = True,
175+
alpha: float = 1.0,
176+
df_penalty_l1: float = 1.0,
177+
learning_rate: float = 0.01,
178+
dropout_rate: float = 0.2,
179+
) -> tf.keras.Model:
180+
"""Builds a pypsps toy model for binary treatment & continous outcome.
181+
182+
All pypsps keras layers can be used to build more complex causal model architectures
183+
within a TensorFlow graph. The specific model structure here is only used
184+
for proof-of-concept / demo purposes.
185+
186+
Args:
187+
n_states: number of predictive states to use in the pypsps model.
188+
n_features: number of (numeric) features to use as input.
189+
compile: if True, compiles pypsps model with the appropriate pypsps causal loss functions.
190+
alpha: propensity score penalty (by default alpha = 1., which corresponds to equal weight)
191+
df_penalty_l1: l1 parameter for the DF regularization
192+
learning_rate: learning rate of the optimizer.
193+
194+
Returns:
195+
A tf.keras Model with the pypsps architecture (compiled model if `compile=True`).
196+
"""
197+
198+
assert n_states >= 1, f"Got n_states={n_states}"
199+
assert n_features >= 1, f"Got n_features={n_features}"
200+
201+
features = tfk.layers.Input(shape=(n_features,))
202+
treat = tfk.layers.Input(shape=(1,))
203+
204+
features_bn = tfk.layers.BatchNormalization()(features)
205+
feat_treat = tfk.layers.Concatenate(name="features_and_treatment")(
206+
[features_bn, treat]
207+
)
208+
209+
ps_hidden = tf.keras.layers.Dense(
210+
predictive_state_hidden_layers[0][0], predictive_state_hidden_layers[0][1]
211+
)(features_bn)
212+
ps_hidden = tf.keras.layers.Dropout(dropout_rate)(ps_hidden)
213+
ps_hidden = tf.keras.layers.BatchNormalization()(ps_hidden)
214+
215+
for units, act in predictive_state_hidden_layers[1:]:
216+
ps_hidden = tf.keras.layers.Dense(units, act)(ps_hidden)
217+
ps_hidden = tf.keras.layers.Dropout(dropout_rate)(ps_hidden)
218+
ps_hidden = tf.keras.layers.BatchNormalization()(ps_hidden)
219+
220+
ps_hidden = tf.keras.layers.Concatenate()([ps_hidden, features_bn])
221+
pss = pypress.keras.layers.PredictiveStateSimplex(
222+
n_states=n_states, input_dim=n_features
223+
)
224+
pred_states = pss(ps_hidden)
225+
226+
# Propensity score for binary treatment (--> "sigmoid" activation).
227+
prop_score = pypress.keras.layers.PredictiveStateMeans(
228+
units=1, activation="sigmoid", name="propensity_score"
229+
)(pred_states)
230+
231+
outcome_hidden = tf.keras.layers.Dense(
232+
outcome_hidden_layers[0][0], outcome_hidden_layers[0][1]
233+
)(feat_treat)
234+
outcome_hidden = tf.keras.layers.Dropout(dropout_rate)(outcome_hidden)
235+
outcome_hidden = tf.keras.layers.BatchNormalization()(outcome_hidden)
236+
237+
for units, act in outcome_hidden_layers[1:]:
238+
outcome_hidden = tf.keras.layers.Dense(units, act)(outcome_hidden)
239+
outcome_hidden = tf.keras.layers.Dropout(dropout_rate)(outcome_hidden)
240+
outcome_hidden = tf.keras.layers.BatchNormalization()(outcome_hidden)
241+
242+
outcome_hidden = tf.keras.layers.Concatenate()([outcome_hidden, feat_treat])
243+
244+
loc_preds = []
245+
scale_preds = []
246+
# One outcome model per state.
247+
for state_id in range(n_states):
248+
loc_preds.append(
249+
tfk.layers.Dense(1, name="loc_pred_state_" + str(state_id))(
250+
tfk.layers.Dense(
251+
loc_layer[0],
252+
loc_layer[1],
253+
name="loc_feat_eng_state_" + str(state_id),
254+
)(outcome_hidden)
255+
)
256+
)
257+
258+
if scale_layer is None:
259+
# In this toy model use a constant scale estimate (BiasOnly); if needed
260+
# change this to a scale parameter that changes as a function of inputs / hidden layers.
261+
scale_preds.append(
262+
tf.keras.activations.softplus(
263+
layers.BiasOnly(name="scale_logit_" + str(state_id))(feat_treat)
264+
)
265+
)
266+
else:
267+
scale_preds.append(
268+
tfk.layers.Dense(
269+
1, activation="softplus", name="scale_pred_state_" + str(state_id)
270+
)(
271+
tfk.layers.Dense(
272+
scale_layer[0],
273+
scale_layer[1],
274+
name="scale_feat_eng_state_" + str(state_id),
275+
)(outcome_hidden)
276+
)
277+
)
278+
279+
loc_comb = tfk.layers.Concatenate(name="loc_pred_combined")(loc_preds)
280+
scale_comb = tfk.layers.Concatenate(name="scale_pred_combined")(scale_preds)
281+
282+
outputs_concat = tfk.layers.Concatenate(name="output_tensor")(
283+
[loc_comb, scale_comb, pred_states, prop_score]
284+
)
285+
286+
model = tfk.models.Model(inputs=[features, treat], outputs=outputs_concat)
287+
288+
if compile:
289+
290+
psps_causal_loss = _build_binary_continuous_causal_loss(
291+
n_states=n_states,
292+
alpha=alpha,
293+
df_penalty_l1=df_penalty_l1,
145294
)
146295
model.compile(
147296
loss=psps_causal_loss,
148-
optimizer=tfk.optimizers.Nadam(learning_rate=0.01),
297+
optimizer=tfk.optimizers.Nadam(learning_rate=learning_rate),
149298
metrics=[
150299
metrics.PropensityScoreBinaryCrossentropy(),
151300
metrics.PropensityScoreAUC(curve="PR"),

pypsps/tests/test_losses.py

-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Test module for loss functions."""
22

3-
43
import numpy as np
54
import pytest
65
import tensorflow as tf
@@ -81,6 +80,5 @@ def test_end_to_end_dataset_model_fit():
8180

8281
assert preds.shape[0] == ks_data.n_samples
8382

84-
outcome_pred, scale_pred, weights, prop_score = utils.split_y_pred(preds)
8583
ate = inference.predict_ate(model, inputs[0])
8684
assert ate > 0

pypsps/tests/test_models.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Test module for model functions."""
2+
3+
import numpy as np
4+
import pytest
5+
import tensorflow as tf
6+
import random
7+
8+
from .. import datasets
9+
from ..keras import models
10+
11+
12+
tfk = tf.keras
13+
14+
15+
def test_build_toy_model():
16+
np.random.seed(10)
17+
ks_data = datasets.KangSchafer(true_ate=10).sample(n_samples=1000)
18+
19+
inputs, outputs = ks_data.to_keras_inputs_outputs()
20+
tf.random.set_seed(10)
21+
model = models.build_toy_model(
22+
n_states=3, n_features=ks_data.n_features, compile=True
23+
)
24+
preds = model.predict(inputs)
25+
assert not np.isnan(preds.sum().sum())
26+
27+
28+
def test_build_model():
29+
np.random.seed(10)
30+
ks_data = datasets.KangSchafer(true_ate=10).sample(n_samples=1000)
31+
32+
inputs, outputs = ks_data.to_keras_inputs_outputs()
33+
tf.random.set_seed(10)
34+
model = models.build_model_binary_normal(
35+
n_states=3,
36+
n_features=ks_data.n_features,
37+
compile=True,
38+
predictive_state_hidden_layers=[(10, "selu"), (20, "relu")],
39+
outcome_hidden_layers=[(30, "tanh"), (20, "selu")],
40+
loc_layer=(20, "selu"),
41+
scale_layer=(10, "tanh"),
42+
)
43+
preds = model.predict(inputs)
44+
assert not np.isnan(preds.sum().sum())

0 commit comments

Comments
 (0)