Skip to content

Commit af4c92f

Browse files
authored
chore: add example binary survival models (#14)
* add binary survival data and notebook example * remove surivlav data * update * remove comment * update tests
1 parent 68aff3e commit af4c92f

10 files changed

+2488
-34
lines changed

notebooks/pypsps_binary_survival_example.ipynb

+1,868
Large diffs are not rendered by default.

poetry.lock

+190-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ seaborn = "^0.13.2"
2727
scikit-learn = "^1.6.1"
2828
optuna = "^4.2.1"
2929
pydot = "^3.0.4"
30+
scikit-survival = "^0.24.0"
3031

3132
[build-system]
3233
requires = ["poetry-core>=1.0.0"]

pypsps/datasets/base.py

+7
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(
2424
true_ate: Optional[float] = None,
2525
true_ute: Optional[pd.DataFrame] = None,
2626
true_propensity_score: Optional[pd.DataFrame] = None,
27+
true_outcomes: Optional[pd.DataFrame] = None,
2728
):
2829
"""Initializes the class."""
2930
if isinstance(treatments, pd.Series):
@@ -40,12 +41,18 @@ def __init__(
4041
self.true_ate = true_ate
4142
self.true_ute = true_ute
4243
self.true_propensity_score = true_propensity_score
44+
self.true_outcomes = true_outcomes
4345

4446
def to_data_frame(self) -> pd.DataFrame:
4547
"""Returns all data as a concatenated DataFrame."""
4648
list_dfs = [self.outcomes, self.treatments, self.features]
4749
if self.latent_features is not None:
4850
list_dfs.append(self.latent_features)
51+
if self.true_outcomes is not None:
52+
list_dfs.append(self.true_outcomes)
53+
if self.true_propensity_score is not None:
54+
list_dfs.append(self.true_propensity_score)
55+
4956
return pd.concat(list_dfs, axis=1)
5057

5158
def to_keras_inputs_outputs(

pypsps/datasets/binary_survival.py

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""Toy example survival of a cancer treatment to survival times."""
2+
3+
import math
4+
5+
import numpy as np
6+
import pandas as pd
7+
from scipy.special import expit # logistic sigmoid
8+
9+
from . import base
10+
11+
_FEAT_COLS = ["gender", "age", "comorbidity", "cancer_severity"]
12+
13+
14+
def _simple_custom_uuid(val):
15+
"""
16+
Returns a simple custom UUID string for the given value.
17+
This function uses the built-in hash() and converts the absolute hash value
18+
to a base-36 string (digits and lowercase letters).
19+
"""
20+
# Get absolute hash value to avoid negative numbers.
21+
h = abs(hash(val))
22+
alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"
23+
if h == 0:
24+
return alphabet[0]
25+
s = []
26+
while h:
27+
s.append(alphabet[h % 36])
28+
h //= 36
29+
return "".join(reversed(s))
30+
31+
32+
class CancerSurvivalSimulator(base.BaseSimulator):
33+
"""Cancer survival simulation."""
34+
35+
def __init__(self, **kwargs):
36+
super().__init__(**kwargs)
37+
self._rng = np.random.RandomState(self._seed)
38+
39+
def sample(self, n_samples: int):
40+
"""Samples example dataset."""
41+
42+
# 1. Generate features
43+
# Gender: randomly assign Male/Female
44+
genders = self._rng.choice(["male", "female"], size=n_samples)
45+
46+
# Age: uniformly distributed from 30 to 80.
47+
ages = self._rng.uniform(30, 80, size=n_samples)
48+
49+
# Comorbidity: categorical; probabilities: Low (50%), Medium (30%), High (20%)
50+
comorbidity = self._rng.choice(["low", "medium", "high"], size=n_samples, p=[0.5, 0.3, 0.2])
51+
52+
# Cancer severity: uniformly from 1 to 10.
53+
cancer_severity = self._rng.uniform(1, 10, size=n_samples)
54+
55+
# 2. Treatment assignment (chemotherapy)
56+
# Logistic model: more likely if cancer_severity is high and age is low.
57+
# We'll use: lp = -0.6 + 0.1 * cancer_severity - 0.015 * age.
58+
lp = 1.0 + 0.5 * cancer_severity - 0.05 * ages
59+
p_chemo = expit(lp)
60+
chemo = self._rng.binomial(1, p_chemo, size=n_samples)
61+
62+
# 3. Simulate true recovery time from exponential distribution.
63+
# For untreated: median = 365 days => scale = 365/ln2
64+
# For treated: median = 365/2 days => scale = (365/2)/ln2
65+
# np.random.exponential uses "scale" parameter = 1/lambda = mean.
66+
scale_untreated = 365 / math.log(2) # ~527 days
67+
scale_treated = (100.0) / math.log(2) # ~263.5 days
68+
69+
# For each patient, choose scale based on treatment.
70+
scales = np.where(chemo == 1, scale_treated, scale_untreated)
71+
# Simulate recovery time from exponential distribution.
72+
true_recovery_time = self._rng.exponential(scale=scales)
73+
74+
# 4. Impose study follow-up: end at 730 days.
75+
# Observed time is min(true_recovery_time, 730)
76+
observed_time = np.minimum(true_recovery_time, 540)
77+
# Natural event indicator: 1 if true_recovery_time <= 730, else 0.
78+
event_indicator = (true_recovery_time <= 540).astype(int)
79+
80+
# 6. Assemble DataFrame
81+
df = pd.DataFrame(
82+
{
83+
"gender": genders,
84+
"age": ages,
85+
"comorbidity": comorbidity,
86+
"cancer_severity": cancer_severity,
87+
"chemotherapy": chemo,
88+
"true_recovery_time": true_recovery_time,
89+
"event_time": observed_time,
90+
"event_indicator": event_indicator,
91+
"prob_chemotherapy": p_chemo,
92+
}
93+
)
94+
df = df.sort_values("prob_chemotherapy", ascending=False)
95+
df["patient_id"] = (df.index.to_series() + 1e6).apply(_simple_custom_uuid)
96+
df = df.set_index("patient_id", verify_integrity=True)
97+
98+
df["gender"] = df["gender"].map({"male": 0, "female": 1})
99+
df["comorbidity"] = df["comorbidity"].map({"low": 0, "medium": 1, "high": 2})
100+
101+
true_ute = pd.Series((365.0 / 2.0) / math.log(2.0), index=df.index, name="true_ute")
102+
103+
return base.CausalDataset(
104+
treatments=df["chemotherapy"],
105+
outcomes=df[["event_time", "event_indicator"]],
106+
features=df[_FEAT_COLS],
107+
true_ate=true_ute.mean(),
108+
true_ute=true_ute,
109+
true_propensity_score=df["prob_chemotherapy"],
110+
true_outcomes=df["true_recovery_time"],
111+
)

pypsps/keras/callbacks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, n: int = 10):
1717
def on_epoch_end(self, epoch, logs=None):
1818
"""call at end of epoch"""
1919
# logs is a dictionary containing metric names and values.
20-
if (epoch + 1) % self.n == 0:
20+
if (epoch == 0) or ((epoch + 1) % self.n == 0):
2121
logs = logs or {}
2222
log_str = f"Epoch {epoch + 1}: " + ", ".join(
2323
f"{key}={value:.4f}" for key, value in logs.items()

pypsps/keras/metrics.py

+95-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Module for metrics from pypsps predictions."""
22

3-
import pypress
3+
import pypress.utils
44
import tensorflow as tf
55

66
from .. import utils
@@ -11,10 +11,23 @@
1111
class PropensityScoreBinaryCrossentropy(tf.keras.metrics.BinaryCrossentropy):
1212
"""Computes cross entropy for the propensity score. Used as a metric in pypsps model."""
1313

14+
def __init__(
15+
self,
16+
n_outcome_pred_cols: int,
17+
n_treatment_pred_cols: int,
18+
name="propensity_score_binary_crossentropy",
19+
**kwargs,
20+
):
21+
super().__init__(name=name)
22+
self._n_outcome_pred_cols = n_outcome_pred_cols
23+
self._n_treatment_pred_cols = n_treatment_pred_cols
24+
1425
def update_state(self, y_true, y_pred, sample_weight=None):
1526
"""Updates state."""
1627
_, _, propensity_score = utils.split_y_pred(
17-
y_pred, n_outcome_pred_cols=2, n_treatment_pred_cols=1
28+
y_pred,
29+
n_outcome_pred_cols=self._n_outcome_pred_cols,
30+
n_treatment_pred_cols=self._n_treatment_pred_cols,
1831
)
1932
treatment_true = y_true[:, 1:]
2033
super().update_state(
@@ -26,10 +39,17 @@ def update_state(self, y_true, y_pred, sample_weight=None):
2639
class PropensityScoreAUC(tf.keras.metrics.AUC):
2740
"""AUC computed on the ouptut for propensity part."""
2841

42+
def __init__(self, n_outcome_pred_cols: int, n_treatment_pred_cols: int, **kwargs):
43+
super().__init__()
44+
self._n_outcome_pred_cols = n_outcome_pred_cols
45+
self._n_treatment_pred_cols = n_treatment_pred_cols
46+
2947
def update_state(self, y_true, y_pred, sample_weight=None):
3048
"""Updates state"""
3149
_, _, propensity_score = utils.split_y_pred(
32-
y_pred, n_outcome_pred_cols=2, n_treatment_pred_cols=1
50+
y_pred,
51+
n_outcome_pred_cols=self._n_outcome_pred_cols,
52+
n_treatment_pred_cols=self._n_treatment_pred_cols,
3353
)
3454
treatment_true = y_true[:, 1:]
3555
super().update_state(
@@ -41,10 +61,26 @@ def update_state(self, y_true, y_pred, sample_weight=None):
4161
class TreatmentMeanSquaredError(tf.keras.metrics.MeanSquaredError):
4262
"""MSE computed on continuous treatment prediction."""
4363

64+
def __init__(
65+
self,
66+
n_outcome_pred_cols: int,
67+
n_treatment_pred_cols: int,
68+
n_outcome_true_cols: int,
69+
**kwargs,
70+
):
71+
super().__init__()
72+
self._n_outcome_true_cols = n_outcome_true_cols
73+
self._n_outcome_pred_cols = n_outcome_pred_cols
74+
self._n_treatment_pred_cols = n_treatment_pred_cols
75+
4476
def update_state(self, y_true, y_pred, sample_weight=None):
4577
"""Updates state"""
46-
treat_pred = utils.split_y_pred(y_pred, n_outcome_pred_cols=1, n_treatment_pred_cols=2)[2]
47-
treat_true = utils.split_y_true(y_true, n_outcome_true_cols=1)[1]
78+
treat_pred = utils.split_y_pred(
79+
y_pred,
80+
n_outcome_pred_cols=self._n_outcome_pred_cols,
81+
n_treatment_pred_cols=self._n_treatment_pred_cols,
82+
)[2]
83+
treat_true = utils.split_y_true(y_true, n_outcome_true_cols=self._n_outcome_true_cols)[1]
4884
super().update_state(y_true=treat_true, y_pred=treat_pred, sample_weight=sample_weight)
4985

5086

@@ -53,21 +89,53 @@ def update_state(self, y_true, y_pred, sample_weight=None):
5389
class TreatmentMeanAbsoluteError(tf.keras.metrics.MeanAbsoluteError):
5490
"""MSE computed on the ouptut for weighted average outcome prediction."""
5591

92+
def __init__(
93+
self,
94+
n_outcome_pred_cols: int,
95+
n_treatment_pred_cols: int,
96+
n_outcome_true_cols: int,
97+
**kwargs,
98+
):
99+
super().__init__()
100+
self._n_outcome_true_cols = n_outcome_true_cols
101+
self._n_outcome_pred_cols = n_outcome_pred_cols
102+
self._n_treatment_pred_cols = n_treatment_pred_cols
103+
56104
def update_state(self, y_true, y_pred, sample_weight=None):
57105
"""Updates state"""
58-
treat_pred = utils.split_y_pred(y_pred, n_outcome_pred_cols=1, n_treatment_pred_cols=2)[2]
59-
treat_true = utils.split_y_true(y_true, n_outcome_true_cols=1)[1]
106+
treat_pred = utils.split_y_pred(
107+
y_pred,
108+
n_outcome_pred_cols=self._n_outcome_pred_cols,
109+
n_treatment_pred_cols=self._n_treatment_pred_cols,
110+
)[2]
111+
treat_true = utils.split_y_true(y_true, n_outcome_true_cols=self._n_treatment_pred_cols)[1]
60112
super().update_state(y_true=treat_true, y_pred=treat_pred, sample_weight=sample_weight)
61113

62114

63115
@tf.keras.utils.register_keras_serializable(package="pypsps")
64116
class OutcomeMeanSquaredError(tf.keras.metrics.MeanSquaredError):
65117
"""MSE computed on the ouptut for weighted average outcome prediction."""
66118

119+
def __init__(
120+
self,
121+
n_outcome_pred_cols: int,
122+
n_treatment_pred_cols: int,
123+
n_outcome_true_cols: int,
124+
**kwargs,
125+
):
126+
super().__init__()
127+
self._n_outcome_true_cols = n_outcome_true_cols
128+
self._n_outcome_pred_cols = n_outcome_pred_cols
129+
self._n_treatment_pred_cols = n_treatment_pred_cols
130+
67131
def update_state(self, y_true, y_pred, sample_weight=None):
68132
"""Updates state"""
69-
avg_outcome = utils.agg_outcome_pred(y_pred, n_outcome_pred_cols=2, n_treatment_pred_cols=1)
70-
outcome_true = utils.split_y_true(y_true, n_outcome_true_cols=1)[0]
133+
avg_outcome = utils.agg_outcome_pred(
134+
y_pred,
135+
n_outcome_pred_cols=self._n_outcome_pred_cols,
136+
n_treatment_pred_cols=self._n_treatment_pred_cols,
137+
)
138+
outcome_true = utils.split_y_true(y_true, n_outcome_true_cols=self._n_outcome_true_cols)[0]
71139
super().update_state(y_true=outcome_true, y_pred=avg_outcome, sample_weight=sample_weight)
72140

73141

@@ -76,10 +144,26 @@ def update_state(self, y_true, y_pred, sample_weight=None):
76144
class OutcomeMeanAbsoluteError(tf.keras.metrics.MeanAbsoluteError):
77145
"""MSE computed on the ouptut for weighted average outcome prediction."""
78146

147+
def __init__(
148+
self,
149+
n_outcome_pred_cols: int,
150+
n_treatment_pred_cols: int,
151+
n_outcome_true_cols: int,
152+
**kwargs,
153+
):
154+
super().__init__()
155+
self._n_outcome_true_cols = n_outcome_true_cols
156+
self._n_outcome_pred_cols = n_outcome_pred_cols
157+
self._n_treatment_pred_cols = n_treatment_pred_cols
158+
79159
def update_state(self, y_true, y_pred, sample_weight=None):
80160
"""Updates state"""
81-
avg_outcome = utils.agg_outcome_pred(y_pred, n_outcome_pred_cols=2, n_treatment_pred_cols=1)
82-
outcome_true = utils.split_y_true(y_true, n_outcome_true_cols=1)[0]
161+
avg_outcome = utils.agg_outcome_pred(
162+
y_pred,
163+
n_outcome_pred_cols=self._n_outcome_pred_cols,
164+
n_treatment_pred_cols=self._n_treatment_pred_cols,
165+
)
166+
outcome_true = utils.split_y_true(y_true, n_outcome_true_cols=self._n_outcome_true_cols)[0]
83167
super().update_state(y_true=outcome_true, y_pred=avg_outcome, sample_weight=sample_weight)
84168

85169

0 commit comments

Comments
 (0)