Skip to content

Feature/design fixing specials #100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
4bd68a4
New vignette on prediction model class (ml_model)
kkholst Jan 21, 2025
6cad405
Merge branch 'dev' into docs/vignette-prediction-model
kkholst Feb 4, 2025
b6ee439
predictor_nb naive bayes
kkholst Feb 4, 2025
26f64c4
predictor_nb naive bayes
kkholst Feb 4, 2025
82255ae
roxygen
kkholst Feb 4, 2025
495ed57
pbc example
kkholst Feb 4, 2025
a2e336f
merge dev
benesom May 12, 2025
2fbfeed
wip
benesom May 12, 2025
067acd0
introduction
benesom May 13, 2025
e70d13c
some notes
benesom May 13, 2025
626c3a1
Merge branch 'dev' of gh-private:kkholst/targeted into docs/vignette-…
benesom May 14, 2025
a9014d6
tests
benesom May 23, 2025
02eeda6
tests + roxygen
benesom May 23, 2025
2dc15e4
roxygen
benesom May 23, 2025
3b0d962
merge dev
benesom May 23, 2025
3a97784
implement summary method + simple tests
benesom May 25, 2025
e09449c
minor
benesom May 25, 2025
db98d2c
removing formals public field from learner r6 class
benesom May 25, 2025
1dbc72d
Merge branch 'dev' of gh-private:kkholst/targeted into docs/vignette-…
benesom May 25, 2025
9685093
Merge branch 'feature/learner-summary-method-benesom' into docs/vigne…
benesom May 25, 2025
8b4c106
Merge branch 'feature/renaming-predictor-grf-benesom' into docs/vigne…
benesom May 25, 2025
b77403d
wip
benesom May 26, 2025
05b9341
wip
benesom May 26, 2025
a4a3ccf
merge dev
benesom May 26, 2025
fe76c6c
wip
benesom May 26, 2025
cc401b4
Merge branch 'dev' into docs/vignette-prediction-model
kkholst Jun 1, 2025
67cf230
print working for learner with atomic vector result
kkholst Jun 1, 2025
e4367c6
print design working with zero-dim (i.e. summary.design res)
kkholst Jun 1, 2025
9a0a371
with new argument design.matrix. When FALSE only specials will be ex…
kkholst Jun 3, 2025
3baf920
default remove specials from stored formula. Specials are now correct…
kkholst Jun 3, 2025
b54db7b
earth is already handling offset in formula, so shouldn't be added to…
kkholst Jun 3, 2025
b64b28b
unit tests
kkholst Jun 3, 2025
89bef1d
response() should calc. design-matrix. Avoid unnecessary design-matri…
kkholst Jun 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 86 additions & 35 deletions R/design.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,45 +44,66 @@ model.extract2 <- function(frame, component) {
#' @param specials.call (call) specials optionally defined as a call-type
#' @param xlev a named list of character vectors giving the full set of levels
#' to be assumed for each factor
#' @param design.matrix (logical) if FALSE then only response and specials are
#' returned. Otherwise, the design.matrix `x` is als part of the returned
#' object.
#' @return An object of class 'design'
#' @author Klaus Kähler Holst
#' @export
design <- function(formula, data, ..., # nolint
intercept = FALSE,
response = FALSE,
rm_envir = FALSE,
specials = c("weights", "offset"),
specials = NULL,
specials.call = NULL,
xlev = NULL) {
tt <- terms(formula, data = data, specials = specials)
xlev = NULL,
design.matrix = TRUE) {
dots <- substitute(list(...))
if ("subset" %in% names(dots)) stop(
"subset is not an allowed specials argument for targeted::design"
)
mf <- model.frame(tt,
data = data, ...,
xlev = xlev,
drop.unused.levels = FALSE
)
mf <- model.frame(tt, data=data, ...)
tt <- terms(formula, data = data, specials = specials)

if (!design.matrix) { # only extract specials, response
des <- attr(tt, "factors")
sterm.list <- c()
for (s in specials) {
sterm <- rownames(des)[attr(tt, "specials")[[s]]]
sterm.list <- c(sterm.list, sterm)
}
fs <- update(formula, ~1)
if (length(sterm.list) > 0) {
upd <- paste(" ~ . - ", paste(sterm.list, collapse = " - "))
fs <- reformulate(paste(sterm.list, collapse = " + "))
fs <- update(formula, fs)
formula <- update(formula, upd)
}
mf <- model.frame(fs, data=data, ...)
} else { # also extract design matrix
mf <- model.frame(tt,
data = data, ...,
xlev = xlev,
drop.unused.levels = FALSE
)
if (is.null(xlev)) {
xlev <- .getXlevels(tt, mf)
}
xlev0 <- xlev
}

y <- model.response(mf, type = "any")
# delete response to generate design matrix when creating making predictions
if (!response) tt <- delete.response(tt)
has_intercept <- attr(tt, "intercept") == 1L
specials <- union(
specials,
names(dots)[-1] # removing "" at first position when calling dots, which
) # is a call object
if (is.null(xlev)) {
xlev <- .getXlevels(tt, mf)
}
xlev0 <- xlev

term.labels <- attr(tt, "term.labels") # predictors
specials.list <- c()
if (length(specials) > 0) {
des <- attr(tt, "factors")

sterm.list <- c()

for (s in specials) {
w <- eval(substitute(model.extract2(mf, s), list(s = s)))
specials.list <- c(specials.list, list(w))
Expand All @@ -91,20 +112,32 @@ design <- function(formula, data, ..., # nolint
}
names(specials.list) <- specials
if (length(sterm.list) > 0) {
upd <- paste(" ~ . - ", paste(sterm.list, collapse = " - "))
reformulate
tmp.terms <- update(tt, upd) |> terms()
xlev0 <- .getXlevels(tmp.terms, mf)
mf <- model.frame(tmp.terms,
data = data, ...,
xlev = xlev0,
drop.unused.levels = FALSE
)
if ((nrow(attr(tt, "factors")) - attr(tt, "response")) ==
length(sterm.list)) {
# only specials on the rhs, remove everything
formula <- update(formula, ~1)
} else {
# remove specials from formula
formula <- drop.terms(tt,
unlist(attr(tt, "specials")) -
attr(tt, "response"),
keep.response = TRUE)
}
if (design.matrix) {
xlev0[sterm.list] <- NULL
mf <- model.frame(formula,
data = data, ...,
xlev = xlev0,
drop.unused.levels = FALSE
)
# predictors without the specials
term.labels <- setdiff(term.labels,
unlist(sterm.list))
term.labels <- setdiff(term.labels,
unlist(sterm.list))

}
}
}

if (!is.null(specials.call)) {
specials.list2 <- eval(specials.call, data)
for (n in names(specials.list2)) {
Expand All @@ -114,22 +147,31 @@ design <- function(formula, data, ..., # nolint
}
}

x <- model.matrix(mf, data = data, xlev = xlev0)
has_intercept <- attr(tt, "intercept") == 1L
if (!intercept && has_intercept) {
has_intercept <- FALSE
x <- x[, -1, drop = FALSE]
if (design.matrix) {
x <- model.matrix(mf, data = data, xlev = xlev0)
if (!intercept && has_intercept) {
has_intercept <- FALSE
x <- x[, -1, drop = FALSE]
}
} else {
term.labels <- NULL
x <- NULL
}

# delete response to generate design matrix when making predictions
if (!response) tt <- delete.response(tt)

if (rm_envir) attr(tt, ".Environment") <- NULL
if (is.null(specials.call)) specials.call <- dots

res <- c(
list(
formula = formula, # formula without specials
terms = tt,
term.labels = term.labels,
xlevels = xlev,
x = x, y = y,
design.matrix = design.matrix,
intercept = has_intercept,
data = data[0, ], ## Empty data.frame to capture structure of data
specials = specials,
Expand All @@ -146,6 +188,7 @@ update.design <- function(object, data = NULL, ...) {
return(
design(object$terms,
data = data,
design.matrix = object$design.matrix,
xlev = object$xlevels,
intercept = object$intercept,
specials = object$specials,
Expand All @@ -172,7 +215,7 @@ terms.design <- function(x, specials, ...) {

#' @export
summary.design <- function(object, ...) {
object$x <- object$x[0, ]
object$x <- object$x[0, , drop=FALSE]
object$y <- NULL
for (i in object$specials) object[[i]] <- NULL
return(object)
Expand All @@ -182,7 +225,11 @@ summary.design <- function(object, ...) {
print.design <- function(x, n=2, ...) {
cat_ruler(" design object ", 10)
cat(sprintf("\nresponse (length: %s)", length(x$y)))
lava::Print(x$y, n = n, ...)
if (length(x$y) > 0) {
lava::Print(x$y, n = n, ...)
} else {
cat("\n")
}
specials <- c()
for (nam in x$specials) {
if (!is.null(x[[nam]])) {
Expand All @@ -197,7 +244,11 @@ print.design <- function(x, n=2, ...) {
cat("\n")
}
cat(sprintf("\ndesign matrix (dim: %s)\n", paste0(dim(x$x), collapse = ", ")))
lava::Print(x$x, n = n, ...)
if (NROW(x$x) > 0) {
lava::Print(x$x, n = n, ...)
} else {
print(x$x)
}
return(invisible(x))
}

Expand Down
45 changes: 36 additions & 9 deletions R/learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ learner <- R6::R6Class("learner", # nolint
#' @param estimate.args optional arguments to estimate function
#' @param specials optional specials terms (weights, offset,
#' id, subset, ...) passed on to [targeted::design]
#' @param formula.keep.specials if TRUE then special terms defined by
#' `specials` will be removed from the formula before it is being passed to
#' the estimate print.function()
#' @param intercept (logical) include intercept in design matrix
initialize = function(formula = NULL,
estimate,
Expand All @@ -88,6 +91,7 @@ learner <- R6::R6Class("learner", # nolint
estimate.args = NULL,
info = NULL,
specials = c(),
formula.keep.specials = FALSE,
intercept = FALSE
) {
estimate <- add_dots(estimate)
Expand Down Expand Up @@ -115,11 +119,26 @@ learner <- R6::R6Class("learner", # nolint
} else {
if (fit_formula) { # Formula in arguments of estimation procedure
private$fitfun <- function(data, ...) {
args <- private$update_args(private$estimate.args, ...)
des <- do.call(
targeted::design,
c(list(formula = private$.formula,
data = data,
design.matrix = FALSE),
private$des.args
)
)
args <- private$update_args(private$estimate.args, ...) #
form <- private$.formula
if (!private$formula.keep.specials) form <- des$formula
args <- c(
args, list(formula = private$.formula, data = data)
args, list(formula = form, data = data)
)
return(do.call(private$init.estimate, args))
if (length(des$specials) > 0) {
args <- c(args, des[des$specials])
}
return(structure(do.call(private$init.estimate, args),
design = summary(des)
))
}
} else {
# Formula automatically processed into design matrix & response
Expand All @@ -140,7 +159,7 @@ learner <- R6::R6Class("learner", # nolint
}
}
private$predfun <- function(object, data, ...) {
if (fit_formula || no_formula) {
if (no_formula) {
predict_args_call <- private$update_args(predict.args, ...)
args <- c(list(object, newdata = data), predict_args_call)
} else {
Expand All @@ -151,15 +170,19 @@ learner <- R6::R6Class("learner", # nolint
}
predict_args_call <- predict.args
predict_args_call[names(args)] <- args

newdata <- data
if (!fit_formula) {
newdata <- model.matrix(des)
}
args <- c(list(object,
newdata = model.matrix(des)
newdata = newdata
), predict_args_call)
}
return(do.call(private$init.predict, args))
}
}
private$.formula <- formula
private$formula.keep.specials <- formula.keep.specials
self$info <- info
private$init <- list(
estimate.args = estimate.args,
Expand Down Expand Up @@ -247,15 +270,14 @@ learner <- R6::R6Class("learner", # nolint
return(obj)
},


#' @description
#' Extract response from data
#' @param eval when FALSE return the untransformed outcome
#' (i.e., return 'a' if formula defined as I(a==1) ~ ...)
#' @param ... additional arguments to [targeted::design]
response = function(data, eval = TRUE, ...) {
if (eval) {
return(self$design(data = data, ...)$y)
return(self$design(data = data, ..., design.matrix = FALSE)$y)
}
if (is.null(private$.formula)) return(NULL)
newf <- update(private$.formula, ~1)
Expand Down Expand Up @@ -303,6 +325,10 @@ learner <- R6::R6Class("learner", # nolint
# @field .formula Model formula object // uses dot as a pre-fix to allow
# using formula as an active binding
.formula = NULL,
# @field formula.keep.specials if TRUE then special terms defined by
# `specials` will be removed from the formula before it is being passed to
# the estimate print.function()
formula.keep.specials = NULL,
# @field init Information on the initialized model
init = NULL,
# When x$clone(deep=TRUE) is called, the deep_clone gets invoked once for
Expand Down Expand Up @@ -388,7 +414,8 @@ learner_print <- function(self, private) {
if (!is.null(private$fitted)) {
cat_ruler("\u2500", 18)
fit <- self$fit
if (!is.null(fit$call)) fit$call <- substitute()
attr(fit, "design") <- NULL
if (!is.atomic(fit) && !is.null(fit$call)) fit$call <- substitute()
cat(capture.output(print(fit)), sep ="\n")
}

Expand Down
1 change: 0 additions & 1 deletion R/learner_mars.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ learner_mars <- function(formula,
),
list(...)
)
args$specials <- union(args$specials, c("offset"))

args$estimate <- function(formula, data, ...) earth::earth(formula, data, ...)
args$predict <- function(object, newdata, ...) {
Expand Down
7 changes: 3 additions & 4 deletions inst/tinytest/test_design.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ test_design_ellipsis()
test_design_specials <- function() {
# offset is correctly identified as a special variable and not added as a
# covariate
dd <- design(y ~ offset(x1), ddata)
dd <- design(y ~ offset(x1), ddata, specials="offset")

expect_equal(ncol(dd$x), 0)
offset_expect <- ddata$x1
Expand All @@ -89,7 +89,7 @@ test_design_specials <- function() {
# an offset variable is not changed
ddata1 <- ddata
ddata1$offset <- 1
dd <- design(y ~ offset + x1, ddata1)
dd <- design(y ~ offset + x1, ddata1, specials="offset")
expect_equivalent(
as.matrix(ddata1[, c("offset", "x1")]),
dd$x
Expand All @@ -114,8 +114,7 @@ test_design_specials <- function() {
expect_equal(ddata$x1, unname(dd$offset))

# test default weight special
weights <- identity
dd <- design(y ~ weights(x1), ddata)
dd <- design(y ~ weights(x1), ddata, specials="weights")
expect_equal(unname(dd$weights), ddata$x1)
}
test_design_specials()
Expand Down
Loading