diff --git a/DESCRIPTION b/DESCRIPTION index f19e55a38..0c374b72e 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -36,6 +36,7 @@ Suggests: aorsf (>= 0.0.5), actuar, apcluster, + BART (>= 2.9.4), C50, coin, CoxBoost, @@ -66,7 +67,7 @@ Suggests: mgcv, mlr3cluster, mlr3learners (>= 0.4.2), - mlr3proba, + mlr3proba (>= 0.5.3), mlr3pipelines, mvtnorm, nnet, diff --git a/NAMESPACE b/NAMESPACE index a5b0f79da..16ee3d9c3 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -80,6 +80,7 @@ export(LearnerSurvGAMBoost) export(LearnerSurvGBM) export(LearnerSurvGLMBoost) export(LearnerSurvGlmnet) +export(LearnerSurvLearnerSurvBART) export(LearnerSurvLogisticHazard) export(LearnerSurvMBoost) export(LearnerSurvNelson) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R new file mode 100644 index 000000000..3c35f4c01 --- /dev/null +++ b/R/learner_BART_surv_bart.R @@ -0,0 +1,218 @@ +#' @title Survival Bayesian Additive Regression Trees Learner +#' @author bblodfon +#' @name mlr_learners_surv.bart +#' +#' @description +#' Fits a Bayesian Additive Regression Trees (BART) learner to right-censored +#' survival data. Calls [BART::mc.surv.bart()] from \CRANpkg{BART}. +#' +#' @details +#' Two types of prediction are returned for this learner: +#' 1. `distr`: a 3d survival array with observations as 1st dimension, time +#' points as 2nd and the posterior draws as 3rd dimension. +#' 2. `crank`: the expected mortality using [mlr3proba::.surv_return]. The parameter +#' `which.curve` decides which posterior draw (3rd dimension) will be used for the +#' calculation of the expected mortality. Note that the median posterior is +#' by default used for the calculation of survival measures that require a `distr` +#' prediction, see more info on [PredictionSurv][mlr3proba::PredictionSurv]. +#' +#' @section Custom mlr3 defaults: +#' - `mc.cores` is initialized to 1 to avoid threading conflicts with \CRANpkg{future}. +#' +#' @section Custom mlr3 parameters: +#' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is +#' initialized to `TRUE`. +#' - `importance` allows to choose the type of importance. Default is `count`, +#' see documentation of method `$importance()` for more details. +#' - `which.curve` allows to choose which posterior draw will be used for the +#' calculation of the `crank` prediction. If between (0,1) it is taken as the +#' quantile of the curves otherwise if greater than 1 it is taken as the curve +#' index, can also be 'mean'. By default the **median posterior** is used, +#' i.e. `which.curve` is 0.5. +#' +#' @templateVar id surv.bart +#' @template learner +#' +#' @references +#' `r format_bib("sparapani2021nonparametric", "chipman2010bart")` +#' +#' @template seealso_learner +#' @template example +#' @export +LearnerSurvLearnerSurvBART = R6Class("LearnerSurvLearnerSurvBART", + inherit = mlr3proba::LearnerSurv, + public = list( + #' @description + #' Creates a new instance of this [R6][R6::R6Class] class. + initialize = function() { + param_set = ps( + K = p_dbl(default = NULL, special_vals = list(NULL), lower = 1, tags = c("train", "predict")), + events = p_uty(default = NULL, tags = c("train", "predict")), + ztimes = p_uty(default = NULL, tags = c("train", "predict")), + zdelta = p_uty(default = NULL, tags = c("train", "predict")), + sparse = p_lgl(default = FALSE, tags = "train"), + theta = p_dbl(default = 0, tags = "train"), + omega = p_dbl(default = 1, tags = "train"), + a = p_dbl(default = 0.5, lower = 0.5, upper = 1, tags = "train"), + b = p_dbl(default = 1L, tags = "train"), + augment = p_lgl(default = FALSE, tags = "train"), + rho = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"), + usequants = p_lgl(default = FALSE, tags = "train"), + rm.const = p_lgl(default = TRUE, tags = "train"), + type = p_fct(levels = c("pbart", "lbart"), default = "pbart", tags = "train"), + ntype = p_int(lower = 1, upper = 3, tags = "train"), + k = p_dbl(default = 2.0, lower = 0, tags = "train"), + power = p_dbl(default = 2.0, lower = 0, tags = "train"), + base = p_dbl(default = 0.95, lower = 0, upper = 1, tags = "train"), + offset = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"), + ntree = p_int(default = 50L, lower = 1L, tags = "train"), + numcut = p_int(default = 100L, lower = 1L, tags = "train"), + ndpost = p_int(default = 1000L, lower = 1L, tags = "train"), + nskip = p_int(default = 250L, lower = 0L, tags = "train"), + keepevery = p_int(default = 10L, lower = 1L, tags = "train"), + printevery = p_int(default = 100L, lower = 1L, tags = "train"), + seed = p_int(default = 99L, tags = "train"), + mc.cores = p_int(default = 2L, lower = 1L, tags = c("train", "predict")), + nice = p_int(default = 19L, lower = 0L, upper = 19L, tags = c("train", "predict")), + openmp = p_lgl(default = TRUE, tags = "predict"), + quiet = p_lgl(default = TRUE, tags = "predict"), + importance = p_fct(default = "count", levels = c("count", "prob"), tags = "train"), + which.curve = p_dbl(lower = 0L, special_vals = list("mean"), tags = "predict") + ) + + # custom defaults + param_set$values = list(mc.cores = 1, quiet = TRUE, importance = "count", + which.curve = 0.5) # 0.5 quantile => median posterior + + super$initialize( + id = "surv.bart", + packages = "BART", + feature_types = c("logical", "integer", "numeric"), + predict_types = c("crank", "distr"), + param_set = param_set, + properties = c("importance", "missings"), + man = "mlr3extralearners::mlr_learners_surv.bart", + label = "Bayesian Additive Regression Trees" + ) + }, + + #' @description + #' Two types of importance scores are supported based on the value + #' of the parameter `importance`: + #' 1. `prob`: The mean selection probability of each feature in the trees, + #' extracted from the slot `varprob.mean`. + #' If `sparse = FALSE` (default), this is a fixed constant. + #' Recommended to use this option when `sparse = TRUE`. + #' 2. `count`: The mean observed count of each feature in the trees (average + #' number of times the feature was used in a tree decision rule across all + #' posterior draws), extracted from the slot `varcount.mean`. + #' This is the default importance scores. + #' + #' In both cases, higher values signify more important variables. + #' + #' @return Named `numeric()`. + importance = function() { + if (is.null(self$model$model)) { + stopf("No model stored") + } + + pars = self$param_set$get_values(tags = "train") + + if (pars$importance == "prob") { + sort(self$model$model$varprob.mean[-1], decreasing = TRUE) + } else { + sort(self$model$model$varcount.mean[-1], decreasing = TRUE) + } + } + ), + + private = list( + .train = function(task) { + pars = self$param_set$get_values(tags = "train") + pars$importance = NULL # not used in the train function + + x.train = as.data.frame(task$data(cols = task$feature_names)) # nolint + truth = task$truth() + times = truth[, 1] + delta = truth[, 2] # delta => status + + list( + model = invoke( + BART::mc.surv.bart, + x.train = x.train, + times = times, + delta = delta, + .args = pars + ), + # need these for predict + x.train = x.train, + times = times, + delta = delta + ) + }, + + .predict = function(task) { + # get parameters with tag "predict" + pars = self$param_set$get_values(tags = "predict") + + # get newdata and ensure same ordering in train and predict + x.test = as.data.frame(ordered_features(task, self)) # nolint + + # subset parameters to use in `surv.pre.bart` + pars_pre = pars[names(pars) %in% c("K", "events", "ztimes", "zdelta")] + + # transform data to be suitable for BART survival analysis (needs train data) + trans_data = invoke( + BART::surv.pre.bart, + times = self$model$times, + delta = self$model$delta, + x.train = self$model$x.train, + x.test = x.test, + .args = pars_pre + ) + + # subset parameters to use in `predict` + pars_pred = pars[names(pars) %in% c("mc.cores", "nice")] + + pred_fun = function() { + invoke( + predict, + self$model$model, + newdata = trans_data$tx.test, + .args = pars_pred + ) + } + + # don't print C++ generated info during prediction + if (pars$quiet) { + utils::capture.output({ + pred = pred_fun() + }) + } else { + pred = pred_fun() + } + + # Number of test observations + N = task$nrow + # Number of unique times + K = pred$K + times = pred$times + # Number of posterior draws + M = nrow(pred$surv.test) + + # Convert full posterior survival matrix to 3D survival array + # See page 34-35 in Sparapani (2021) for more details + surv_array = aperm( + array(pred$surv.test, dim = c(M, K, N), dimnames = list(NULL, times, NULL)), + c(3, 2, 1) + ) + + # distr => 3d survival array + # crank => expected mortality + mlr3proba::.surv_return(times = times, surv = surv_array, + which.curve = pars$which.curve) + } + ) +) + +.extralrns_dict$add("surv.bart", LearnerSurvLearnerSurvBART) diff --git a/man/mlr_learners_surv.bart.Rd b/man/mlr_learners_surv.bart.Rd new file mode 100644 index 000000000..3c72ce758 --- /dev/null +++ b/man/mlr_learners_surv.bart.Rd @@ -0,0 +1,210 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/learner_BART_surv_bart.R +\name{mlr_learners_surv.bart} +\alias{mlr_learners_surv.bart} +\alias{LearnerSurvLearnerSurvBART} +\title{Survival Bayesian Additive Regression Trees Learner} +\description{ +Fits a Bayesian Additive Regression Trees (BART) learner to right-censored +survival data. Calls \code{\link[BART:surv.bart]{BART::mc.surv.bart()}} from \CRANpkg{BART}. +} +\details{ +Two types of prediction are returned for this learner: +\enumerate{ +\item \code{distr}: a 3d survival array with observations as 1st dimension, time +points as 2nd and the posterior draws as 3rd dimension. +\item \code{crank}: the expected mortality using \link[mlr3proba:dot-surv_return]{mlr3proba::.surv_return}. The parameter +\code{which.curve} decides which posterior draw (3rd dimension) will be used for the +calculation of the expected mortality. Note that the median posterior is +by default used for the calculation of survival measures that require a \code{distr} +prediction, see more info on \link[mlr3proba:PredictionSurv]{PredictionSurv}. +} +} +\section{Custom mlr3 defaults}{ + +\itemize{ +\item \code{mc.cores} is initialized to 1 to avoid threading conflicts with \CRANpkg{future}. +} +} + +\section{Custom mlr3 parameters}{ + +\itemize{ +\item \code{quiet} allows to suppress messages generated by the wrapped C++ code. Is +initialized to \code{TRUE}. +\item \code{importance} allows to choose the type of importance. Default is \code{count}, +see documentation of method \verb{$importance()} for more details. +\item \code{which.curve} allows to choose which posterior draw will be used for the +calculation of the \code{crank} prediction. If between (0,1) it is taken as the +quantile of the curves otherwise if greater than 1 it is taken as the curve +index, can also be 'mean'. By default the \strong{median posterior} is used, +i.e. \code{which.curve} is 0.5. +} +} + +\section{Dictionary}{ + +This \link{Learner} can be instantiated via the \link[mlr3misc:Dictionary]{dictionary} \link{mlr_learners} or with the associated sugar function \code{\link[=lrn]{lrn()}}: + +\if{html}{\out{
}}\preformatted{mlr_learners$get("surv.bart") +lrn("surv.bart") +}\if{html}{\out{
}} +} + +\section{Meta Information}{ + +\itemize{ +\item Task type: \dQuote{surv} +\item Predict Types: \dQuote{crank}, \dQuote{distr} +\item Feature Types: \dQuote{logical}, \dQuote{integer}, \dQuote{numeric} +\item Required Packages: \CRANpkg{mlr3}, \CRANpkg{mlr3proba}, \CRANpkg{BART} +} +} + +\section{Parameters}{ +\tabular{lllll}{ + Id \tab Type \tab Default \tab Levels \tab Range \cr + K \tab numeric \tab NULL \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + events \tab untyped \tab \tab \tab - \cr + ztimes \tab untyped \tab \tab \tab - \cr + zdelta \tab untyped \tab \tab \tab - \cr + sparse \tab logical \tab FALSE \tab TRUE, FALSE \tab - \cr + theta \tab numeric \tab 0 \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr + omega \tab numeric \tab 1 \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr + a \tab numeric \tab 0.5 \tab \tab \eqn{[0.5, 1]}{[0.5, 1]} \cr + b \tab numeric \tab 1 \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr + augment \tab logical \tab FALSE \tab TRUE, FALSE \tab - \cr + rho \tab numeric \tab NULL \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr + usequants \tab logical \tab FALSE \tab TRUE, FALSE \tab - \cr + rm.const \tab logical \tab TRUE \tab TRUE, FALSE \tab - \cr + type \tab character \tab pbart \tab pbart, lbart \tab - \cr + ntype \tab integer \tab - \tab \tab \eqn{[1, 3]}{[1, 3]} \cr + k \tab numeric \tab 2 \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr + power \tab numeric \tab 2 \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr + base \tab numeric \tab 0.95 \tab \tab \eqn{[0, 1]}{[0, 1]} \cr + offset \tab numeric \tab NULL \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr + ntree \tab integer \tab 50 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + numcut \tab integer \tab 100 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + ndpost \tab integer \tab 1000 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + nskip \tab integer \tab 250 \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr + keepevery \tab integer \tab 10 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + printevery \tab integer \tab 100 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + seed \tab integer \tab 99 \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr + mc.cores \tab integer \tab 2 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + nice \tab integer \tab 19 \tab \tab \eqn{[0, 19]}{[0, 19]} \cr + openmp \tab logical \tab TRUE \tab TRUE, FALSE \tab - \cr + quiet \tab logical \tab TRUE \tab TRUE, FALSE \tab - \cr + importance \tab character \tab count \tab count, prob \tab - \cr + which.curve \tab numeric \tab - \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr +} +} + +\examples{ +learner = mlr3::lrn("surv.bart") +print(learner) + +# available parameters: +learner$param_set$ids() +} +\references{ +Sparapani, Rodney, Spanbauer, Charles, McCulloch, Robert (2021). +\dQuote{Nonparametric machine learning and efficient computation with bayesian additive regression trees: the BART R package.} +\emph{Journal of Statistical Software}, \bold{97}, 1--66. + +Chipman, A H, George, I E, McCulloch, E R (2010). +\dQuote{BART: Bayesian additive regression trees.} +\emph{The Annals of Applied Statistics}, \bold{4}(1), 266--298. +} +\seealso{ +\itemize{ +\item \link[mlr3misc:Dictionary]{Dictionary} of \link[mlr3:Learner]{Learners}: \link[mlr3:mlr_learners]{mlr3::mlr_learners}. +\item \code{as.data.table(mlr_learners)} for a table of available \link[=Learner]{Learners} in the running session (depending on the loaded packages). +\item Chapter in the \href{https://mlr3book.mlr-org.com/}{mlr3book}: \url{https://mlr3book.mlr-org.com/basics.html#learners} +\item \CRANpkg{mlr3learners} for a selection of recommended learners. +\item \CRANpkg{mlr3cluster} for unsupervised clustering learners. +\item \CRANpkg{mlr3pipelines} to combine learners with pre- and postprocessing steps. +\item \CRANpkg{mlr3tuning} for tuning of hyperparameters, \CRANpkg{mlr3tuningspaces} for established default tuning spaces. +} +} +\author{ +bblodfon +} +\section{Super classes}{ +\code{\link[mlr3:Learner]{mlr3::Learner}} -> \code{\link[mlr3proba:LearnerSurv]{mlr3proba::LearnerSurv}} -> \code{LearnerSurvLearnerSurvBART} +} +\section{Methods}{ +\subsection{Public methods}{ +\itemize{ +\item \href{#method-LearnerSurvLearnerSurvBART-new}{\code{LearnerSurvLearnerSurvBART$new()}} +\item \href{#method-LearnerSurvLearnerSurvBART-importance}{\code{LearnerSurvLearnerSurvBART$importance()}} +\item \href{#method-LearnerSurvLearnerSurvBART-clone}{\code{LearnerSurvLearnerSurvBART$clone()}} +} +} +\if{html}{\out{ +
Inherited methods + +
+}} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-LearnerSurvLearnerSurvBART-new}{}}} +\subsection{Method \code{new()}}{ +Creates a new instance of this \link[R6:R6Class]{R6} class. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{LearnerSurvLearnerSurvBART$new()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-LearnerSurvLearnerSurvBART-importance}{}}} +\subsection{Method \code{importance()}}{ +Two types of importance scores are supported based on the value +of the parameter \code{importance}: +\enumerate{ +\item \code{prob}: The mean selection probability of each feature in the trees, +extracted from the slot \code{varprob.mean}. +If \code{sparse = FALSE} (default), this is a fixed constant. +Recommended to use this option when \code{sparse = TRUE}. +\item \code{count}: The mean observed count of each feature in the trees (average +number of times the feature was used in a tree decision rule across all +posterior draws), extracted from the slot \code{varcount.mean}. +This is the default importance scores. +} + +In both cases, higher values signify more important variables. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{LearnerSurvLearnerSurvBART$importance()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Named \code{numeric()}. +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-LearnerSurvLearnerSurvBART-clone}{}}} +\subsection{Method \code{clone()}}{ +The objects of this class are cloneable with this method. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{LearnerSurvLearnerSurvBART$clone(deep = FALSE)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{deep}}{Whether to make a deep clone.} +} +\if{html}{\out{
}} +} +} +} diff --git a/tests/testthat/test_BART_surv_bart.R b/tests/testthat/test_BART_surv_bart.R new file mode 100644 index 000000000..60c25a568 --- /dev/null +++ b/tests/testthat/test_BART_surv_bart.R @@ -0,0 +1,6 @@ +test_that("autotest", { + learner = lrn("surv.bart", nskip = 1, ndpost = 1, keepevery = 1) + expect_learner(learner) + result = run_autotest(learner) + expect_true(result, info = result$error) +}) diff --git a/tests/testthat/test_paramtest_BART_surv_bart.R b/tests/testthat/test_paramtest_BART_surv_bart.R new file mode 100644 index 000000000..d55c865eb --- /dev/null +++ b/tests/testthat/test_paramtest_BART_surv_bart.R @@ -0,0 +1,38 @@ +test_that("paramtest surv.bart train", { + learner = lrn("surv.bart") + fun_list = list(BART::mc.surv.bart) + exclude = c( + "x.train", # handled internally + "y.train", # handled internally + "times", # handled internally + "delta", # handled internally + "x.test", # not used + "xinfo", # not used + "tau.num", # not used, automatically calculated from `type` and `ntype` + "id", # only for `surv.bart` + "importance" # added to choose the type of importance + ) + + paramtest = run_paramtest(learner, fun_list, exclude, tag = "train") + expect_paramtest(paramtest) +}) + +test_that("paramset surv.bart predict", { + learner = lrn("surv.bart") + fun_list = list(BART:::predict.survbart, BART::surv.pre.bart) # nolint + exclude = c( + "object", # handled internally + "newdata", # handled internally + "openmp", # handled internally + "times", # handled internally, used in `surv.pre.bart` + "delta", # handled internally, used in `surv.pre.bart` + "x.train", # handled internally, used in `surv.pre.bart` + "x.test", # handled internally, used in `surv.pre.bart` + "nice", # handled internally + "quiet", # added to suppress print messages + "which.curve" # added to choose 3rd dimension (posterior draw) for crank calculation + ) + + paramtest = run_paramtest(learner, fun_list, exclude, tag = "predict") + expect_paramtest(paramtest) +}) diff --git a/tests/testthat/test_paramtest_survivalmodels_surv_dnnsurv.R b/tests/testthat/test_paramtest_survivalmodels_surv_dnnsurv.R index 5212d4499..005560f9a 100644 --- a/tests/testthat/test_paramtest_survivalmodels_surv_dnnsurv.R +++ b/tests/testthat/test_paramtest_survivalmodels_surv_dnnsurv.R @@ -9,7 +9,21 @@ test_that("paramtest surv.dnnsurv train", { "time_variable", # handled internally "status_variable", # handled internally "x", # unused - "y" # unused + "y", # unused + "rho", # handled internally + "global_clipnorm", # handled internally + "use_ema", # handled internally + "ema_momentum", # handled internally + "ema_overwrite_frequency", # handled internally + "jit_compile", # handled internally + "initial_accumultator_value", # handled internally + "amsgrad", # handled internally + "lr_power", # handled internally + "l1_regularization_strength", # handled internally + "l2_regularization_strength", # handled internally + "l2_shrinkage_regularization_strength", # handled internally + "beta", # handled internally + "centered" # handled internally ) paramtest = run_paramtest(learner, fun_list, exclude, tag = "train") diff --git a/vignettes/extending.Rmd b/vignettes/extending.Rmd index c3fb516ce..52ea7dc97 100644 --- a/vignettes/extending.Rmd +++ b/vignettes/extending.Rmd @@ -316,7 +316,7 @@ test_that("autotest", { learner = LearnerRegrRpartSimple$new() # basic learner properties - expect_learner(learner, check_man = FALSE) + expect_learner(learner) # you can skip tests using the `exclude` argument result = run_autotest(learner)