Skip to content

Commit 68aff3e

Browse files
authored
add exponential distribution to negloglik and losses; add causal loss metric (#13)
1 parent 081f390 commit 68aff3e

File tree

4 files changed

+245
-4
lines changed

4 files changed

+245
-4
lines changed

pypsps/keras/callbacks.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,19 @@ def on_epoch_end(self, epoch, logs=None):
2525
print(log_str)
2626

2727

28-
def recommended_callbacks(monitor="val_loss") -> List[tf.keras.callbacks.Callback]:
28+
def recommended_callbacks(
29+
monitor="val_loss", patience: int = 50, mode="min"
30+
) -> List[tf.keras.callbacks.Callback]:
2931
"""Return a list of recommended callbacks.
3032
3133
This list is subject to change w/o notice. Do not rely on this in production.
3234
"""
3335
callbacks = [
34-
tf.keras.callbacks.EarlyStopping(monitor=monitor, patience=20, restore_best_weights=True),
35-
tf.keras.callbacks.ReduceLROnPlateau(patience=10),
36+
tf.keras.callbacks.EarlyStopping(
37+
monitor=monitor, patience=patience, restore_best_weights=True, mode=mode
38+
),
39+
tf.keras.callbacks.ReduceLROnPlateau(patience=patience // 3),
3640
tf.keras.callbacks.TerminateOnNaN(),
37-
VerboseNEpochs(n=10),
41+
VerboseNEpochs(n=20),
3842
]
3943
return callbacks

pypsps/keras/metrics.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import tensorflow as tf
55

66
from .. import utils
7+
from . import losses
78

89

910
@tf.keras.utils.register_keras_serializable(package="pypsps")
@@ -96,3 +97,53 @@ def predictive_state_df(y_true, y_pred) -> tf.Tensor:
9697
return pypress.utils.tr_kernel(weights)
9798

9899
return predictive_state_df
100+
101+
102+
def causal_loss_metric_gen(
103+
outcome_loss: losses.OutcomeLoss,
104+
treatment_loss: losses.TreatmentLoss,
105+
alpha: float = 1.0,
106+
outcome_loss_weight: float = 1.0,
107+
):
108+
"""
109+
Function wrapper that returns a metric function computing the causal loss.
110+
111+
The returned function takes (y_true, y_pred) as inputs and computes:
112+
113+
causal_loss = outcome_loss_weight * outcome_loss(y_true, y_pred)
114+
+ alpha * treatment_loss(y_true, y_pred)
115+
116+
This metric function can be passed to model.compile(metrics=[...]).
117+
118+
Parameters
119+
----------
120+
outcome_loss : OutcomeLoss
121+
Instance of an outcome loss (e.g. Normal log-likelihood loss).
122+
treatment_loss : TreatmentLoss
123+
Instance of a treatment loss (e.g. binary cross-entropy for treatment prediction).
124+
alpha : float, default=1.0
125+
Penalty parameter for treatment loss.
126+
outcome_loss_weight : float, default=1.0
127+
Weight for the outcome loss.
128+
129+
Returns
130+
-------
131+
function
132+
A function metric that takes (y_true, y_pred) and returns the causal loss as a float value (can be passed as metric).
133+
"""
134+
# Construct an instance of CausalLoss with the given parameters.
135+
causal_loss_obj = losses.CausalLoss(
136+
outcome_loss=outcome_loss,
137+
treatment_loss=treatment_loss,
138+
alpha=alpha,
139+
outcome_loss_weight=outcome_loss_weight,
140+
)
141+
142+
def causal_loss_metric(y_true, y_pred) -> tf.Tensor:
143+
"""Metric function computing the causal loss."""
144+
# Call the causal loss object to compute the loss per example.
145+
# Here we assume causal_loss_obj returns per-example loss.
146+
return causal_loss_obj(y_true, y_pred)
147+
148+
causal_loss_metric.__name__ = "causal_loss"
149+
return causal_loss_metric

pypsps/keras/neglogliks.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,80 @@ def call(self, y_true, y_pred):
7171
):
7272
return tf.reduce_mean(losses, axis=-1)
7373
raise NotImplementedError("reduction='%s' is not implemented", self.reduction)
74+
75+
76+
def _negloglik_exponential(
77+
event_time: tf.Tensor, event_indicator: tf.Tensor, rate: tf.Tensor
78+
) -> tf.Tensor:
79+
"""
80+
Computes the negative log-likelihood for an exponential distribution with censoring.
81+
82+
For each observation i:
83+
- If an event occurs (event_indicator[i] == 1):
84+
log-likelihood = log(rate[i]) - rate[i] * event_time[i]
85+
- If censored (event_indicator[i] == 0):
86+
log-likelihood = - rate[i] * event_time[i]
87+
88+
Therefore, the negative log-likelihood for observation i is:
89+
loss_i = rate[i] * event_time[i] - event_indicator[i] * log(rate[i])
90+
91+
Parameters
92+
----------
93+
event_time : tf.Tensor, shape (n,)
94+
The observed event or censoring times.
95+
event_indicator : tf.Tensor, shape (n,)
96+
Binary indicator (1 if event occurred, 0 if censored).
97+
rate : tf.Tensor, shape (n,)
98+
The predicted rate (λ) of the exponential distribution.
99+
100+
Returns
101+
-------
102+
tf.Tensor
103+
A tensor of shape (n,) containing the negative log-likelihood for each observation.
104+
"""
105+
rate = tf.cast(rate, tf.float32)
106+
log_rate = tf.math.log(rate + _EPS)
107+
# Ensure inputs are float32
108+
event_time = tf.cast(event_time, tf.float32)
109+
event_indicator = tf.cast(event_indicator, tf.float32)
110+
111+
# Compute the negative log likelihood per observation
112+
nll = rate * event_time - event_indicator * log_rate
113+
return nll
114+
115+
116+
class NegloglikExponential(tf.keras.losses.Loss):
117+
"""Computes the negative log-likelihood of an Exponential survival model with censorship."""
118+
119+
def __init__(
120+
self,
121+
reduction=tf.keras.losses.Reduction.AUTO,
122+
log_rate: bool = False,
123+
name="negloglik_exponential",
124+
):
125+
super().__init__(reduction=reduction, name=name)
126+
self._log_rate = log_rate
127+
128+
def call(self, y_true, y_pred):
129+
"""Implements the loss function call."""
130+
event_time = y_true[:, 0]
131+
event_indicator = y_true[:, 1]
132+
133+
if self._log_rate:
134+
y_pred = tf.exp(y_pred)
135+
136+
# y_pred is the rate
137+
losses = _negloglik_exponential(
138+
tf.squeeze(event_time), tf.squeeze(event_indicator), rate=tf.squeeze(y_pred)
139+
)
140+
141+
if self.reduction == tf.keras.losses.Reduction.NONE:
142+
return losses
143+
if self.reduction == tf.keras.losses.Reduction.SUM:
144+
return tf.reduce_sum(losses, axis=-1)
145+
if self.reduction in (
146+
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
147+
tf.keras.losses.Reduction.AUTO,
148+
):
149+
return tf.reduce_mean(losses, axis=-1)
150+
raise NotImplementedError(f"reduction='{self.reduction}' is not implemented")

pypsps/tests/test_neglogliks.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,27 @@
44

55
import numpy as np
66
import pytest
7+
import tensorflow as tf
78
import tensorflow_probability as tfp
89

910
from ..keras import neglogliks
1011

1112
tfd = tfp.distributions
1213

1314

15+
def _create_sample_data_exponential():
16+
"""
17+
Creates a simple test case with two observations:
18+
- First observation: event_time=10, event_indicator=1, log_hazard=log(0.1)
19+
- Second observation: event_time=10, event_indicator=0, log_hazard=log(0.2)
20+
"""
21+
# y_true has shape (n, 2): columns are event_time and event_indicator.
22+
y_true = tf.constant([[10.0, 1.0], [10.0, 0.0]])
23+
# y_pred has shape (n, 1): log_hazard predictions.
24+
y_pred = tf.constant([[0.1], [0.2]])
25+
return y_true, y_pred
26+
27+
1428
def _test_data() -> Tuple[np.ndarray, np.ndarray]:
1529
y_true = np.array([0.0, 1.0, 2.0])
1630
y_pred = np.array([[0.0, 1.0], [-1, 0.1], [0.1, 0.5]])
@@ -46,3 +60,98 @@ def test_negloglik_loss_class_works(reduction):
4660
print(loss_normal)
4761
print(loss_class_normal)
4862
assert loss_normal.numpy() == pytest.approx(loss_class_normal.numpy(), 0.0001)
63+
64+
65+
# --------------------------------------------------------------------
66+
# Tests for _negloglik_exponential function.
67+
# --------------------------------------------------------------------
68+
69+
70+
def test_negloglik_exponential_event():
71+
"""
72+
Test when an event occurs (event_indicator == 1).
73+
74+
For an observation with:
75+
event_time = 10,
76+
event_indicator = 1,
77+
log_hazard = log(0.1) (so rate = 0.1),
78+
the loss should be: rate*event_time - log_hazard = 0.1*10 - log(0.1).
79+
"""
80+
event_time = tf.constant([10.0])
81+
event_indicator = tf.constant([1.0])
82+
rate = tf.constant([0.1])
83+
84+
loss = neglogliks._negloglik_exponential(event_time, event_indicator, rate)
85+
expected = 0.1 * 10 - np.log(0.1)
86+
np.testing.assert_allclose(loss.numpy(), [expected], atol=1e-5)
87+
88+
89+
def test_negloglik_exponential_censored():
90+
"""
91+
Test when an observation is censored (event_indicator == 0).
92+
93+
For an observation with:
94+
event_time = 10,
95+
event_indicator = 0,
96+
log_hazard = log(0.1) (so rate = 0.1),
97+
the loss should be: rate*event_time = 0.1*10.
98+
"""
99+
event_time = tf.constant([10.0])
100+
event_indicator = tf.constant([0.0])
101+
rate = tf.constant([0.1])
102+
103+
loss = neglogliks._negloglik_exponential(event_time, event_indicator, rate)
104+
expected = 0.1 * 10
105+
np.testing.assert_allclose(loss.numpy(), [expected], atol=1e-5)
106+
107+
108+
def test_NegloglikExponential_none():
109+
"""
110+
Test NegloglikExponential with reduction NONE.
111+
112+
Expected losses:
113+
Observation 1: 0.1*10 - log(0.1)
114+
Observation 2: 0.2*10
115+
"""
116+
loss_obj = neglogliks.NegloglikExponential(reduction=tf.keras.losses.Reduction.NONE)
117+
y_true, y_pred = _create_sample_data_exponential()
118+
losses = loss_obj(y_true, y_pred)
119+
120+
expected1 = 0.1 * 10 - np.log(0.1)
121+
expected2 = 0.2 * 10
122+
expected = np.array([expected1, expected2])
123+
np.testing.assert_allclose(losses.numpy(), expected, atol=1e-5)
124+
125+
126+
def test_NegloglikExponential_sum():
127+
"""
128+
Test NegloglikExponential with reduction SUM.
129+
130+
Expected loss: sum over observations.
131+
"""
132+
loss_obj = neglogliks.NegloglikExponential(reduction=tf.keras.losses.Reduction.SUM)
133+
y_true, y_pred = _create_sample_data_exponential()
134+
loss_value = loss_obj(y_true, y_pred)
135+
136+
expected1 = 0.1 * 10 - np.log(0.1)
137+
expected2 = 0.2 * 10
138+
expected = expected1 + expected2
139+
np.testing.assert_allclose(loss_value.numpy(), expected, atol=1e-5)
140+
141+
142+
def test_NegloglikExponential_sum_over_batch_size():
143+
"""
144+
Test NegloglikExponential with reduction SUM_OVER_BATCH_SIZE.
145+
146+
Expected loss: average loss over observations.
147+
"""
148+
loss_obj = neglogliks.NegloglikExponential(
149+
reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
150+
)
151+
y_true, y_pred = _create_sample_data_exponential()
152+
loss_value = loss_obj(y_true, y_pred)
153+
154+
expected1 = 0.1 * 10 - np.log(0.1)
155+
expected2 = 0.2 * 10
156+
expected = (expected1 + expected2) / 2.0
157+
np.testing.assert_allclose(loss_value.numpy(), expected, atol=1e-5)

0 commit comments

Comments
 (0)