Skip to content

Commit c885c1c

Browse files
authored
Merge pull request #437 from mlr-org/rm_scores
Updates in several scores
2 parents 1e11c7f + 8864d62 commit c885c1c

File tree

100 files changed

+1479
-2006
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

100 files changed

+1479
-2006
lines changed

DESCRIPTION

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: mlr3proba
22
Title: Probabilistic Supervised Learning for 'mlr3'
3-
Version: 0.8.0
3+
Version: 0.8.1
44
Authors@R: c(
55
person("Raphael", "Sonabend", , "[email protected]", role = "aut",
66
comment = c(ORCID = "0000-0001-9225-4654")),
@@ -153,16 +153,14 @@ Collate:
153153
'autoplot.R'
154154
'bibentries.R'
155155
'breslow.R'
156-
'cindex.R'
157156
'data.R'
157+
'helper_measures.R'
158158
'helpers.R'
159159
'histogram.R'
160-
'integrated_scores.R'
161160
'mlr3proba-package.R'
162161
'pecs.R'
163162
'pipelines.R'
164163
'plot_probregr.R'
165-
'scoring_rule_erv.R'
166-
'surv_measures.R'
167164
'surv_return.R'
165+
'weighted_survival_score.R'
168166
'zzz.R'

NAMESPACE

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ S3method(pecs,PredictionSurv)
3939
S3method(pecs,list)
4040
S3method(plot,TaskDens)
4141
S3method(plot,TaskSurv)
42-
export(.c_weight_survival_score)
4342
export(.surv_return)
4443
export(LearnerCompRisks)
4544
export(LearnerCompRisksAalenJohansen)

NEWS.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# mlr3proba 0.8.1
2+
3+
* feat: `surv.logloss` and `surv.rcll` now use linear interpolation of S(t) to calculate the density f(t)
4+
* fix: `surv.mae`/`surv.mse`/`surv.rmse` scores return `NA` when test set has only censored observations
5+
* fix: fix bug in msr(`surv.brier`) that resulted in 0 division instead of `eps` division (`Inf` values are filtered out so this was kinda masking the inflation of ISBS)
6+
* refactor: remove `se` argument from most of the scores (not practically used)
7+
* refactor: remove `method` argument from integrated survival scores (the previous default, `method = 2`, time-weighted integration, is now always used)
8+
* **BREACKING CHANGE**: we removed all experimental `proper` scoring rules (and `remove_obs` argument).
9+
Scores yield the same results as before with the default option `proper = FALSE`
10+
* refactor: all private functions start with `.` now and are adequately (privately) documented. Code was refactored for clarity
11+
* refactor: all internal `Rcpp` measure functions
12+
* refine doc in lots of measures
13+
114
# mlr3proba 0.8.0
215

316
* Compatibility with `mlr3` v1.0.0 (`weights_learner`) and `mlr3pipelines` v0.8.0

R/MeasureCompRisksAUC.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,10 @@ MeasureCompRisksAUC = R6Class(
9797
cif_mat = cif[[as.character(cause)]]
9898

9999
# get CIF on the time horizon
100-
mat = interpolate_cif(cif_mat, new_times = time_horizon)
100+
mat = .interp_cif(cif_mat, eval_times = time_horizon)
101101

102102
# calculate AUC(t) score
103-
res = riskRegression_score(
103+
res = .riskRegr_score(
104104
mat_list = list(mat),
105105
metric = "auc",
106106
data = data,
@@ -118,10 +118,10 @@ MeasureCompRisksAUC = R6Class(
118118
cif_mat = cif[[cause]]
119119

120120
# get CIF on the time horizon
121-
mat = interpolate_cif(cif_mat, new_times = time_horizon)
121+
mat = .interp_cif(cif_mat, eval_times = time_horizon)
122122

123123
# calculate AUC(t) score
124-
res = riskRegression_score(
124+
res = .riskRegr_score(
125125
mat_list = list(mat),
126126
metric = "auc",
127127
data = data,

R/MeasureSurvAUC.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ MeasureSurvAUC = R6Class("MeasureSurvAUC",
2222

2323
super$initialize(
2424
id = id,
25-
range = 0:1,
25+
range = c(0, 1),
2626
minimize = FALSE,
2727
packages = "survAUC",
2828
predict_type = "lp",
@@ -36,7 +36,6 @@ MeasureSurvAUC = R6Class("MeasureSurvAUC",
3636

3737
private = list(
3838
.score = function(prediction, learner, task, train_set, FUN, ...) {
39-
4039
args = list()
4140
ps = self$param_set$values
4241

R/MeasureSurvCalibrationAlpha.R

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
#'
77
#' @description
88
#' This calibration method is defined by estimating
9-
#' \deqn{\hat{\alpha} = \sum \delta_i / \sum H_i(T_i)}
10-
#' where \eqn{\delta} is the observed censoring indicator from the test data,
11-
#' \eqn{H_i} is the predicted cumulative hazard, and \eqn{T_i} is the observed
12-
#' survival time (event or censoring).
9+
#' \deqn{\hat{\alpha} = \frac{\sum_{i=1}^n \delta_i}{\sum_{i=1}^n H_i(T_i)}}
10+
#' where \eqn{\delta} is the observed censoring indicator from the test data
11+
#' \eqn{n} observations), \eqn{H_i} is the predicted cumulative hazard, and \eqn{T_i}
12+
#' is the observed survival time (event or censoring).
1313
#'
1414
#' The standard error is given by
15-
#' \deqn{\hat{\alpha_{se}} = exp(1/\sqrt{\sum \delta_i})}
15+
#' \deqn{\hat{\alpha_{se}} = e^{1/\sqrt{\sum \delta_i}}}
1616
#'
1717
#' The model is well calibrated if the estimated \eqn{\hat{\alpha}} coefficient
1818
#' (returned score) is equal to 1.
@@ -75,11 +75,11 @@ MeasureSurvCalibrationAlpha = R6Class("MeasureSurvCalibrationAlpha",
7575
truth = prediction$truth
7676
all_times = truth[, 1L] # both event times and censoring times
7777
status = truth[, 2L]
78-
deaths = sum(status)
78+
n_events = sum(status)
7979

8080
ps = self$param_set$values
8181
if (ps$se) {
82-
return(exp(1 / sqrt(deaths)))
82+
return(exp(1 / sqrt(n_events)))
8383
} else {
8484
distr = prediction$data$distr
8585

@@ -113,7 +113,7 @@ MeasureSurvCalibrationAlpha = R6Class("MeasureSurvCalibrationAlpha",
113113
# Inf => case where censoring occurs at last time point
114114
# 0 => case where survival probabilities are all 1
115115
cumhaz[cumhaz == Inf | cumhaz == 0] = ps$eps
116-
out = deaths / sum(cumhaz)
116+
out = n_events / sum(cumhaz)
117117

118118
if (ps$method == "diff") {
119119
out = abs(1 - out)

R/MeasureSurvCalibrationBeta.R

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#' This calibration method fits the predicted linear predictor from a Cox PH
77
#' model as the only predictor in a new Cox PH model with the test data as
88
#' the response.
9-
#' \deqn{h(t|x) = h_0(t)exp(\beta \times lp)}
9+
#' \deqn{h(t|x) = h_0(t)e^{\beta \times lp}}
1010
#' where \eqn{lp} is the predicted linear predictor on the test data.
1111
#'
1212
#' The model is well calibrated if the estimated \eqn{\hat{\beta}} coefficient
@@ -56,7 +56,8 @@ MeasureSurvCalibrationBeta = R6Class("MeasureSurvCalibrationBeta",
5656
predict_type = "lp",
5757
label = "Van Houwelingen's Beta",
5858
man = "mlr3proba::mlr_measures_surv.calib_beta",
59-
param_set = ps
59+
param_set = ps,
60+
properties = "na_score"
6061
)
6162
}
6263
),
@@ -68,21 +69,22 @@ MeasureSurvCalibrationBeta = R6Class("MeasureSurvCalibrationBeta",
6869

6970
if (inherits(fit, "try-error")) {
7071
return(NA)
71-
} else {
72-
ps = self$param_set$values
72+
}
7373

74-
if (ps$se) {
75-
return(fit$coefficients[, "se(coef)"])
76-
} else {
77-
out = fit$coefficients[, "coef"]
74+
ps = self$param_set$values
7875

79-
if (ps$method == "diff") {
80-
out = abs(1 - out)
81-
}
76+
if (ps$se) {
77+
return(fit$coefficients[, "se(coef)"])
78+
} else {
79+
out = fit$coefficients[, "coef"]
8280

83-
return(out)
81+
if (ps$method == "diff") {
82+
out = abs(1 - out)
8483
}
84+
85+
return(out)
8586
}
87+
8688
}
8789
)
8890
)

R/MeasureSurvChamblessAUC.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ MeasureSurvChamblessAUC = R6Class("MeasureSurvChamblessAUC",
4242
private = list(
4343
.score = function(prediction, learner, task, train_set, ...) {
4444
if (!inherits(learner, "LearnerSurvCoxPH")) {
45-
stop("surv.chambless_auc only compatible with Cox PH models")
45+
stop("Only compatible with Cox PH models")
4646
}
47+
4748
ps = self$param_set$values
4849
if (!ps$integrated) {
4950
msg = "If `integrated=FALSE` then `times` should be a scalar numeric."

R/MeasureSurvCindex.R

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
#' Weighting applied to tied rankings, default is to give them half (0.5) weighting.
4545
#'
4646
#' @references
47-
#' `r format_bib("peto_1972", "harrell_1982", "goenen_2005", "schemper_2009", "uno_2011")`
47+
#' `r format_bib("peto_1972", "harrell_1982", "gonen_2005", "schemper_2009", "uno_2011")`
4848
#'
4949
#' @template param_range
5050
#' @template param_minimize
@@ -90,11 +90,10 @@ MeasureSurvCindex = R6Class("MeasureSurvCindex",
9090

9191
super$initialize(
9292
id = "surv.cindex",
93-
range = 0:1,
93+
range = c(0, 1),
9494
minimize = FALSE,
95-
packages = character(),
9695
predict_type = "crank",
97-
properties = character(),
96+
properties = "na_score",
9897
label = "Concordance Index",
9998
man = "mlr3proba::mlr_measures_surv.cindex",
10099
param_set = ps
@@ -108,43 +107,39 @@ MeasureSurvCindex = R6Class("MeasureSurvCindex",
108107
.score = function(prediction, task, train_set, ...) {
109108
ps = self$param_set$values
110109

111-
# calculate t_max (cutoff time horizon)
112-
if (is.null(ps$t_max) && !is.null(ps$p_max)) {
110+
# Determine cutoff time horizon (t_max)
111+
t_max = ps$t_max
112+
if (is.null(t_max) && !is.null(ps$p_max)) {
113113
truth = prediction$truth
114-
unique_times = unique(sort(truth[, "time"]))
114+
unique_times = unique(sort(truth[, 1L]))
115115
surv = survival::survfit(truth ~ 1)
116-
indx = which(1 - (surv$n.risk / surv$n) > ps$p_max)
117-
if (length(indx) == 0L) {
118-
t_max = NULL # t_max calculated in `cindex()`
119-
} else {
120-
# first time point that surpasses the specified
121-
# `p_max` proportion of censoring
122-
t_max = surv$time[indx[1L]]
123-
}
124-
} else {
125-
t_max = ps$t_max
116+
censored_proportion = 1 - (surv$n.risk / surv$n)
117+
indx = which(censored_proportion > ps$p_max)
118+
119+
# First time point that surpasses `p_max` censoring
120+
t_max = if (length(indx) > 0L) surv$time[indx[1L]] else NULL
126121
}
127122

128-
if (ps$weight_meth == "GH") {
129-
return(gonen(prediction$crank, ps$tiex))
130-
} else if (ps$weight_meth == "I") {
131-
return(cindex(prediction$truth, prediction$crank, t_max, ps$weight_meth, ps$tiex))
132-
} else {
133-
if (is.null(task) | is.null(train_set)) {
134-
stop("'task' and 'train_set' required for all weighted C-indexes (except GH).")
135-
}
136-
return(cindex(prediction$truth, prediction$crank, t_max, ps$weight_meth,
137-
ps$tiex, task$truth(train_set), ps$eps))
123+
# Select weighting method
124+
weight_meth = ps$weight_meth
125+
126+
if (weight_meth == "I") {
127+
return(.cindex(prediction$truth, prediction$crank, t_max, weight_meth, ps$tiex))
138128
}
129+
130+
if (weight_meth == "GH") {
131+
return(.gonen(prediction$crank, ps$tiex))
132+
}
133+
134+
# All other methods require task and train_set
135+
if (is.null(task) || is.null(train_set)) {
136+
stopf("'task' and 'train_set' are required for weighted C-index method '%s'", weight_meth)
137+
}
138+
139+
train_truth = task$truth(train_set)
140+
.cindex(prediction$truth, prediction$crank, t_max, weight_meth, ps$tiex, train_truth, ps$eps)
139141
}
140142
)
141143
)
142144

143-
gonen = function(crank, tiex) {
144-
assert_numeric(crank, any.missing = FALSE)
145-
assert_number(tiex)
146-
147-
c_gonen(sort(crank), tiex)
148-
}
149-
150145
register_measure("surv.cindex", MeasureSurvCindex)

R/MeasureSurvDCalibration.R

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -89,26 +89,13 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration",
8989
true_times = prediction$truth[, 1L]
9090

9191
# predict individual probability of death at observed event time
92-
# bypass distr6 construction if possible
93-
if (inherits(prediction$data$distr, "array")) {
94-
surv = prediction$data$distr
95-
if (length(dim(surv)) == 3) {
96-
# survival 3d array, extract median
97-
surv = .ext_surv_mat(arr = surv, which.curve = 0.5)
98-
}
99-
times = as.numeric(colnames(surv))
92+
surv = .get_surv_matrix(prediction)
93+
times = as.numeric(colnames(surv))
10094

101-
extend_times = getFromNamespace("C_Vec_WeightedDiscreteCdf", ns = "distr6")
102-
si = diag(extend_times(true_times, times, cdf = t(1 - surv), FALSE, FALSE))
103-
} else {
104-
distr = prediction$distr
105-
if (inherits(distr, c("Matdist", "Arrdist"))) {
106-
si = diag(distr$survival(true_times))
107-
} else { # VectorDistribution or single Distribution, e.g. WeightDisc()
108-
si = as.numeric(distr$survival(data = matrix(true_times, nrow = 1L)))
109-
}
110-
}
111-
# remove zeros
95+
extend_times = getFromNamespace("C_Vec_WeightedDiscreteCdf", ns = "distr6")
96+
si = diag(extend_times(true_times, times, cdf = t(1 - surv), FALSE, FALSE))
97+
98+
# replace zeros
11299
si = map_dbl(si, function(.x) max(.x, 1e-5))
113100
# index of associated bucket
114101
js = ceiling(B * si)

0 commit comments

Comments
 (0)