1
1
"""Module for metrics from pypsps predictions."""
2
2
3
- import pypress
3
+ import pypress . utils
4
4
import tensorflow as tf
5
5
6
6
from .. import utils
11
11
class PropensityScoreBinaryCrossentropy (tf .keras .metrics .BinaryCrossentropy ):
12
12
"""Computes cross entropy for the propensity score. Used as a metric in pypsps model."""
13
13
14
+ def __init__ (
15
+ self ,
16
+ n_outcome_pred_cols : int ,
17
+ n_treatment_pred_cols : int ,
18
+ name = "propensity_score_binary_crossentropy" ,
19
+ ** kwargs ,
20
+ ):
21
+ super ().__init__ (name = name )
22
+ self ._n_outcome_pred_cols = n_outcome_pred_cols
23
+ self ._n_treatment_pred_cols = n_treatment_pred_cols
24
+
14
25
def update_state (self , y_true , y_pred , sample_weight = None ):
15
26
"""Updates state."""
16
27
_ , _ , propensity_score = utils .split_y_pred (
17
- y_pred , n_outcome_pred_cols = 2 , n_treatment_pred_cols = 1
28
+ y_pred ,
29
+ n_outcome_pred_cols = self ._n_outcome_pred_cols ,
30
+ n_treatment_pred_cols = self ._n_treatment_pred_cols ,
18
31
)
19
32
treatment_true = y_true [:, 1 :]
20
33
super ().update_state (
@@ -26,10 +39,17 @@ def update_state(self, y_true, y_pred, sample_weight=None):
26
39
class PropensityScoreAUC (tf .keras .metrics .AUC ):
27
40
"""AUC computed on the ouptut for propensity part."""
28
41
42
+ def __init__ (self , n_outcome_pred_cols : int , n_treatment_pred_cols : int , ** kwargs ):
43
+ super ().__init__ ()
44
+ self ._n_outcome_pred_cols = n_outcome_pred_cols
45
+ self ._n_treatment_pred_cols = n_treatment_pred_cols
46
+
29
47
def update_state (self , y_true , y_pred , sample_weight = None ):
30
48
"""Updates state"""
31
49
_ , _ , propensity_score = utils .split_y_pred (
32
- y_pred , n_outcome_pred_cols = 2 , n_treatment_pred_cols = 1
50
+ y_pred ,
51
+ n_outcome_pred_cols = self ._n_outcome_pred_cols ,
52
+ n_treatment_pred_cols = self ._n_treatment_pred_cols ,
33
53
)
34
54
treatment_true = y_true [:, 1 :]
35
55
super ().update_state (
@@ -41,10 +61,26 @@ def update_state(self, y_true, y_pred, sample_weight=None):
41
61
class TreatmentMeanSquaredError (tf .keras .metrics .MeanSquaredError ):
42
62
"""MSE computed on continuous treatment prediction."""
43
63
64
+ def __init__ (
65
+ self ,
66
+ n_outcome_pred_cols : int ,
67
+ n_treatment_pred_cols : int ,
68
+ n_outcome_true_cols : int ,
69
+ ** kwargs ,
70
+ ):
71
+ super ().__init__ ()
72
+ self ._n_outcome_true_cols = n_outcome_true_cols
73
+ self ._n_outcome_pred_cols = n_outcome_pred_cols
74
+ self ._n_treatment_pred_cols = n_treatment_pred_cols
75
+
44
76
def update_state (self , y_true , y_pred , sample_weight = None ):
45
77
"""Updates state"""
46
- treat_pred = utils .split_y_pred (y_pred , n_outcome_pred_cols = 1 , n_treatment_pred_cols = 2 )[2 ]
47
- treat_true = utils .split_y_true (y_true , n_outcome_true_cols = 1 )[1 ]
78
+ treat_pred = utils .split_y_pred (
79
+ y_pred ,
80
+ n_outcome_pred_cols = self ._n_outcome_pred_cols ,
81
+ n_treatment_pred_cols = self ._n_treatment_pred_cols ,
82
+ )[2 ]
83
+ treat_true = utils .split_y_true (y_true , n_outcome_true_cols = self ._n_outcome_true_cols )[1 ]
48
84
super ().update_state (y_true = treat_true , y_pred = treat_pred , sample_weight = sample_weight )
49
85
50
86
@@ -53,21 +89,53 @@ def update_state(self, y_true, y_pred, sample_weight=None):
53
89
class TreatmentMeanAbsoluteError (tf .keras .metrics .MeanAbsoluteError ):
54
90
"""MSE computed on the ouptut for weighted average outcome prediction."""
55
91
92
+ def __init__ (
93
+ self ,
94
+ n_outcome_pred_cols : int ,
95
+ n_treatment_pred_cols : int ,
96
+ n_outcome_true_cols : int ,
97
+ ** kwargs ,
98
+ ):
99
+ super ().__init__ ()
100
+ self ._n_outcome_true_cols = n_outcome_true_cols
101
+ self ._n_outcome_pred_cols = n_outcome_pred_cols
102
+ self ._n_treatment_pred_cols = n_treatment_pred_cols
103
+
56
104
def update_state (self , y_true , y_pred , sample_weight = None ):
57
105
"""Updates state"""
58
- treat_pred = utils .split_y_pred (y_pred , n_outcome_pred_cols = 1 , n_treatment_pred_cols = 2 )[2 ]
59
- treat_true = utils .split_y_true (y_true , n_outcome_true_cols = 1 )[1 ]
106
+ treat_pred = utils .split_y_pred (
107
+ y_pred ,
108
+ n_outcome_pred_cols = self ._n_outcome_pred_cols ,
109
+ n_treatment_pred_cols = self ._n_treatment_pred_cols ,
110
+ )[2 ]
111
+ treat_true = utils .split_y_true (y_true , n_outcome_true_cols = self ._n_treatment_pred_cols )[1 ]
60
112
super ().update_state (y_true = treat_true , y_pred = treat_pred , sample_weight = sample_weight )
61
113
62
114
63
115
@tf .keras .utils .register_keras_serializable (package = "pypsps" )
64
116
class OutcomeMeanSquaredError (tf .keras .metrics .MeanSquaredError ):
65
117
"""MSE computed on the ouptut for weighted average outcome prediction."""
66
118
119
+ def __init__ (
120
+ self ,
121
+ n_outcome_pred_cols : int ,
122
+ n_treatment_pred_cols : int ,
123
+ n_outcome_true_cols : int ,
124
+ ** kwargs ,
125
+ ):
126
+ super ().__init__ ()
127
+ self ._n_outcome_true_cols = n_outcome_true_cols
128
+ self ._n_outcome_pred_cols = n_outcome_pred_cols
129
+ self ._n_treatment_pred_cols = n_treatment_pred_cols
130
+
67
131
def update_state (self , y_true , y_pred , sample_weight = None ):
68
132
"""Updates state"""
69
- avg_outcome = utils .agg_outcome_pred (y_pred , n_outcome_pred_cols = 2 , n_treatment_pred_cols = 1 )
70
- outcome_true = utils .split_y_true (y_true , n_outcome_true_cols = 1 )[0 ]
133
+ avg_outcome = utils .agg_outcome_pred (
134
+ y_pred ,
135
+ n_outcome_pred_cols = self ._n_outcome_pred_cols ,
136
+ n_treatment_pred_cols = self ._n_treatment_pred_cols ,
137
+ )
138
+ outcome_true = utils .split_y_true (y_true , n_outcome_true_cols = self ._n_outcome_true_cols )[0 ]
71
139
super ().update_state (y_true = outcome_true , y_pred = avg_outcome , sample_weight = sample_weight )
72
140
73
141
@@ -76,10 +144,26 @@ def update_state(self, y_true, y_pred, sample_weight=None):
76
144
class OutcomeMeanAbsoluteError (tf .keras .metrics .MeanAbsoluteError ):
77
145
"""MSE computed on the ouptut for weighted average outcome prediction."""
78
146
147
+ def __init__ (
148
+ self ,
149
+ n_outcome_pred_cols : int ,
150
+ n_treatment_pred_cols : int ,
151
+ n_outcome_true_cols : int ,
152
+ ** kwargs ,
153
+ ):
154
+ super ().__init__ ()
155
+ self ._n_outcome_true_cols = n_outcome_true_cols
156
+ self ._n_outcome_pred_cols = n_outcome_pred_cols
157
+ self ._n_treatment_pred_cols = n_treatment_pred_cols
158
+
79
159
def update_state (self , y_true , y_pred , sample_weight = None ):
80
160
"""Updates state"""
81
- avg_outcome = utils .agg_outcome_pred (y_pred , n_outcome_pred_cols = 2 , n_treatment_pred_cols = 1 )
82
- outcome_true = utils .split_y_true (y_true , n_outcome_true_cols = 1 )[0 ]
161
+ avg_outcome = utils .agg_outcome_pred (
162
+ y_pred ,
163
+ n_outcome_pred_cols = self ._n_outcome_pred_cols ,
164
+ n_treatment_pred_cols = self ._n_treatment_pred_cols ,
165
+ )
166
+ outcome_true = utils .split_y_true (y_true , n_outcome_true_cols = self ._n_outcome_true_cols )[0 ]
83
167
super ().update_state (y_true = outcome_true , y_pred = avg_outcome , sample_weight = sample_weight )
84
168
85
169
0 commit comments