1
- metalearner_nnls <- function (y , pred , method = " nnls" ) {
1
+ metalearner_nnls <- function (y , pred , method = " nnls" , ... ) {
2
2
if (NCOL(pred )== 1 ) return (1.0 )
3
+ idx <- which(apply(pred , 2 , \(x ) ! any(is.na(x ))))
4
+ coefs <- rep(0 , ncol(pred ))
5
+ pred <- pred [, idx , drop = FALSE ]
3
6
if (method == " nnls" ) {
4
7
res <- nnls :: nnls(A = pred , b = y )
5
- coefs <- res $ x
8
+ coefs [ idx ] <- res $ x
6
9
} else {
7
10
res <- glmnet :: glmnet(
8
11
y = y , x = pred ,
9
12
intercept = FALSE ,
10
13
lambda = 0 ,
11
14
lower.limits = rep(0 , ncol(pred ))
12
15
)
13
- coefs <- as.vector(coef(res ))[- 1 ]
16
+ coefs [ idx ] <- as.vector(coef(res ))[- 1 ]
14
17
}
15
18
if (any(is.na(coefs ))) coefs [is.na(coefs )] <- 0
16
19
if (all(coefs == 0 )) coefs [1 ] <- 1
17
20
return (coefs / sum(coefs ))
18
21
}
19
22
23
+ metalearner_discrete <- function (y , pred , risk , ... ) {
24
+ weights <- rep(0 , NCOL(pred ))
25
+ risk [is.na ] <- Inf
26
+ weights [which.min(risk )[1 ]] <- 1
27
+ return (weights )
28
+ }
29
+
30
+
20
31
get_learner_names <- function (model.list , name.prefix ) {
21
32
.names <- names(model.list )
22
33
if (is.null(.names )) .names <- rep(" " , length(model.list ))
@@ -93,9 +104,26 @@ superlearner <- function(learners,
93
104
name.prefix = NULL ,
94
105
... ) {
95
106
pred_mod <- function (models , data ) {
96
- res <- lapply(models , \(x ) x $ predict(data ))
97
- return (Reduce(cbind , res ))
107
+ n <- nrow(data )
108
+ res <- matrix (NA , nrow = n , ncol = length(models ))
109
+ for (i in seq_along(models )) {
110
+ if (! is.null(models [[i ]]$ fit )) {
111
+ res [, i ] <- tryCatch(
112
+ models [[i ]]$ predict(data ), error = function (x ) rep(NA , n )
113
+ )
114
+ }
115
+ }
116
+ return (res )
117
+ }
118
+ est_mod <- function (models , data ) {
119
+ for (i in seq_along(models )) {
120
+ v <- tryCatch(models [[i ]]$ estimate(data ), error = function (x ) NULL )
121
+ if (is.null(v )) {
122
+ models [[i ]]$ fit <- NULL
123
+ }
124
+ }
98
125
}
126
+
99
127
if (is.character(model.score )) {
100
128
model.score <- get(model.score )
101
129
}
@@ -120,7 +148,7 @@ superlearner <- function(learners,
120
148
test <- data [fold , , drop = FALSE ]
121
149
train <- data [setdiff(1 : n , fold ), , drop = FALSE ]
122
150
mod <- lapply(learners , \(x ) x $ clone(deep = TRUE ))
123
- lapply (mod , \( x ) x $ estimate( train ) )
151
+ est_mod (mod , train )
124
152
pred.test <- pred_mod(mod , test )
125
153
if (! silent ) pb()
126
154
return (list (pred = pred.test , fold = fold ))
@@ -156,14 +184,24 @@ superlearner <- function(learners,
156
184
}
157
185
mod <- lapply(learners , \(x ) x $ clone())
158
186
names(mod ) <- model.names
159
- # # Meta-learner
187
+ # Meta-learner
160
188
y <- learners [[1 ]]$ response(data )
161
189
risk <- apply(pred , 2 , \(x ) model.score(y , x ))
190
+ # Learners with failed predictions
191
+ idx <- which(apply(pred , 2 , \(x ) any(is.na(x ) | is.nan(x ))))
192
+ if (length(risk ) > 0 ) risk [idx ] <- Inf
162
193
names(risk ) <- model.names
163
- w <- meta.learner(y = y , pred = pred )
194
+ if (is.character(meta.learner )) {
195
+ if (tolower(meta.learner [1 ]) == " discrete" ) {
196
+ meta.learner <- metalearner_discrete
197
+ } else {
198
+ stop(" unrecognized meta-learner" )
199
+ }
200
+ }
201
+ w <- meta.learner(y = y , pred = pred , risk = risk )
164
202
names(w ) <- model.names
165
203
# # Full predictions
166
- lapply (mod , \( x ) x $ estimate( data ) )
204
+ est_mod (mod , data )
167
205
res <- list (
168
206
model.score = risk ,
169
207
weights = w ,
@@ -185,7 +223,6 @@ print.superlearner <- function(x, ...) {
185
223
return (print(res ))
186
224
}
187
225
188
-
189
226
# ' @title Extract ensemble weights
190
227
# ' @param object (superlearner) Fitted model.
191
228
# ' @param ... Not used.
@@ -202,7 +239,6 @@ score.superlearner <- function(x, ...) {
202
239
return (x $ model.score )
203
240
}
204
241
205
-
206
242
# ' @title Predict Method for superlearner Fits
207
243
# ' @description Obtains predictions for ensemble model or individual learners.
208
244
# ' @export
0 commit comments