Skip to content

Commit 372daa7

Browse files
committed
discrete super-learner and removing failed learners from ensemble
1 parent f031d85 commit 372daa7

File tree

2 files changed

+52
-12
lines changed

2 files changed

+52
-12
lines changed

R/learner.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,12 @@ learner <- R6::R6Class("learner", # nolint
301301
}
302302
),
303303
active = list(
304+
clear = function() invisible(private$fitted <- NULL),
304305
#' @field fit Return estimated model object.
305-
fit = function() private$fitted,
306+
fit = function(value) {
307+
if (missing(value)) return(private$fitted)
308+
else private$fitted <- NULL
309+
},
306310
#' @field formula Return model formula. Use [learner$update()][learner] to
307311
#' update the formula.
308312
formula = function() private$.formula

R/superlearner.R

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,33 @@
1-
metalearner_nnls <- function(y, pred, method = "nnls") {
1+
metalearner_nnls <- function(y, pred, method = "nnls", ...) {
22
if (NCOL(pred)==1) return(1.0)
3+
idx <- which(apply(pred, 2, \(x) !any(is.na(x))))
4+
coefs <- rep(0, ncol(pred))
5+
pred <- pred[, idx, drop = FALSE]
36
if (method == "nnls") {
47
res <- nnls::nnls(A = pred, b = y)
5-
coefs <- res$x
8+
coefs[idx] <- res$x
69
} else {
710
res <- glmnet::glmnet(
811
y = y, x = pred,
912
intercept = FALSE,
1013
lambda = 0,
1114
lower.limits = rep(0, ncol(pred))
1215
)
13-
coefs <- as.vector(coef(res))[-1]
16+
coefs[idx] <- as.vector(coef(res))[-1]
1417
}
1518
if (any(is.na(coefs))) coefs[is.na(coefs)] <- 0
1619
if (all(coefs == 0)) coefs[1] <- 1
1720
return(coefs / sum(coefs))
1821
}
1922

23+
metalearner_discrete <- function(y, pred, risk, ...) {
24+
weights <- rep(0, NCOL(pred))
25+
risk[is.na] <- Inf
26+
weights[which.min(risk)[1]] <- 1
27+
return(weights)
28+
}
29+
30+
2031
get_learner_names <- function(model.list, name.prefix) {
2132
.names <- names(model.list)
2233
if (is.null(.names)) .names <- rep("", length(model.list))
@@ -93,9 +104,26 @@ superlearner <- function(learners,
93104
name.prefix = NULL,
94105
...) {
95106
pred_mod <- function(models, data) {
96-
res <- lapply(models, \(x) x$predict(data))
97-
return(Reduce(cbind, res))
107+
n <- nrow(data)
108+
res <- matrix(NA, nrow=n, ncol=length(models))
109+
for (i in seq_along(models)) {
110+
if (!is.null(models[[i]]$fit)) {
111+
res[, i] <- tryCatch(
112+
models[[i]]$predict(data), error=function(x) rep(NA, n)
113+
)
114+
}
115+
}
116+
return(res)
117+
}
118+
est_mod <- function(models, data) {
119+
for (i in seq_along(models)) {
120+
v <- tryCatch(models[[i]]$estimate(data), error=function(x) NULL)
121+
if (is.null(v)) {
122+
models[[i]]$fit <- NULL
123+
}
124+
}
98125
}
126+
99127
if (is.character(model.score)) {
100128
model.score <- get(model.score)
101129
}
@@ -120,7 +148,7 @@ superlearner <- function(learners,
120148
test <- data[fold, , drop = FALSE]
121149
train <- data[setdiff(1:n, fold), , drop = FALSE]
122150
mod <- lapply(learners, \(x) x$clone(deep = TRUE))
123-
lapply(mod, \(x) x$estimate(train))
151+
est_mod(mod, train)
124152
pred.test <- pred_mod(mod, test)
125153
if (!silent) pb()
126154
return(list(pred = pred.test, fold = fold))
@@ -156,14 +184,24 @@ superlearner <- function(learners,
156184
}
157185
mod <- lapply(learners, \(x) x$clone())
158186
names(mod) <- model.names
159-
## Meta-learner
187+
# Meta-learner
160188
y <- learners[[1]]$response(data)
161189
risk <- apply(pred, 2, \(x) model.score(y, x))
190+
# Learners with failed predictions
191+
idx <- which(apply(pred, 2, \(x) any(is.na(x) | is.nan(x))))
192+
if (length(risk) > 0) risk[idx] <- Inf
162193
names(risk) <- model.names
163-
w <- meta.learner(y = y, pred = pred)
194+
if (is.character(meta.learner)) {
195+
if (tolower(meta.learner[1]) == "discrete") {
196+
meta.learner <- metalearner_discrete
197+
} else {
198+
stop("unrecognized meta-learner")
199+
}
200+
}
201+
w <- meta.learner(y = y, pred = pred, risk = risk)
164202
names(w) <- model.names
165203
## Full predictions
166-
lapply(mod, \(x) x$estimate(data))
204+
est_mod(mod, data)
167205
res <- list(
168206
model.score = risk,
169207
weights = w,
@@ -185,7 +223,6 @@ print.superlearner <- function(x, ...) {
185223
return(print(res))
186224
}
187225

188-
189226
#' @title Extract ensemble weights
190227
#' @param object (superlearner) Fitted model.
191228
#' @param ... Not used.
@@ -202,7 +239,6 @@ score.superlearner <- function(x, ...) {
202239
return(x$model.score)
203240
}
204241

205-
206242
#' @title Predict Method for superlearner Fits
207243
#' @description Obtains predictions for ensemble model or individual learners.
208244
#' @export

0 commit comments

Comments
 (0)