diff --git a/DESCRIPTION b/DESCRIPTION index f19e55a38..0c374b72e 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -36,6 +36,7 @@ Suggests: aorsf (>= 0.0.5), actuar, apcluster, + BART (>= 2.9.4), C50, coin, CoxBoost, @@ -66,7 +67,7 @@ Suggests: mgcv, mlr3cluster, mlr3learners (>= 0.4.2), - mlr3proba, + mlr3proba (>= 0.5.3), mlr3pipelines, mvtnorm, nnet, diff --git a/NAMESPACE b/NAMESPACE index a5b0f79da..16ee3d9c3 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -80,6 +80,7 @@ export(LearnerSurvGAMBoost) export(LearnerSurvGBM) export(LearnerSurvGLMBoost) export(LearnerSurvGlmnet) +export(LearnerSurvLearnerSurvBART) export(LearnerSurvLogisticHazard) export(LearnerSurvMBoost) export(LearnerSurvNelson) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R new file mode 100644 index 000000000..3c35f4c01 --- /dev/null +++ b/R/learner_BART_surv_bart.R @@ -0,0 +1,218 @@ +#' @title Survival Bayesian Additive Regression Trees Learner +#' @author bblodfon +#' @name mlr_learners_surv.bart +#' +#' @description +#' Fits a Bayesian Additive Regression Trees (BART) learner to right-censored +#' survival data. Calls [BART::mc.surv.bart()] from \CRANpkg{BART}. +#' +#' @details +#' Two types of prediction are returned for this learner: +#' 1. `distr`: a 3d survival array with observations as 1st dimension, time +#' points as 2nd and the posterior draws as 3rd dimension. +#' 2. `crank`: the expected mortality using [mlr3proba::.surv_return]. The parameter +#' `which.curve` decides which posterior draw (3rd dimension) will be used for the +#' calculation of the expected mortality. Note that the median posterior is +#' by default used for the calculation of survival measures that require a `distr` +#' prediction, see more info on [PredictionSurv][mlr3proba::PredictionSurv]. +#' +#' @section Custom mlr3 defaults: +#' - `mc.cores` is initialized to 1 to avoid threading conflicts with \CRANpkg{future}. +#' +#' @section Custom mlr3 parameters: +#' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is +#' initialized to `TRUE`. +#' - `importance` allows to choose the type of importance. Default is `count`, +#' see documentation of method `$importance()` for more details. +#' - `which.curve` allows to choose which posterior draw will be used for the +#' calculation of the `crank` prediction. If between (0,1) it is taken as the +#' quantile of the curves otherwise if greater than 1 it is taken as the curve +#' index, can also be 'mean'. By default the **median posterior** is used, +#' i.e. `which.curve` is 0.5. +#' +#' @templateVar id surv.bart +#' @template learner +#' +#' @references +#' `r format_bib("sparapani2021nonparametric", "chipman2010bart")` +#' +#' @template seealso_learner +#' @template example +#' @export +LearnerSurvLearnerSurvBART = R6Class("LearnerSurvLearnerSurvBART", + inherit = mlr3proba::LearnerSurv, + public = list( + #' @description + #' Creates a new instance of this [R6][R6::R6Class] class. + initialize = function() { + param_set = ps( + K = p_dbl(default = NULL, special_vals = list(NULL), lower = 1, tags = c("train", "predict")), + events = p_uty(default = NULL, tags = c("train", "predict")), + ztimes = p_uty(default = NULL, tags = c("train", "predict")), + zdelta = p_uty(default = NULL, tags = c("train", "predict")), + sparse = p_lgl(default = FALSE, tags = "train"), + theta = p_dbl(default = 0, tags = "train"), + omega = p_dbl(default = 1, tags = "train"), + a = p_dbl(default = 0.5, lower = 0.5, upper = 1, tags = "train"), + b = p_dbl(default = 1L, tags = "train"), + augment = p_lgl(default = FALSE, tags = "train"), + rho = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"), + usequants = p_lgl(default = FALSE, tags = "train"), + rm.const = p_lgl(default = TRUE, tags = "train"), + type = p_fct(levels = c("pbart", "lbart"), default = "pbart", tags = "train"), + ntype = p_int(lower = 1, upper = 3, tags = "train"), + k = p_dbl(default = 2.0, lower = 0, tags = "train"), + power = p_dbl(default = 2.0, lower = 0, tags = "train"), + base = p_dbl(default = 0.95, lower = 0, upper = 1, tags = "train"), + offset = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"), + ntree = p_int(default = 50L, lower = 1L, tags = "train"), + numcut = p_int(default = 100L, lower = 1L, tags = "train"), + ndpost = p_int(default = 1000L, lower = 1L, tags = "train"), + nskip = p_int(default = 250L, lower = 0L, tags = "train"), + keepevery = p_int(default = 10L, lower = 1L, tags = "train"), + printevery = p_int(default = 100L, lower = 1L, tags = "train"), + seed = p_int(default = 99L, tags = "train"), + mc.cores = p_int(default = 2L, lower = 1L, tags = c("train", "predict")), + nice = p_int(default = 19L, lower = 0L, upper = 19L, tags = c("train", "predict")), + openmp = p_lgl(default = TRUE, tags = "predict"), + quiet = p_lgl(default = TRUE, tags = "predict"), + importance = p_fct(default = "count", levels = c("count", "prob"), tags = "train"), + which.curve = p_dbl(lower = 0L, special_vals = list("mean"), tags = "predict") + ) + + # custom defaults + param_set$values = list(mc.cores = 1, quiet = TRUE, importance = "count", + which.curve = 0.5) # 0.5 quantile => median posterior + + super$initialize( + id = "surv.bart", + packages = "BART", + feature_types = c("logical", "integer", "numeric"), + predict_types = c("crank", "distr"), + param_set = param_set, + properties = c("importance", "missings"), + man = "mlr3extralearners::mlr_learners_surv.bart", + label = "Bayesian Additive Regression Trees" + ) + }, + + #' @description + #' Two types of importance scores are supported based on the value + #' of the parameter `importance`: + #' 1. `prob`: The mean selection probability of each feature in the trees, + #' extracted from the slot `varprob.mean`. + #' If `sparse = FALSE` (default), this is a fixed constant. + #' Recommended to use this option when `sparse = TRUE`. + #' 2. `count`: The mean observed count of each feature in the trees (average + #' number of times the feature was used in a tree decision rule across all + #' posterior draws), extracted from the slot `varcount.mean`. + #' This is the default importance scores. + #' + #' In both cases, higher values signify more important variables. + #' + #' @return Named `numeric()`. + importance = function() { + if (is.null(self$model$model)) { + stopf("No model stored") + } + + pars = self$param_set$get_values(tags = "train") + + if (pars$importance == "prob") { + sort(self$model$model$varprob.mean[-1], decreasing = TRUE) + } else { + sort(self$model$model$varcount.mean[-1], decreasing = TRUE) + } + } + ), + + private = list( + .train = function(task) { + pars = self$param_set$get_values(tags = "train") + pars$importance = NULL # not used in the train function + + x.train = as.data.frame(task$data(cols = task$feature_names)) # nolint + truth = task$truth() + times = truth[, 1] + delta = truth[, 2] # delta => status + + list( + model = invoke( + BART::mc.surv.bart, + x.train = x.train, + times = times, + delta = delta, + .args = pars + ), + # need these for predict + x.train = x.train, + times = times, + delta = delta + ) + }, + + .predict = function(task) { + # get parameters with tag "predict" + pars = self$param_set$get_values(tags = "predict") + + # get newdata and ensure same ordering in train and predict + x.test = as.data.frame(ordered_features(task, self)) # nolint + + # subset parameters to use in `surv.pre.bart` + pars_pre = pars[names(pars) %in% c("K", "events", "ztimes", "zdelta")] + + # transform data to be suitable for BART survival analysis (needs train data) + trans_data = invoke( + BART::surv.pre.bart, + times = self$model$times, + delta = self$model$delta, + x.train = self$model$x.train, + x.test = x.test, + .args = pars_pre + ) + + # subset parameters to use in `predict` + pars_pred = pars[names(pars) %in% c("mc.cores", "nice")] + + pred_fun = function() { + invoke( + predict, + self$model$model, + newdata = trans_data$tx.test, + .args = pars_pred + ) + } + + # don't print C++ generated info during prediction + if (pars$quiet) { + utils::capture.output({ + pred = pred_fun() + }) + } else { + pred = pred_fun() + } + + # Number of test observations + N = task$nrow + # Number of unique times + K = pred$K + times = pred$times + # Number of posterior draws + M = nrow(pred$surv.test) + + # Convert full posterior survival matrix to 3D survival array + # See page 34-35 in Sparapani (2021) for more details + surv_array = aperm( + array(pred$surv.test, dim = c(M, K, N), dimnames = list(NULL, times, NULL)), + c(3, 2, 1) + ) + + # distr => 3d survival array + # crank => expected mortality + mlr3proba::.surv_return(times = times, surv = surv_array, + which.curve = pars$which.curve) + } + ) +) + +.extralrns_dict$add("surv.bart", LearnerSurvLearnerSurvBART) diff --git a/man/mlr_learners_surv.bart.Rd b/man/mlr_learners_surv.bart.Rd new file mode 100644 index 000000000..3c72ce758 --- /dev/null +++ b/man/mlr_learners_surv.bart.Rd @@ -0,0 +1,210 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/learner_BART_surv_bart.R +\name{mlr_learners_surv.bart} +\alias{mlr_learners_surv.bart} +\alias{LearnerSurvLearnerSurvBART} +\title{Survival Bayesian Additive Regression Trees Learner} +\description{ +Fits a Bayesian Additive Regression Trees (BART) learner to right-censored +survival data. Calls \code{\link[BART:surv.bart]{BART::mc.surv.bart()}} from \CRANpkg{BART}. +} +\details{ +Two types of prediction are returned for this learner: +\enumerate{ +\item \code{distr}: a 3d survival array with observations as 1st dimension, time +points as 2nd and the posterior draws as 3rd dimension. +\item \code{crank}: the expected mortality using \link[mlr3proba:dot-surv_return]{mlr3proba::.surv_return}. The parameter +\code{which.curve} decides which posterior draw (3rd dimension) will be used for the +calculation of the expected mortality. Note that the median posterior is +by default used for the calculation of survival measures that require a \code{distr} +prediction, see more info on \link[mlr3proba:PredictionSurv]{PredictionSurv}. +} +} +\section{Custom mlr3 defaults}{ + +\itemize{ +\item \code{mc.cores} is initialized to 1 to avoid threading conflicts with \CRANpkg{future}. +} +} + +\section{Custom mlr3 parameters}{ + +\itemize{ +\item \code{quiet} allows to suppress messages generated by the wrapped C++ code. Is +initialized to \code{TRUE}. +\item \code{importance} allows to choose the type of importance. Default is \code{count}, +see documentation of method \verb{$importance()} for more details. +\item \code{which.curve} allows to choose which posterior draw will be used for the +calculation of the \code{crank} prediction. If between (0,1) it is taken as the +quantile of the curves otherwise if greater than 1 it is taken as the curve +index, can also be 'mean'. By default the \strong{median posterior} is used, +i.e. \code{which.curve} is 0.5. +} +} + +\section{Dictionary}{ + +This \link{Learner} can be instantiated via the \link[mlr3misc:Dictionary]{dictionary} \link{mlr_learners} or with the associated sugar function \code{\link[=lrn]{lrn()}}: + +\if{html}{\out{