From bd5f64240e66660fba2ffa4c6c4db99f9ec52177 Mon Sep 17 00:00:00 2001 From: john Date: Tue, 15 Aug 2023 16:04:07 +0200 Subject: [PATCH 01/47] update doc links --- man/mlr_learners_classif.C50.Rd | 2 +- man/mlr_learners_classif.abess.Rd | 2 +- man/mlr_learners_dens.mixed.Rd | 2 +- man/mlr_learners_regr.abess.Rd | 2 +- man/mlr_learners_regr.cubist.Rd | 2 +- man/mlr_learners_surv.gamboost.Rd | 2 +- man/mlr_learners_surv.glmboost.Rd | 2 +- man/mlr_learners_surv.mboost.Rd | 4 ++-- 8 files changed, 9 insertions(+), 9 deletions(-) diff --git a/man/mlr_learners_classif.C50.Rd b/man/mlr_learners_classif.C50.Rd index 42c0dc525..043e04e45 100644 --- a/man/mlr_learners_classif.C50.Rd +++ b/man/mlr_learners_classif.C50.Rd @@ -6,7 +6,7 @@ \title{Classification C5.0 Learner} \description{ Decision Tree Algorithm. -Calls \code{\link[C50:C5.0.formula]{C50::C5.0.formula()}} from \CRANpkg{C50}. +Calls \code{\link[C50:C5.0]{C50::C5.0.formula()}} from \CRANpkg{C50}. } \section{Dictionary}{ diff --git a/man/mlr_learners_classif.abess.Rd b/man/mlr_learners_classif.abess.Rd index 793ddc370..411d25103 100644 --- a/man/mlr_learners_classif.abess.Rd +++ b/man/mlr_learners_classif.abess.Rd @@ -134,7 +134,7 @@ Creates a new instance of this \link[R6:R6Class]{R6} class. \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-LearnerClassifAbess-selected_features}{}}} \subsection{Method \code{selected_features()}}{ -Extract the name of selected features from the model by \code{\link[abess:extract]{abess::extract()}}. +Extract the name of selected features from the model by \code{\link[abess:extract.abess]{abess::extract()}}. \subsection{Usage}{ \if{html}{\out{
}}\preformatted{LearnerClassifAbess$selected_features()}\if{html}{\out{
}} } diff --git a/man/mlr_learners_dens.mixed.Rd b/man/mlr_learners_dens.mixed.Rd index 3d1d65a7d..a1e1db3d5 100644 --- a/man/mlr_learners_dens.mixed.Rd +++ b/man/mlr_learners_dens.mixed.Rd @@ -6,7 +6,7 @@ \title{Density Mixed Data Kernel Learner} \description{ Density estimator for discrete and continuous variables. -Calls \code{\link[np:npudens]{np::npudens()}} from \CRANpkg{np}. +Calls \code{\link[np:np.density]{np::npudens()}} from \CRANpkg{np}. } \section{Dictionary}{ diff --git a/man/mlr_learners_regr.abess.Rd b/man/mlr_learners_regr.abess.Rd index 188c1dae9..5bab35838 100644 --- a/man/mlr_learners_regr.abess.Rd +++ b/man/mlr_learners_regr.abess.Rd @@ -126,7 +126,7 @@ Creates a new instance of this \link[R6:R6Class]{R6} class. \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-LearnerRegrAbess-selected_features}{}}} \subsection{Method \code{selected_features()}}{ -Extract the name of selected features from the model by \code{\link[abess:extract]{abess::extract()}}. +Extract the name of selected features from the model by \code{\link[abess:extract.abess]{abess::extract()}}. \subsection{Usage}{ \if{html}{\out{
}}\preformatted{LearnerRegrAbess$selected_features()}\if{html}{\out{
}} } diff --git a/man/mlr_learners_regr.cubist.Rd b/man/mlr_learners_regr.cubist.Rd index c25c3a794..3aa475fd1 100644 --- a/man/mlr_learners_regr.cubist.Rd +++ b/man/mlr_learners_regr.cubist.Rd @@ -7,7 +7,7 @@ \description{ Rule-based model that is an extension of Quinlan's M5 model tree. Each tree contains linear regression models at the terminal leaves. -Calls \code{\link[Cubist:cubist]{Cubist::cubist()}} from \CRANpkg{Cubist}. +Calls \code{\link[Cubist:cubist.default]{Cubist::cubist()}} from \CRANpkg{Cubist}. } \section{Dictionary}{ diff --git a/man/mlr_learners_surv.gamboost.Rd b/man/mlr_learners_surv.gamboost.Rd index 7ee93b37d..9fb239c2f 100644 --- a/man/mlr_learners_surv.gamboost.Rd +++ b/man/mlr_learners_surv.gamboost.Rd @@ -133,7 +133,7 @@ Named \code{numeric()}. \if{latex}{\out{\hypertarget{method-LearnerSurvGAMBoost-selected_features}{}}} \subsection{Method \code{selected_features()}}{ Selected features are extracted with the function -\code{\link[mboost:variable.names.mboost]{mboost::variable.names.mboost()}}, with +\code{\link[mboost:methods]{mboost::variable.names.mboost()}}, with \code{used.only = TRUE}. \subsection{Usage}{ \if{html}{\out{
}}\preformatted{LearnerSurvGAMBoost$selected_features()}\if{html}{\out{
}} diff --git a/man/mlr_learners_surv.glmboost.Rd b/man/mlr_learners_surv.glmboost.Rd index eef4c6703..e07654fed 100644 --- a/man/mlr_learners_surv.glmboost.Rd +++ b/man/mlr_learners_surv.glmboost.Rd @@ -116,7 +116,7 @@ matrix and original names can't be recovered. description Selected features are extracted with the function -\code{\link[mboost:variable.names.mboost]{mboost::variable.names.mboost()}}, with +\code{\link[mboost:methods]{mboost::variable.names.mboost()}}, with \code{used.only = TRUE}. return \code{character()}. \subsection{Usage}{ diff --git a/man/mlr_learners_surv.mboost.Rd b/man/mlr_learners_surv.mboost.Rd index e5f1e6410..fc0b9c1ac 100644 --- a/man/mlr_learners_surv.mboost.Rd +++ b/man/mlr_learners_surv.mboost.Rd @@ -6,7 +6,7 @@ \title{Boosted Generalized Additive Survival Learner} \description{ Model-based boosting for survival analysis. -Calls \code{\link[mboost:mboost]{mboost::mboost()}} from \CRANpkg{mboost}. +Calls \code{\link[mboost:gamboost]{mboost::mboost()}} from \CRANpkg{mboost}. } \details{ \code{distr} prediction made by \code{\link[mboost:survFit]{mboost::survFit()}}. @@ -131,7 +131,7 @@ Named \code{numeric()}. \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-LearnerSurvMBoost-selected_features}{}}} \subsection{Method \code{selected_features()}}{ -Selected features are extracted with the function \code{\link[mboost:variable.names.mboost]{mboost::variable.names.mboost()}}, with +Selected features are extracted with the function \code{\link[mboost:methods]{mboost::variable.names.mboost()}}, with \code{used.only = TRUE}. \subsection{Usage}{ \if{html}{\out{
}}\preformatted{LearnerSurvMBoost$selected_features()}\if{html}{\out{
}} From 8fa59a2dd70a54a26cbce06d46342c12282f471d Mon Sep 17 00:00:00 2001 From: john Date: Thu, 17 Aug 2023 13:18:34 +0200 Subject: [PATCH 02/47] add BART survival learner --- NAMESPACE | 1 + R/learner_BART_surv_bart.R | 292 ++++++++++++++++++ man/mlr_learners_surv.bart.Rd | 280 +++++++++++++++++ tests/testthat/test_BART_surv_bart.R | 8 + .../testthat/test_paramtest_BART_surv_bart.R | 36 +++ 5 files changed, 617 insertions(+) create mode 100644 R/learner_BART_surv_bart.R create mode 100644 man/mlr_learners_surv.bart.Rd create mode 100644 tests/testthat/test_BART_surv_bart.R create mode 100644 tests/testthat/test_paramtest_BART_surv_bart.R diff --git a/NAMESPACE b/NAMESPACE index 9662fce9a..e3b3a1513 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -80,6 +80,7 @@ export(LearnerSurvGAMBoost) export(LearnerSurvGBM) export(LearnerSurvGLMBoost) export(LearnerSurvGlmnet) +export(LearnerSurvLearnerSurvBART) export(LearnerSurvLogisticHazard) export(LearnerSurvMBoost) export(LearnerSurvNelson) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R new file mode 100644 index 000000000..429d4d536 --- /dev/null +++ b/R/learner_BART_surv_bart.R @@ -0,0 +1,292 @@ +#' @title Survival Bayesian Additive Regression Trees Learner +#' @author bblodfon +#' @name mlr_learners_surv.bart +#' +#' @description +#' Fits a Bayesian Additive Regression Trees (BART) learner to right-censored +#' survival data. +#' For prediction, we return the mean posterior estimates of the survival +#' function and the corresponding `crank` (expected mortality) using +#' [mlr3proba::.surv_return]. +#' The full posterior estimates are currently stored in the +#' `learner$state$surv_test` slot, along with the number of test observations +#' `N`, number of unique times in the train set `K` and number of posterior +#' draws `M`. +#' See example for more details. +#' +#' Calls [BART::mc.surv.bart()] from \CRANpkg{BART}. +#' Based on BART version 2.9.4 (2023-03-25). +#' +#' @section Custom mlr3 defaults: +#' - `mc.cores` (in general use as many as possible if no issues arise): +#' - Actual default: 2 +#' - Adjusted value: 1 +#' - Reason for change: May conflict with parallelization via \CRANpkg{future}. +#' +#' @section Initial parameter values: +#' - `quiet` +#' - Default is `TRUE`, to suppress messages generated by the wrapped C++ code +#' during prediction. +#' +#' @templateVar id surv.bart +#' @template learner +#' +#' @references +#' `r format_bib("sparapani2021nonparametric", "chipman2010bart")` +#' +#' @template seealso_learner +#' @examples +#' library(mlr3proba) +#' library(dplyr) +#' library(ggplot2) +#' +#' learner = lrn("surv.bart", nskip = 10, ndpost = 20, keepevery = 2) +#' task = tsk("lung") +#' task$missings() # has missing values +#' +#' # split to train and test sets +#' set.seed(42) +#' part = partition(task) +#' +#' # Train +#' learner$train(task, row_ids = part$train) +#' +#' # Importance: average number of times a feature has been used in the trees +#' learner$importance() +#' +#' # Test +#' p = learner$predict(task, row_ids = part$test) +#' p$score() # C-index +#' +#' # Mean survival probabilities for the first 3 patients at given time points +#' p$distr$survival(times = c(1,50,150))[,1:3] +#' +#' # number of posterior draws +#' M = learner$state$M +#' stopifnot(M == 20) +#' # number of test observations +#' N = learner$state$N +#' stopifnot(N == length(part$test)) +#' # number of unique time points in the train set +#' K = learner$state$K +#' stopifnot(K == length(task$unique_times(rows = part$train))) +#' # the actual times are also available in the `$model` slot: +#' head(learner$model$times) +#' +#' # Full posterior prediction matrix +#' surv_test = learner$state$surv_test +#' stopifnot(all(dim(surv_test) == c(M, K * N))) +#' +#' # Posterior survival function estimates for the 1st test patient for all +#' # time points (from the train set) - see Sparapani (2021), pages 34-35 +#' post_surv = surv_test[, 1:K] +#' +#' # For every time point, get the median survival estimate as well as +#' # the lower and upper bounds of the 95% quantile credible interval +#' surv_data = post_surv %>% +#' as.data.frame() %>% +#' `colnames<-` (learner$model$times) %>% +#' summarise(across(everything(), list( +#' median = ~ median(.), +#' low_qi = ~ quantile(., 0.025), +#' high_qi = ~ quantile(., 0.975) +#' ))) %>% +#' tidyr::pivot_longer( +#' cols = everything(), +#' names_to = c("times", ".value"), +#' names_pattern = "(^[^_]+)_(.*)" # everything until the first underscore +#' ) %>% +#' mutate(times = as.numeric(times)) +#' surv_data +#' +#' # Draw a survival curve for the first patient in the test set with +#' # uncertainty quantified +#' surv_data %>% +#' ggplot2::ggplot(aes(x = times, y = median)) + +#' geom_step(col = 'black') + +#' xlab('Time (Days)') + +#' ylab('Survival Probability') + +#' geom_ribbon(aes(ymin = low_qi, ymax = high_qi), alpha = 0.3) + +#' theme_bw() +#' @export +delayedAssign( + "LearnerSurvLearnerSurvBART", R6Class("LearnerSurvLearnerSurvBART", + inherit = LearnerSurv, + public = list( + #' @description + #' Creates a new instance of this [R6][R6::R6Class] class. + initialize = function() { + param_set = ps( + K = p_uty(default = NULL, tags = c("train", "predict")), + events = p_uty(default = NULL, tags = c("train", "predict")), + ztimes = p_uty(default = NULL, tags = c("train", "predict")), + zdelta = p_uty(default = NULL, tags = c("train", "predict")), + sparse = p_lgl(default = FALSE, tags = "train"), + theta = p_dbl(default = 0, tags = "train"), + omega = p_dbl(default = 1, tags = "train"), + a = p_dbl(default = 0.5, lower = 0.5, upper = 1, tags = "train"), + b = p_dbl(default = 1L, tags = "train"), + augment = p_lgl(default = FALSE, tags = "train"), + rho = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"), + usequants = p_lgl(default = FALSE, tags = "train"), + rm.const = p_lgl(default = TRUE, tags = "train"), + type = p_fct(levels = c("pbart", "lbart"), default = "pbart", tags = "train"), + ntype = p_int(lower = 1, upper = 3, tags = "train"), + k = p_dbl(default = 2.0, lower = 0, tags = "train"), + power = p_dbl(default = 2.0, lower = 0, tags = "train"), + base = p_dbl(default = 0.95, lower = 0, upper = 1, tags = "train"), + offset = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"), + ntree = p_int(default = 50L, lower = 1L, tags = "train"), + numcut = p_int(default = 100L, lower = 1L, tags = "train"), + ndpost = p_int(default = 1000L, lower = 1L, tags = "train"), + nskip = p_int(default = 250L, lower = 0L, tags = "train"), + keepevery = p_int(default = 10L, lower = 1L, tags = "train"), + printevery = p_int(default = 100L, lower = 1L, tags = "train"), + seed = p_int(default = 99L, tags = "train"), + mc.cores = p_int(default = 2L, lower = 1L, tags = c("train", "predict")), + nice = p_int(default = 19L, lower = 0L, upper = 19L, tags = c("train", "predict")), + openmp = p_lgl(default = TRUE, tags = "predict"), + quiet = p_lgl(default = TRUE, tags = "predict") + ) + + # custom defaults + param_set$values = list(mc.cores = 1, quiet = TRUE) + + super$initialize( + id = "surv.bart", + packages = "BART", + feature_types = c("logical", "integer", "numeric", "factor"), + predict_types = c("crank", "distr"), + param_set = param_set, + properties = c("importance", "missings"), + man = "mlr3extralearners::mlr_learners_surv.bart", + label = "Bayesian Additive Regression Trees" + ) + }, + + #' @description + #' Two types of importance scores are supported: + #' 1. The mean selection probability of each feature in the trees, + #' extracted from the slot `varprob.mean`. + #' If `sparse = FALSE` (default), this is a fixed constant. + #' Recommended to use this option when `sparse = TRUE`. + #' 2. The observed count of each feature in the trees, extracted from the + #' slot `varcount.mean`. + #' This is the default importance scores. + #' + #' In both cases, higher values signify more important variables. + #' + #' @param type Can be either `count` or `prob`. + #' @return Named `numeric()`. + importance = function(type = 'count') { + if (!type %in% c('prob', 'count')) stopf("type can be only 'prob' or 'count'") + + if (is.null(self$model)) { + stopf("No model stored") + } + + if (type == 'prob') { + sort(self$model$varprob.mean[-1], decreasing = T) + } else { + sort(self$model$varcount.mean[-1], decreasing = T) + } + } + ), + + private = list( + .train = function(task) { + pars = self$param_set$get_values(tags = "train") + + x_train = as.data.frame(task$data(cols = task$feature_names)) + times = task$truth()[,1] + delta = task$truth()[,2] # delta => status + + # need these for predict + self$state$train_data = list(x_train = x_train, times = times, delta = delta) + + invoke( + BART::mc.surv.bart, + x.train = x_train, + times = times, + delta = delta, + .args = pars + ) + }, + + .predict = function(task) { + # get parameters with tag "predict" + pars = self$param_set$get_values(tags = "predict") + + # get newdata and ensure same ordering in train and predict + x_test = as.data.frame(ordered_features(task, self)) + + # transform data to be suitable for BART survival analysis (needs train data) + train_data = self$state$train_data + + # subset parameters to use in `surv.pre.bart` + pars_pre = pars[names(pars) %in% c('K', 'events', 'ztimes', 'zdelta')] + + trans_data = invoke( + BART::surv.pre.bart, + times = train_data$times, + delta = train_data$delta, + x.train = train_data$x_train, + x.test = x_test, + .args = pars_pre + ) + + # subset parameters to use in `predict` + pars_pred = pars[names(pars) %in% c('mc.cores', 'nice')] + + pred_fun = function() { + invoke( + predict, + self$model, + newdata = trans_data$tx.test, + .args = pars_pred + ) + } + + # don't print C++ generated info during prediction + if (pars$quiet) { + capture.output({ + pred = pred_fun() + }) + } else { + pred = pred_fun() + } + + # Build survival matrix using the mean posterior estimates of the survival + # function, see page 34-35 in Sparapani (2021) for more details + + # Number of test observations + N = task$nrow + self$state$N = N + # Number of unique times + K = pred$K + self$state$K = K + # Number of posterior draws + self$state$M = nrow(pred$surv.test) + + # save the full posterior survival matrix + self$state$surv_test = pred$surv.test + + # create mean posterior survival matrix + surv = matrix(nrow = N, ncol = K) # obs x times + # check: output is flattened (column means) + stopifnot(length(pred$surv.test.mean) == N * K) + + for (i in 1:N) { + # every K-size vector contains the mean survival estimates of the + # i-th test observation + indxs = ((i-1)*K + 1):(i*K) + surv[i,] = pred$surv.test.mean[indxs] + } + + mlr3proba::.surv_return(times = pred$times, surv = surv) + } + ) + ) +) + +.extralrns_dict$add("surv.bart", LearnerSurvLearnerSurvBART) diff --git a/man/mlr_learners_surv.bart.Rd b/man/mlr_learners_surv.bart.Rd new file mode 100644 index 000000000..90354c4c8 --- /dev/null +++ b/man/mlr_learners_surv.bart.Rd @@ -0,0 +1,280 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/learner_BART_surv_bart.R +\name{mlr_learners_surv.bart} +\alias{mlr_learners_surv.bart} +\alias{LearnerSurvLearnerSurvBART} +\title{Survival Bayesian Additive Regression Trees Learner} +\description{ +Fits a Bayesian Additive Regression Trees (BART) learner to right-censored +survival data. +For prediction, we return the mean posterior estimates of the survival +function and the corresponding \code{crank} (expected mortality) using +\link[mlr3proba:dot-surv_return]{mlr3proba::.surv_return}. +The full posterior estimates are currently stored in the +\code{learner$state$surv_test} slot, along with the number of test observations +\code{N}, number of unique times in the train set \code{K} and number of posterior +draws \code{M}. +See example for more details. + +Calls \code{\link[BART:surv.bart]{BART::mc.surv.bart()}} from \CRANpkg{BART}. +Based on BART version 2.9.4 (2023-03-25). +} +\section{Custom mlr3 defaults}{ + +\itemize{ +\item \code{mc.cores} (in general use as many as possible if no issues arise): +\itemize{ +\item Actual default: 2 +\item Adjusted value: 1 +\item Reason for change: May conflict with parallelization via \CRANpkg{future}. +} +} +} + +\section{Initial parameter values}{ + +\itemize{ +\item \code{quiet} +\itemize{ +\item Default is \code{TRUE}, to suppress messages generated by the wrapped C++ code +during prediction. +} +} +} + +\section{Dictionary}{ + +This \link{Learner} can be instantiated via the \link[mlr3misc:Dictionary]{dictionary} \link{mlr_learners} or with the associated sugar function \code{\link[=lrn]{lrn()}}: + +\if{html}{\out{
}}\preformatted{mlr_learners$get("surv.bart") +lrn("surv.bart") +}\if{html}{\out{
}} +} + +\section{Meta Information}{ + +\itemize{ +\item Task type: \dQuote{surv} +\item Predict Types: \dQuote{crank}, \dQuote{distr} +\item Feature Types: \dQuote{logical}, \dQuote{integer}, \dQuote{numeric}, \dQuote{factor} +\item Required Packages: \CRANpkg{mlr3}, \CRANpkg{mlr3proba}, \CRANpkg{BART} +} +} + +\section{Parameters}{ +\tabular{lllll}{ + Id \tab Type \tab Default \tab Levels \tab Range \cr + K \tab untyped \tab \tab \tab - \cr + events \tab untyped \tab \tab \tab - \cr + ztimes \tab untyped \tab \tab \tab - \cr + zdelta \tab untyped \tab \tab \tab - \cr + sparse \tab logical \tab FALSE \tab TRUE, FALSE \tab - \cr + theta \tab numeric \tab 0 \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr + omega \tab numeric \tab 1 \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr + a \tab numeric \tab 0.5 \tab \tab \eqn{[0.5, 1]}{[0.5, 1]} \cr + b \tab numeric \tab 1 \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr + augment \tab logical \tab FALSE \tab TRUE, FALSE \tab - \cr + rho \tab numeric \tab NULL \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr + usequants \tab logical \tab FALSE \tab TRUE, FALSE \tab - \cr + rm.const \tab logical \tab TRUE \tab TRUE, FALSE \tab - \cr + type \tab character \tab pbart \tab pbart, lbart \tab - \cr + ntype \tab integer \tab - \tab \tab \eqn{[1, 3]}{[1, 3]} \cr + k \tab numeric \tab 2 \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr + power \tab numeric \tab 2 \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr + base \tab numeric \tab 0.95 \tab \tab \eqn{[0, 1]}{[0, 1]} \cr + offset \tab numeric \tab NULL \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr + ntree \tab integer \tab 50 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + numcut \tab integer \tab 100 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + ndpost \tab integer \tab 1000 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + nskip \tab integer \tab 250 \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr + keepevery \tab integer \tab 10 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + printevery \tab integer \tab 100 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + seed \tab integer \tab 99 \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr + mc.cores \tab integer \tab 2 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + nice \tab integer \tab 19 \tab \tab \eqn{[0, 19]}{[0, 19]} \cr + openmp \tab logical \tab TRUE \tab TRUE, FALSE \tab - \cr + quiet \tab logical \tab TRUE \tab TRUE, FALSE \tab - \cr +} +} + +\examples{ +library(mlr3proba) +library(dplyr) +library(ggplot2) + +learner = lrn("surv.bart", nskip = 10, ndpost = 20, keepevery = 2) +task = tsk("lung") +task$missings() # has missing values + +# split to train and test sets +set.seed(42) +part = partition(task) + +# Train +learner$train(task, row_ids = part$train) + +# Importance: average number of times a feature has been used in the trees +learner$importance() + +# Test +p = learner$predict(task, row_ids = part$test) +p$score() # C-index + +# Mean survival probabilities for the first 3 patients at given time points +p$distr$survival(times = c(1,50,150))[,1:3] + +# number of posterior draws +M = learner$state$M +stopifnot(M == 20) +# number of test observations +N = learner$state$N +stopifnot(N == length(part$test)) +# number of unique time points in the train set +K = learner$state$K +stopifnot(K == length(task$unique_times(rows = part$train))) +# the actual times are also available in the `$model` slot: +head(learner$model$times) + +# Full posterior prediction matrix +surv_test = learner$state$surv_test +stopifnot(all(dim(surv_test) == c(M, K * N))) + +# Posterior survival function estimates for the 1st test patient for all +# time points (from the train set) - see Sparapani (2021), pages 34-35 +post_surv = surv_test[, 1:K] + +# For every time point, get the median survival estimate as well as +# the lower and upper bounds of the 95\% quantile credible interval +surv_data = post_surv \%>\% + as.data.frame() \%>\% + `colnames<-` (learner$model$times) \%>\% + summarise(across(everything(), list( + median = ~ median(.), + low_qi = ~ quantile(., 0.025), + high_qi = ~ quantile(., 0.975) + ))) \%>\% + tidyr::pivot_longer( + cols = everything(), + names_to = c("times", ".value"), + names_pattern = "(^[^_]+)_(.*)" # everything until the first underscore + ) \%>\% + mutate(times = as.numeric(times)) +surv_data + +# Draw a survival curve for the first patient in the test set with +# uncertainty quantified +surv_data \%>\% + ggplot2::ggplot(aes(x = times, y = median)) + + geom_step(col = 'black') + + xlab('Time (Days)') + + ylab('Survival Probability') + + geom_ribbon(aes(ymin = low_qi, ymax = high_qi), alpha = 0.3) + + theme_bw() +} +\references{ +Sparapani, Rodney, Spanbauer, Charles, McCulloch, Robert (2021). +\dQuote{Nonparametric machine learning and efficient computation with bayesian additive regression trees: the BART R package.} +\emph{Journal of Statistical Software}, \bold{97}, 1--66. + +Chipman, A H, George, I E, McCulloch, E R (2010). +\dQuote{BART: Bayesian additive regression trees.} +\emph{The Annals of Applied Statistics}, \bold{4}(1), 266--298. +} +\seealso{ +\itemize{ +\item \link[mlr3misc:Dictionary]{Dictionary} of \link[mlr3:Learner]{Learners}: \link[mlr3:mlr_learners]{mlr3::mlr_learners}. +\item \code{as.data.table(mlr_learners)} for a table of available \link[=Learner]{Learners} in the running session (depending on the loaded packages). +\item Chapter in the \href{https://mlr3book.mlr-org.com/}{mlr3book}: \url{https://mlr3book.mlr-org.com/basics.html#learners} +\item \CRANpkg{mlr3learners} for a selection of recommended learners. +\item \CRANpkg{mlr3cluster} for unsupervised clustering learners. +\item \CRANpkg{mlr3pipelines} to combine learners with pre- and postprocessing steps. +\item \CRANpkg{mlr3tuning} for tuning of hyperparameters, \CRANpkg{mlr3tuningspaces} for established default tuning spaces. +} +} +\author{ +bblodfon +} +\section{Super classes}{ +\code{\link[mlr3:Learner]{mlr3::Learner}} -> \code{\link[mlr3proba:LearnerSurv]{mlr3proba::LearnerSurv}} -> \code{LearnerSurvLearnerSurvBART} +} +\section{Methods}{ +\subsection{Public methods}{ +\itemize{ +\item \href{#method-LearnerSurvLearnerSurvBART-new}{\code{LearnerSurvLearnerSurvBART$new()}} +\item \href{#method-LearnerSurvLearnerSurvBART-importance}{\code{LearnerSurvLearnerSurvBART$importance()}} +\item \href{#method-LearnerSurvLearnerSurvBART-clone}{\code{LearnerSurvLearnerSurvBART$clone()}} +} +} +\if{html}{\out{ +
Inherited methods + +
+}} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-LearnerSurvLearnerSurvBART-new}{}}} +\subsection{Method \code{new()}}{ +Creates a new instance of this \link[R6:R6Class]{R6} class. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{LearnerSurvLearnerSurvBART$new()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-LearnerSurvLearnerSurvBART-importance}{}}} +\subsection{Method \code{importance()}}{ +Two types of importance scores are supported: +\enumerate{ +\item The mean selection probability of each feature in the trees, +extracted from the slot \code{varprob.mean}. +If \code{sparse = FALSE} (default), this is a fixed constant. +Recommended to use this option when \code{sparse = TRUE}. +\item The observed count of each feature in the trees, extracted from the +slot \code{varcount.mean}. +This is the default importance scores. +} + +In both cases, higher values signify more important variables. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{LearnerSurvLearnerSurvBART$importance(type = "count")}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{type}}{Can be either \code{count} or \code{prob}.} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +Named \code{numeric()}. +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-LearnerSurvLearnerSurvBART-clone}{}}} +\subsection{Method \code{clone()}}{ +The objects of this class are cloneable with this method. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{LearnerSurvLearnerSurvBART$clone(deep = FALSE)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{deep}}{Whether to make a deep clone.} +} +\if{html}{\out{
}} +} +} +} diff --git a/tests/testthat/test_BART_surv_bart.R b/tests/testthat/test_BART_surv_bart.R new file mode 100644 index 000000000..cc1ecec49 --- /dev/null +++ b/tests/testthat/test_BART_surv_bart.R @@ -0,0 +1,8 @@ +test_that("autotest", { + learner = lrn("surv.bart", nskip = 1, ndpost = 1, keepevery = 1) + expect_learner(learner) + # `feat_all` task has a factor which is split to two variables which give two + # different importance scores and not just once, so it fails (but it's okay) + result = run_autotest(learner, exclude = "feat_all") + expect_true(result, info = result$error) +}) diff --git a/tests/testthat/test_paramtest_BART_surv_bart.R b/tests/testthat/test_paramtest_BART_surv_bart.R new file mode 100644 index 000000000..0e94a6d6e --- /dev/null +++ b/tests/testthat/test_paramtest_BART_surv_bart.R @@ -0,0 +1,36 @@ +test_that("paramtest surv.bart train", { + learner = lrn("surv.bart") + fun_list = list(BART::mc.surv.bart) + exclude = c( + "x.train", # handled internally + "y.train", # handled internally + "times", # handled internally + "delta", # handled internally + "x.test", # not used + "xinfo", # not used + "tau.num", # not used, automatically calculated from `type` and `ntype` + "id" # only for `surv.bart` + ) + + paramtest = run_paramtest(learner, fun_list, exclude, tag = "train") + expect_paramtest(paramtest) +}) + +test_that("paramset surv.bart predict", { + learner = lrn("surv.bart") + fun_list = list(BART:::predict.survbart, BART::surv.pre.bart) # nolint + exclude = c( + "object", # handled internally + "newdata", # handled internally + "openmp", # handled internally + "times", # handled internally, used in `surv.pre.bart` + "delta", # handled internally, used in `surv.pre.bart` + "x.train", # handled internally, used in `surv.pre.bart` + "x.test", # handled internally, used in `surv.pre.bart` + "nice", # handled internally + "quiet" # added to suppress print messages + ) + + paramtest = run_paramtest(learner, fun_list, exclude, tag = "predict") + expect_paramtest(paramtest) +}) From 792b7684e8d10ea7c0e718ea1aa6fdd538a4efce Mon Sep 17 00:00:00 2001 From: john Date: Thu, 17 Aug 2023 13:22:15 +0200 Subject: [PATCH 03/47] fix dnnsurv parameter test --- ...test_paramtest_survivalmodels_surv_dnnsurv.R | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/testthat/test_paramtest_survivalmodels_surv_dnnsurv.R b/tests/testthat/test_paramtest_survivalmodels_surv_dnnsurv.R index 5212d4499..8daaec6aa 100644 --- a/tests/testthat/test_paramtest_survivalmodels_surv_dnnsurv.R +++ b/tests/testthat/test_paramtest_survivalmodels_surv_dnnsurv.R @@ -9,7 +9,22 @@ test_that("paramtest surv.dnnsurv train", { "time_variable", # handled internally "status_variable", # handled internally "x", # unused - "y" # unused + "y", # unused + "schedule_decay", # not used + "rho", # handled internally + "global_clipnorm", # handled internally + "use_ema", # handled internally + "ema_momentum", # handled internally + "ema_overwrite_frequency", # handled internally + "jit_compile", # handled internally + "initial_accumultator_value", # handled internally + "amsgrad", # handled internally + "lr_power", # handled internally + "l1_regularization_strength", # handled internally + "l2_regularization_strength", # handled internally + "l2_shrinkage_regularization_strength", # handled internally + "beta", # handled internally + "centered" # handled internally ) paramtest = run_paramtest(learner, fun_list, exclude, tag = "train") From 33b9269e7ea7f3d942bfa22570de4094ec9f8406 Mon Sep 17 00:00:00 2001 From: john Date: Thu, 17 Aug 2023 15:32:22 +0200 Subject: [PATCH 04/47] add BART to 'Suggests' --- DESCRIPTION | 1 + 1 file changed, 1 insertion(+) diff --git a/DESCRIPTION b/DESCRIPTION index feb4914ef..650b85da0 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -36,6 +36,7 @@ Suggests: aorsf (>= 0.0.5), actuar, apcluster, + BART (>= 2.9.4), C50, coin, CoxBoost, From 322f178d02307354e1d6fd6bd294f25a43ac94d3 Mon Sep 17 00:00:00 2001 From: john Date: Thu, 17 Aug 2023 16:49:05 +0200 Subject: [PATCH 05/47] various small fixes --- R/learner_BART_surv_bart.R | 9 +++++---- vignettes/extending.Rmd | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 429d4d536..ae9a6a3af 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -38,6 +38,7 @@ #' @examples #' library(mlr3proba) #' library(dplyr) +#' library(tidyr) #' library(ggplot2) #' #' learner = lrn("surv.bart", nskip = 10, ndpost = 20, keepevery = 2) @@ -91,7 +92,7 @@ #' low_qi = ~ quantile(., 0.025), #' high_qi = ~ quantile(., 0.975) #' ))) %>% -#' tidyr::pivot_longer( +#' pivot_longer( #' cols = everything(), #' names_to = c("times", ".value"), #' names_pattern = "(^[^_]+)_(.*)" # everything until the first underscore @@ -102,7 +103,7 @@ #' # Draw a survival curve for the first patient in the test set with #' # uncertainty quantified #' surv_data %>% -#' ggplot2::ggplot(aes(x = times, y = median)) + +#' ggplot(aes(x = times, y = median)) + #' geom_step(col = 'black') + #' xlab('Time (Days)') + #' ylab('Survival Probability') + @@ -111,7 +112,7 @@ #' @export delayedAssign( "LearnerSurvLearnerSurvBART", R6Class("LearnerSurvLearnerSurvBART", - inherit = LearnerSurv, + inherit = mlr3proba::LearnerSurv, public = list( #' @description #' Creates a new instance of this [R6][R6::R6Class] class. @@ -249,7 +250,7 @@ delayedAssign( # don't print C++ generated info during prediction if (pars$quiet) { - capture.output({ + utils::capture.output({ pred = pred_fun() }) } else { diff --git a/vignettes/extending.Rmd b/vignettes/extending.Rmd index c3fb516ce..52ea7dc97 100644 --- a/vignettes/extending.Rmd +++ b/vignettes/extending.Rmd @@ -316,7 +316,7 @@ test_that("autotest", { learner = LearnerRegrRpartSimple$new() # basic learner properties - expect_learner(learner, check_man = FALSE) + expect_learner(learner) # you can skip tests using the `exclude` argument result = run_autotest(learner) From 33bc17101f2eb947d502f97fdce12118847cd263 Mon Sep 17 00:00:00 2001 From: john Date: Fri, 18 Aug 2023 09:25:03 +0200 Subject: [PATCH 06/47] add more libraries to Suggests to run new BART example --- DESCRIPTION | 3 +++ 1 file changed, 3 insertions(+) diff --git a/DESCRIPTION b/DESCRIPTION index 650b85da0..fd79b4700 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -43,6 +43,7 @@ Suggests: Cubist, curl, dbarts, + dplyr, distr6, earth, flexsurv, @@ -50,6 +51,7 @@ Suggests: formattable, future, gbm, + ggplot2, glmnet, gss, jsonlite, @@ -99,6 +101,7 @@ Suggests: survivalsvm, tensorflow (>= 2.0.0), testthat, + tidyr, xgboost Remotes: alan-turing-institute/distr6, From 1871f0fd171f7184b35ba5b579c5583f5833b128 Mon Sep 17 00:00:00 2001 From: john Date: Fri, 18 Aug 2023 09:25:41 +0200 Subject: [PATCH 07/47] update doc --- man/mlr_learners_surv.bart.Rd | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/man/mlr_learners_surv.bart.Rd b/man/mlr_learners_surv.bart.Rd index 90354c4c8..f5b7c7552 100644 --- a/man/mlr_learners_surv.bart.Rd +++ b/man/mlr_learners_surv.bart.Rd @@ -100,6 +100,7 @@ lrn("surv.bart") \examples{ library(mlr3proba) library(dplyr) +library(tidyr) library(ggplot2) learner = lrn("surv.bart", nskip = 10, ndpost = 20, keepevery = 2) @@ -153,7 +154,7 @@ surv_data = post_surv \%>\% low_qi = ~ quantile(., 0.025), high_qi = ~ quantile(., 0.975) ))) \%>\% - tidyr::pivot_longer( + pivot_longer( cols = everything(), names_to = c("times", ".value"), names_pattern = "(^[^_]+)_(.*)" # everything until the first underscore @@ -164,7 +165,7 @@ surv_data # Draw a survival curve for the first patient in the test set with # uncertainty quantified surv_data \%>\% - ggplot2::ggplot(aes(x = times, y = median)) + + ggplot(aes(x = times, y = median)) + geom_step(col = 'black') + xlab('Time (Days)') + ylab('Survival Probability') + From 989ff509e20740832fb9324416ede3eec26e7f65 Mon Sep 17 00:00:00 2001 From: John Zobolas Date: Mon, 21 Aug 2023 12:08:58 +0200 Subject: [PATCH 08/47] Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer --- R/learner_BART_surv_bart.R | 1 - 1 file changed, 1 deletion(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index ae9a6a3af..2071de68b 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -15,7 +15,6 @@ #' See example for more details. #' #' Calls [BART::mc.surv.bart()] from \CRANpkg{BART}. -#' Based on BART version 2.9.4 (2023-03-25). #' #' @section Custom mlr3 defaults: #' - `mc.cores` (in general use as many as possible if no issues arise): From 4d5ea9d2a544800bb8d2584e6a1bde03729dab32 Mon Sep 17 00:00:00 2001 From: John Zobolas Date: Mon, 21 Aug 2023 12:09:13 +0200 Subject: [PATCH 09/47] Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer --- R/learner_BART_surv_bart.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 2071de68b..f3aa5059f 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -17,7 +17,7 @@ #' Calls [BART::mc.surv.bart()] from \CRANpkg{BART}. #' #' @section Custom mlr3 defaults: -#' - `mc.cores` (in general use as many as possible if no issues arise): +#' - `mc.cores` is initialized to 1 to avoid threading conflicts with \CRANpkg{future}. #' - Actual default: 2 #' - Adjusted value: 1 #' - Reason for change: May conflict with parallelization via \CRANpkg{future}. From b82b1edafa374f74a8c4f5af0538832bb3c4a69f Mon Sep 17 00:00:00 2001 From: John Zobolas Date: Mon, 21 Aug 2023 12:09:20 +0200 Subject: [PATCH 10/47] Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer --- R/learner_BART_surv_bart.R | 1 - 1 file changed, 1 deletion(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index f3aa5059f..63911ba86 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -18,7 +18,6 @@ #' #' @section Custom mlr3 defaults: #' - `mc.cores` is initialized to 1 to avoid threading conflicts with \CRANpkg{future}. -#' - Actual default: 2 #' - Adjusted value: 1 #' - Reason for change: May conflict with parallelization via \CRANpkg{future}. #' From 3722e3cdc45824ef687c1c434f32e2e03f5dd32a Mon Sep 17 00:00:00 2001 From: John Zobolas Date: Mon, 21 Aug 2023 12:09:27 +0200 Subject: [PATCH 11/47] Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer --- R/learner_BART_surv_bart.R | 1 - 1 file changed, 1 deletion(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 63911ba86..ea0b4b0a4 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -18,7 +18,6 @@ #' #' @section Custom mlr3 defaults: #' - `mc.cores` is initialized to 1 to avoid threading conflicts with \CRANpkg{future}. -#' - Adjusted value: 1 #' - Reason for change: May conflict with parallelization via \CRANpkg{future}. #' #' @section Initial parameter values: From 387a41e5c8ca47d10cb120d430e8f2e569997b2f Mon Sep 17 00:00:00 2001 From: John Zobolas Date: Mon, 21 Aug 2023 12:09:49 +0200 Subject: [PATCH 12/47] Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer --- R/learner_BART_surv_bart.R | 1 - 1 file changed, 1 deletion(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index ea0b4b0a4..fe9be0ad9 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -18,7 +18,6 @@ #' #' @section Custom mlr3 defaults: #' - `mc.cores` is initialized to 1 to avoid threading conflicts with \CRANpkg{future}. -#' - Reason for change: May conflict with parallelization via \CRANpkg{future}. #' #' @section Initial parameter values: #' - `quiet` From f18d2d9e419889781959a2119b07ef0575a78747 Mon Sep 17 00:00:00 2001 From: John Zobolas Date: Mon, 21 Aug 2023 12:52:14 +0200 Subject: [PATCH 13/47] Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer --- R/learner_BART_surv_bart.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index fe9be0ad9..a9f9e55d0 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -176,7 +176,7 @@ delayedAssign( #' @param type Can be either `count` or `prob`. #' @return Named `numeric()`. importance = function(type = 'count') { - if (!type %in% c('prob', 'count')) stopf("type can be only 'prob' or 'count'") +assert_choice(type, c("prob", "count")) if (is.null(self$model)) { stopf("No model stored") From c123fe15c5a4071f16839e757771edabe14aa18a Mon Sep 17 00:00:00 2001 From: John Zobolas Date: Mon, 21 Aug 2023 12:52:29 +0200 Subject: [PATCH 14/47] Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer --- R/learner_BART_surv_bart.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index a9f9e55d0..b19295852 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -20,7 +20,7 @@ #' - `mc.cores` is initialized to 1 to avoid threading conflicts with \CRANpkg{future}. #' #' @section Initial parameter values: -#' - `quiet` +#' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is `TRUE` by default. #' - Default is `TRUE`, to suppress messages generated by the wrapped C++ code #' during prediction. #' From 38dbed2d74b6ba9145792c386cbf36e8be905d1b Mon Sep 17 00:00:00 2001 From: John Zobolas Date: Mon, 21 Aug 2023 12:52:39 +0200 Subject: [PATCH 15/47] Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer --- R/learner_BART_surv_bart.R | 1 - 1 file changed, 1 deletion(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index b19295852..e3601f556 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -21,7 +21,6 @@ #' #' @section Initial parameter values: #' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is `TRUE` by default. -#' - Default is `TRUE`, to suppress messages generated by the wrapped C++ code #' during prediction. #' #' @templateVar id surv.bart From 6ab35703b284daee5be68f638f96dea334f88c8a Mon Sep 17 00:00:00 2001 From: John Zobolas Date: Mon, 21 Aug 2023 12:52:48 +0200 Subject: [PATCH 16/47] Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer --- R/learner_BART_surv_bart.R | 1 - 1 file changed, 1 deletion(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index e3601f556..2cd1d170c 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -21,7 +21,6 @@ #' #' @section Initial parameter values: #' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is `TRUE` by default. -#' during prediction. #' #' @templateVar id surv.bart #' @template learner From ffd61c600d08fbd56a116fbc4d2c3b21e03a1a75 Mon Sep 17 00:00:00 2001 From: John Zobolas Date: Mon, 21 Aug 2023 12:54:34 +0200 Subject: [PATCH 17/47] Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer --- R/learner_BART_surv_bart.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 2cd1d170c..1ab876ac1 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -173,7 +173,7 @@ delayedAssign( #' #' @param type Can be either `count` or `prob`. #' @return Named `numeric()`. - importance = function(type = 'count') { + importance = function(type = "count") { assert_choice(type, c("prob", "count")) if (is.null(self$model)) { From 8af98e394fed137cb65026e9cbf5837bf4daf0d3 Mon Sep 17 00:00:00 2001 From: john Date: Mon, 21 Aug 2023 13:07:05 +0200 Subject: [PATCH 18/47] change K parameter type --- R/learner_BART_surv_bart.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 1ab876ac1..b32561cb7 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -112,7 +112,7 @@ delayedAssign( #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function() { param_set = ps( - K = p_uty(default = NULL, tags = c("train", "predict")), + K = p_dbl(default = NULL, special_vals = list(NULL), lower = 1, tags = c("train", "predict")), events = p_uty(default = NULL, tags = c("train", "predict")), ztimes = p_uty(default = NULL, tags = c("train", "predict")), zdelta = p_uty(default = NULL, tags = c("train", "predict")), From f795951022330eae8da5fb4247ad25af9f26b5d8 Mon Sep 17 00:00:00 2001 From: john Date: Mon, 21 Aug 2023 15:10:29 +0200 Subject: [PATCH 19/47] simplify and speed-up the creation of the survival matrix --- R/learner_BART_surv_bart.R | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index b32561cb7..88922c63c 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -266,17 +266,8 @@ assert_choice(type, c("prob", "count")) # save the full posterior survival matrix self$state$surv_test = pred$surv.test - # create mean posterior survival matrix - surv = matrix(nrow = N, ncol = K) # obs x times - # check: output is flattened (column means) - stopifnot(length(pred$surv.test.mean) == N * K) - - for (i in 1:N) { - # every K-size vector contains the mean survival estimates of the - # i-th test observation - indxs = ((i-1)*K + 1):(i*K) - surv[i,] = pred$surv.test.mean[indxs] - } + # create mean posterior survival matrix (N obs x K times) + surv = matrix(bart_p2$surv.test.mean, nrow = N, ncol = K, byrow = TRUE) mlr3proba::.surv_return(times = pred$times, surv = surv) } From 4673be744cea2a6d2df7fcc5b6d8c04659b57de0 Mon Sep 17 00:00:00 2001 From: john Date: Mon, 21 Aug 2023 15:51:17 +0200 Subject: [PATCH 20/47] add importance parameter, remove factor feature type --- R/learner_BART_surv_bart.R | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 88922c63c..9910a43a3 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -21,6 +21,7 @@ #' #' @section Initial parameter values: #' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is `TRUE` by default. +#' - `importance` allows to choose the type of importance. Default is `count`, see `importance()` for more details. #' #' @templateVar id surv.bart #' @template learner @@ -141,16 +142,17 @@ delayedAssign( mc.cores = p_int(default = 2L, lower = 1L, tags = c("train", "predict")), nice = p_int(default = 19L, lower = 0L, upper = 19L, tags = c("train", "predict")), openmp = p_lgl(default = TRUE, tags = "predict"), - quiet = p_lgl(default = TRUE, tags = "predict") + quiet = p_lgl(default = TRUE, tags = "predict"), + importance = p_fct(default = "count", levels = c("count", "prob"), tags = "imp") ) # custom defaults - param_set$values = list(mc.cores = 1, quiet = TRUE) + param_set$values = list(mc.cores = 1, quiet = TRUE, importance = "count") super$initialize( id = "surv.bart", packages = "BART", - feature_types = c("logical", "integer", "numeric", "factor"), + feature_types = c("logical", "integer", "numeric"), predict_types = c("crank", "distr"), param_set = param_set, properties = c("importance", "missings"), @@ -160,27 +162,27 @@ delayedAssign( }, #' @description - #' Two types of importance scores are supported: - #' 1. The mean selection probability of each feature in the trees, + #' Two types of importance scores are supported based on the initial value + #' for the parameter `importance`: + #' 1. `prob`: The mean selection probability of each feature in the trees, #' extracted from the slot `varprob.mean`. #' If `sparse = FALSE` (default), this is a fixed constant. #' Recommended to use this option when `sparse = TRUE`. - #' 2. The observed count of each feature in the trees, extracted from the - #' slot `varcount.mean`. + #' 2. `count`: The observed count of each feature in the trees, extracted + #' from the slot `varcount.mean`. #' This is the default importance scores. #' #' In both cases, higher values signify more important variables. #' - #' @param type Can be either `count` or `prob`. #' @return Named `numeric()`. - importance = function(type = "count") { -assert_choice(type, c("prob", "count")) - + importance = function() { if (is.null(self$model)) { stopf("No model stored") } - if (type == 'prob') { + pars = self$param_set$get_values(tags = 'imp') + + if (pars$importance == 'prob') { sort(self$model$varprob.mean[-1], decreasing = T) } else { sort(self$model$varcount.mean[-1], decreasing = T) From b1af8e9989c6fa1277685a1a5be97502d671c1dd Mon Sep 17 00:00:00 2001 From: john Date: Mon, 21 Aug 2023 16:10:07 +0200 Subject: [PATCH 21/47] change tag for importance to train + fix small bug --- R/learner_BART_surv_bart.R | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 9910a43a3..323ade75a 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -143,7 +143,7 @@ delayedAssign( nice = p_int(default = 19L, lower = 0L, upper = 19L, tags = c("train", "predict")), openmp = p_lgl(default = TRUE, tags = "predict"), quiet = p_lgl(default = TRUE, tags = "predict"), - importance = p_fct(default = "count", levels = c("count", "prob"), tags = "imp") + importance = p_fct(default = "count", levels = c("count", "prob"), tags = "train") ) # custom defaults @@ -180,7 +180,7 @@ delayedAssign( stopf("No model stored") } - pars = self$param_set$get_values(tags = 'imp') + pars = self$param_set$get_values(tags = 'train') if (pars$importance == 'prob') { sort(self$model$varprob.mean[-1], decreasing = T) @@ -193,6 +193,7 @@ delayedAssign( private = list( .train = function(task) { pars = self$param_set$get_values(tags = "train") + pars$importance = NULL # not used in the train function x_train = as.data.frame(task$data(cols = task$feature_names)) times = task$truth()[,1] @@ -269,7 +270,7 @@ delayedAssign( self$state$surv_test = pred$surv.test # create mean posterior survival matrix (N obs x K times) - surv = matrix(bart_p2$surv.test.mean, nrow = N, ncol = K, byrow = TRUE) + surv = matrix(pred$surv.test.mean, nrow = N, ncol = K, byrow = TRUE) mlr3proba::.surv_return(times = pred$times, surv = surv) } From ef4256419a195bca86fbadccb8ff15e2fb0065ff Mon Sep 17 00:00:00 2001 From: john Date: Mon, 21 Aug 2023 16:35:46 +0200 Subject: [PATCH 22/47] update doc --- man/mlr_learners_surv.bart.Rd | 38 +++++++++++------------------------ 1 file changed, 12 insertions(+), 26 deletions(-) diff --git a/man/mlr_learners_surv.bart.Rd b/man/mlr_learners_surv.bart.Rd index f5b7c7552..bd043fc2a 100644 --- a/man/mlr_learners_surv.bart.Rd +++ b/man/mlr_learners_surv.bart.Rd @@ -17,28 +17,19 @@ draws \code{M}. See example for more details. Calls \code{\link[BART:surv.bart]{BART::mc.surv.bart()}} from \CRANpkg{BART}. -Based on BART version 2.9.4 (2023-03-25). } \section{Custom mlr3 defaults}{ \itemize{ -\item \code{mc.cores} (in general use as many as possible if no issues arise): -\itemize{ -\item Actual default: 2 -\item Adjusted value: 1 -\item Reason for change: May conflict with parallelization via \CRANpkg{future}. -} +\item \code{mc.cores} is initialized to 1 to avoid threading conflicts with \CRANpkg{future}. } } \section{Initial parameter values}{ \itemize{ -\item \code{quiet} -\itemize{ -\item Default is \code{TRUE}, to suppress messages generated by the wrapped C++ code -during prediction. -} +\item \code{quiet} allows to suppress messages generated by the wrapped C++ code. Is \code{TRUE} by default. +\item \code{importance} allows to choose the type of importance. Default is \code{count}, see \code{importance()} for more details. } } @@ -56,7 +47,7 @@ lrn("surv.bart") \itemize{ \item Task type: \dQuote{surv} \item Predict Types: \dQuote{crank}, \dQuote{distr} -\item Feature Types: \dQuote{logical}, \dQuote{integer}, \dQuote{numeric}, \dQuote{factor} +\item Feature Types: \dQuote{logical}, \dQuote{integer}, \dQuote{numeric} \item Required Packages: \CRANpkg{mlr3}, \CRANpkg{mlr3proba}, \CRANpkg{BART} } } @@ -64,7 +55,7 @@ lrn("surv.bart") \section{Parameters}{ \tabular{lllll}{ Id \tab Type \tab Default \tab Levels \tab Range \cr - K \tab untyped \tab \tab \tab - \cr + K \tab numeric \tab NULL \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr events \tab untyped \tab \tab \tab - \cr ztimes \tab untyped \tab \tab \tab - \cr zdelta \tab untyped \tab \tab \tab - \cr @@ -94,6 +85,7 @@ lrn("surv.bart") nice \tab integer \tab 19 \tab \tab \eqn{[0, 19]}{[0, 19]} \cr openmp \tab logical \tab TRUE \tab TRUE, FALSE \tab - \cr quiet \tab logical \tab TRUE \tab TRUE, FALSE \tab - \cr + importance \tab character \tab count \tab count, prob \tab - \cr } } @@ -234,29 +226,23 @@ Creates a new instance of this \link[R6:R6Class]{R6} class. \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-LearnerSurvLearnerSurvBART-importance}{}}} \subsection{Method \code{importance()}}{ -Two types of importance scores are supported: +Two types of importance scores are supported based on the initial value +for the parameter \code{importance}: \enumerate{ -\item The mean selection probability of each feature in the trees, +\item \code{prob}: The mean selection probability of each feature in the trees, extracted from the slot \code{varprob.mean}. If \code{sparse = FALSE} (default), this is a fixed constant. Recommended to use this option when \code{sparse = TRUE}. -\item The observed count of each feature in the trees, extracted from the -slot \code{varcount.mean}. +\item \code{count}: The observed count of each feature in the trees, extracted +from the slot \code{varcount.mean}. This is the default importance scores. } In both cases, higher values signify more important variables. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{LearnerSurvLearnerSurvBART$importance(type = "count")}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{LearnerSurvLearnerSurvBART$importance()}\if{html}{\out{
}} } -\subsection{Arguments}{ -\if{html}{\out{
}} -\describe{ -\item{\code{type}}{Can be either \code{count} or \code{prob}.} -} -\if{html}{\out{
}} -} \subsection{Returns}{ Named \code{numeric()}. } From 031065c99d24d3c9db3c59a6d962799bae482f07 Mon Sep 17 00:00:00 2001 From: john Date: Mon, 21 Aug 2023 16:36:09 +0200 Subject: [PATCH 23/47] fix tests --- tests/testthat/test_BART_surv_bart.R | 4 +--- tests/testthat/test_paramtest_BART_surv_bart.R | 3 ++- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/testthat/test_BART_surv_bart.R b/tests/testthat/test_BART_surv_bart.R index cc1ecec49..60c25a568 100644 --- a/tests/testthat/test_BART_surv_bart.R +++ b/tests/testthat/test_BART_surv_bart.R @@ -1,8 +1,6 @@ test_that("autotest", { learner = lrn("surv.bart", nskip = 1, ndpost = 1, keepevery = 1) expect_learner(learner) - # `feat_all` task has a factor which is split to two variables which give two - # different importance scores and not just once, so it fails (but it's okay) - result = run_autotest(learner, exclude = "feat_all") + result = run_autotest(learner) expect_true(result, info = result$error) }) diff --git a/tests/testthat/test_paramtest_BART_surv_bart.R b/tests/testthat/test_paramtest_BART_surv_bart.R index 0e94a6d6e..d0dff88d5 100644 --- a/tests/testthat/test_paramtest_BART_surv_bart.R +++ b/tests/testthat/test_paramtest_BART_surv_bart.R @@ -9,7 +9,8 @@ test_that("paramtest surv.bart train", { "x.test", # not used "xinfo", # not used "tau.num", # not used, automatically calculated from `type` and `ntype` - "id" # only for `surv.bart` + "id", # only for `surv.bart` + "importance" # added to choose the type of importance ) paramtest = run_paramtest(learner, fun_list, exclude, tag = "train") From 16330d82346362207854098601cbc11ee1d54887 Mon Sep 17 00:00:00 2001 From: john Date: Mon, 21 Aug 2023 16:50:51 +0200 Subject: [PATCH 24/47] change section name --- R/learner_BART_surv_bart.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 323ade75a..5d3b11aad 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -19,7 +19,7 @@ #' @section Custom mlr3 defaults: #' - `mc.cores` is initialized to 1 to avoid threading conflicts with \CRANpkg{future}. #' -#' @section Initial parameter values: +#' @section Custom mlr3 parameters: #' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is `TRUE` by default. #' - `importance` allows to choose the type of importance. Default is `count`, see `importance()` for more details. #' From 94b174248c426a60f649c1aa8872c08efa137b1f Mon Sep 17 00:00:00 2001 From: john Date: Mon, 21 Aug 2023 16:58:28 +0200 Subject: [PATCH 25/47] remove BART example and extra libraries --- DESCRIPTION | 3 -- R/learner_BART_surv_bart.R | 77 +------------------------------------- 2 files changed, 2 insertions(+), 78 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index fd79b4700..650b85da0 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -43,7 +43,6 @@ Suggests: Cubist, curl, dbarts, - dplyr, distr6, earth, flexsurv, @@ -51,7 +50,6 @@ Suggests: formattable, future, gbm, - ggplot2, glmnet, gss, jsonlite, @@ -101,7 +99,6 @@ Suggests: survivalsvm, tensorflow (>= 2.0.0), testthat, - tidyr, xgboost Remotes: alan-turing-institute/distr6, diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 5d3b11aad..18870c34e 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -5,6 +5,7 @@ #' @description #' Fits a Bayesian Additive Regression Trees (BART) learner to right-censored #' survival data. +#' #' For prediction, we return the mean posterior estimates of the survival #' function and the corresponding `crank` (expected mortality) using #' [mlr3proba::.surv_return]. @@ -12,7 +13,6 @@ #' `learner$state$surv_test` slot, along with the number of test observations #' `N`, number of unique times in the train set `K` and number of posterior #' draws `M`. -#' See example for more details. #' #' Calls [BART::mc.surv.bart()] from \CRANpkg{BART}. #' @@ -30,80 +30,7 @@ #' `r format_bib("sparapani2021nonparametric", "chipman2010bart")` #' #' @template seealso_learner -#' @examples -#' library(mlr3proba) -#' library(dplyr) -#' library(tidyr) -#' library(ggplot2) -#' -#' learner = lrn("surv.bart", nskip = 10, ndpost = 20, keepevery = 2) -#' task = tsk("lung") -#' task$missings() # has missing values -#' -#' # split to train and test sets -#' set.seed(42) -#' part = partition(task) -#' -#' # Train -#' learner$train(task, row_ids = part$train) -#' -#' # Importance: average number of times a feature has been used in the trees -#' learner$importance() -#' -#' # Test -#' p = learner$predict(task, row_ids = part$test) -#' p$score() # C-index -#' -#' # Mean survival probabilities for the first 3 patients at given time points -#' p$distr$survival(times = c(1,50,150))[,1:3] -#' -#' # number of posterior draws -#' M = learner$state$M -#' stopifnot(M == 20) -#' # number of test observations -#' N = learner$state$N -#' stopifnot(N == length(part$test)) -#' # number of unique time points in the train set -#' K = learner$state$K -#' stopifnot(K == length(task$unique_times(rows = part$train))) -#' # the actual times are also available in the `$model` slot: -#' head(learner$model$times) -#' -#' # Full posterior prediction matrix -#' surv_test = learner$state$surv_test -#' stopifnot(all(dim(surv_test) == c(M, K * N))) -#' -#' # Posterior survival function estimates for the 1st test patient for all -#' # time points (from the train set) - see Sparapani (2021), pages 34-35 -#' post_surv = surv_test[, 1:K] -#' -#' # For every time point, get the median survival estimate as well as -#' # the lower and upper bounds of the 95% quantile credible interval -#' surv_data = post_surv %>% -#' as.data.frame() %>% -#' `colnames<-` (learner$model$times) %>% -#' summarise(across(everything(), list( -#' median = ~ median(.), -#' low_qi = ~ quantile(., 0.025), -#' high_qi = ~ quantile(., 0.975) -#' ))) %>% -#' pivot_longer( -#' cols = everything(), -#' names_to = c("times", ".value"), -#' names_pattern = "(^[^_]+)_(.*)" # everything until the first underscore -#' ) %>% -#' mutate(times = as.numeric(times)) -#' surv_data -#' -#' # Draw a survival curve for the first patient in the test set with -#' # uncertainty quantified -#' surv_data %>% -#' ggplot(aes(x = times, y = median)) + -#' geom_step(col = 'black') + -#' xlab('Time (Days)') + -#' ylab('Survival Probability') + -#' geom_ribbon(aes(ymin = low_qi, ymax = high_qi), alpha = 0.3) + -#' theme_bw() +#' @template example #' @export delayedAssign( "LearnerSurvLearnerSurvBART", R6Class("LearnerSurvLearnerSurvBART", From e3a3700296a962ed9afd2fb7f85c9cc8397a8d38 Mon Sep 17 00:00:00 2001 From: john Date: Mon, 21 Aug 2023 22:51:26 +0200 Subject: [PATCH 26/47] return model list slot and name refactoring --- R/learner_BART_surv_bart.R | 53 ++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 18870c34e..95cc2aad8 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -10,9 +10,7 @@ #' function and the corresponding `crank` (expected mortality) using #' [mlr3proba::.surv_return]. #' The full posterior estimates are currently stored in the -#' `learner$state$surv_test` slot, along with the number of test observations -#' `N`, number of unique times in the train set `K` and number of posterior -#' draws `M`. +#' `learner$model$surv.test` slot. #' #' Calls [BART::mc.surv.bart()] from \CRANpkg{BART}. #' @@ -103,16 +101,16 @@ delayedAssign( #' #' @return Named `numeric()`. importance = function() { - if (is.null(self$model)) { + if (is.null(self$model$model)) { stopf("No model stored") } pars = self$param_set$get_values(tags = 'train') if (pars$importance == 'prob') { - sort(self$model$varprob.mean[-1], decreasing = T) + sort(self$model$model$varprob.mean[-1], decreasing = T) } else { - sort(self$model$varcount.mean[-1], decreasing = T) + sort(self$model$model$varcount.mean[-1], decreasing = T) } } ), @@ -122,19 +120,22 @@ delayedAssign( pars = self$param_set$get_values(tags = "train") pars$importance = NULL # not used in the train function - x_train = as.data.frame(task$data(cols = task$feature_names)) + x.train = as.data.frame(task$data(cols = task$feature_names)) times = task$truth()[,1] delta = task$truth()[,2] # delta => status - # need these for predict - self$state$train_data = list(x_train = x_train, times = times, delta = delta) - - invoke( - BART::mc.surv.bart, - x.train = x_train, + list( + model = invoke( + BART::mc.surv.bart, + x.train = x.train, + times = times, + delta = delta, + .args = pars + ), + # need these for predict + x.train = x.train, times = times, - delta = delta, - .args = pars + delta = delta ) }, @@ -143,20 +144,18 @@ delayedAssign( pars = self$param_set$get_values(tags = "predict") # get newdata and ensure same ordering in train and predict - x_test = as.data.frame(ordered_features(task, self)) - - # transform data to be suitable for BART survival analysis (needs train data) - train_data = self$state$train_data + x.test = as.data.frame(ordered_features(task, self)) # subset parameters to use in `surv.pre.bart` pars_pre = pars[names(pars) %in% c('K', 'events', 'ztimes', 'zdelta')] + # transform data to be suitable for BART survival analysis (needs train data) trans_data = invoke( BART::surv.pre.bart, - times = train_data$times, - delta = train_data$delta, - x.train = train_data$x_train, - x.test = x_test, + times = self$model$times, + delta = self$model$delta, + x.train = self$model$x.train, + x.test = x.test, .args = pars_pre ) @@ -166,7 +165,7 @@ delayedAssign( pred_fun = function() { invoke( predict, - self$model, + self$model$model, newdata = trans_data$tx.test, .args = pars_pred ) @@ -186,15 +185,13 @@ delayedAssign( # Number of test observations N = task$nrow - self$state$N = N # Number of unique times K = pred$K - self$state$K = K # Number of posterior draws - self$state$M = nrow(pred$surv.test) + M = nrow(pred$surv.test) # save the full posterior survival matrix - self$state$surv_test = pred$surv.test + self$model$surv.test = pred$surv.test # create mean posterior survival matrix (N obs x K times) surv = matrix(pred$surv.test.mean, nrow = N, ncol = K, byrow = TRUE) From d52c6fa66f10e0553682ef47b618758b44b3a4c5 Mon Sep 17 00:00:00 2001 From: john Date: Tue, 22 Aug 2023 09:46:58 +0200 Subject: [PATCH 27/47] update doc --- man/mlr_learners_surv.bart.Rd | 84 +++-------------------------------- 1 file changed, 7 insertions(+), 77 deletions(-) diff --git a/man/mlr_learners_surv.bart.Rd b/man/mlr_learners_surv.bart.Rd index bd043fc2a..1fb67cc0f 100644 --- a/man/mlr_learners_surv.bart.Rd +++ b/man/mlr_learners_surv.bart.Rd @@ -7,14 +7,12 @@ \description{ Fits a Bayesian Additive Regression Trees (BART) learner to right-censored survival data. + For prediction, we return the mean posterior estimates of the survival function and the corresponding \code{crank} (expected mortality) using \link[mlr3proba:dot-surv_return]{mlr3proba::.surv_return}. The full posterior estimates are currently stored in the -\code{learner$state$surv_test} slot, along with the number of test observations -\code{N}, number of unique times in the train set \code{K} and number of posterior -draws \code{M}. -See example for more details. +\code{learner$model$surv.test} slot. Calls \code{\link[BART:surv.bart]{BART::mc.surv.bart()}} from \CRANpkg{BART}. } @@ -25,7 +23,7 @@ Calls \code{\link[BART:surv.bart]{BART::mc.surv.bart()}} from \CRANpkg{BART}. } } -\section{Initial parameter values}{ +\section{Custom mlr3 parameters}{ \itemize{ \item \code{quiet} allows to suppress messages generated by the wrapped C++ code. Is \code{TRUE} by default. @@ -90,79 +88,11 @@ lrn("surv.bart") } \examples{ -library(mlr3proba) -library(dplyr) -library(tidyr) -library(ggplot2) - -learner = lrn("surv.bart", nskip = 10, ndpost = 20, keepevery = 2) -task = tsk("lung") -task$missings() # has missing values - -# split to train and test sets -set.seed(42) -part = partition(task) - -# Train -learner$train(task, row_ids = part$train) - -# Importance: average number of times a feature has been used in the trees -learner$importance() - -# Test -p = learner$predict(task, row_ids = part$test) -p$score() # C-index - -# Mean survival probabilities for the first 3 patients at given time points -p$distr$survival(times = c(1,50,150))[,1:3] - -# number of posterior draws -M = learner$state$M -stopifnot(M == 20) -# number of test observations -N = learner$state$N -stopifnot(N == length(part$test)) -# number of unique time points in the train set -K = learner$state$K -stopifnot(K == length(task$unique_times(rows = part$train))) -# the actual times are also available in the `$model` slot: -head(learner$model$times) - -# Full posterior prediction matrix -surv_test = learner$state$surv_test -stopifnot(all(dim(surv_test) == c(M, K * N))) - -# Posterior survival function estimates for the 1st test patient for all -# time points (from the train set) - see Sparapani (2021), pages 34-35 -post_surv = surv_test[, 1:K] - -# For every time point, get the median survival estimate as well as -# the lower and upper bounds of the 95\% quantile credible interval -surv_data = post_surv \%>\% - as.data.frame() \%>\% - `colnames<-` (learner$model$times) \%>\% - summarise(across(everything(), list( - median = ~ median(.), - low_qi = ~ quantile(., 0.025), - high_qi = ~ quantile(., 0.975) - ))) \%>\% - pivot_longer( - cols = everything(), - names_to = c("times", ".value"), - names_pattern = "(^[^_]+)_(.*)" # everything until the first underscore - ) \%>\% - mutate(times = as.numeric(times)) -surv_data +learner = mlr3::lrn("surv.bart") +print(learner) -# Draw a survival curve for the first patient in the test set with -# uncertainty quantified -surv_data \%>\% - ggplot(aes(x = times, y = median)) + - geom_step(col = 'black') + - xlab('Time (Days)') + - ylab('Survival Probability') + - geom_ribbon(aes(ymin = low_qi, ymax = high_qi), alpha = 0.3) + - theme_bw() +# available parameters: +learner$param_set$ids() } \references{ Sparapani, Rodney, Spanbauer, Charles, McCulloch, Robert (2021). From c8041b458a21be490b25352147d36b1d8fefd210 Mon Sep 17 00:00:00 2001 From: john Date: Fri, 8 Sep 2023 15:05:33 +0200 Subject: [PATCH 28/47] store full posterior survival array (testing version) * will work with latest version of distr6 --- R/learner_BART_surv_bart.R | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 95cc2aad8..394772930 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -180,23 +180,37 @@ delayedAssign( pred = pred_fun() } - # Build survival matrix using the mean posterior estimates of the survival - # function, see page 34-35 in Sparapani (2021) for more details - # Number of test observations N = task$nrow # Number of unique times K = pred$K + times = pred$times # Number of posterior draws M = nrow(pred$surv.test) - # save the full posterior survival matrix + # save the full posterior survival matrix and the mean for checking + # TODO: remove next two lines self$model$surv.test = pred$surv.test + self$model$surv.test.mean = pred$surv.test.mean + + # Convert full posterior survival matrix to 3D survival array + # See page 34-35 in Sparapani (2021) for more details + surv.array = aperm( + array(pred$surv.test, dim = c(M, K, N), dimnames = list(NULL, times, NULL)), + c(3, 2, 1) + ) - # create mean posterior survival matrix (N obs x K times) + # Mean posterior survival matrix (N obs x K times) surv = matrix(pred$surv.test.mean, nrow = N, ncol = K, byrow = TRUE) - mlr3proba::.surv_return(times = pred$times, surv = surv) + # get crank as expected mortality using mean posterior + pred_list = mlr3proba::.surv_return(times = times, surv = surv) + + # replace with the full survival posterior + pred_list$distr = surv.array + + # return list with crank and distr + pred_list } ) ) From f162a31cd8b5e269b12a83687168648f12bf157a Mon Sep 17 00:00:00 2001 From: john Date: Mon, 11 Sep 2023 15:07:12 +0200 Subject: [PATCH 29/47] update mlr3proba to 0.5.3 + refactoring --- DESCRIPTION | 2 +- R/learner_BART_surv_bart.R | 14 +++----------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 650b85da0..a5677f7bd 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -67,7 +67,7 @@ Suggests: mgcv, mlr3cluster, mlr3learners (>= 0.4.2), - mlr3proba, + mlr3proba (>= 0.5.3), mlr3pipelines, mvtnorm, nnet, diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 394772930..e0c0a926a 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -200,17 +200,9 @@ delayedAssign( c(3, 2, 1) ) - # Mean posterior survival matrix (N obs x K times) - surv = matrix(pred$surv.test.mean, nrow = N, ncol = K, byrow = TRUE) - - # get crank as expected mortality using mean posterior - pred_list = mlr3proba::.surv_return(times = times, surv = surv) - - # replace with the full survival posterior - pred_list$distr = surv.array - - # return list with crank and distr - pred_list + # distr => full survival array + # crank => expected mortality using the mean posterior survival matrix + mlr3proba::.surv_return(times = times, surv = surv.array, which.curve = 'mean') } ) ) From 460b0e8ddee3168fa46b0a193a249ab32b9b4f96 Mon Sep 17 00:00:00 2001 From: john Date: Mon, 18 Sep 2023 12:59:00 +0200 Subject: [PATCH 30/47] add which.curve parameter, defaults to 0.5 (median posterior) --- R/learner_BART_surv_bart.R | 72 ++++++++++++++++++++------------------ 1 file changed, 38 insertions(+), 34 deletions(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index e0c0a926a..31d8b9b6b 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -20,6 +20,7 @@ #' @section Custom mlr3 parameters: #' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is `TRUE` by default. #' - `importance` allows to choose the type of importance. Default is `count`, see `importance()` for more details. +#' - `which.curve` allows to choose which posterior draw will be used for the calculation of the `crank` prediction. By default the median posterior is used. #' #' @templateVar id surv.bart #' @template learner @@ -38,41 +39,43 @@ delayedAssign( #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function() { param_set = ps( - K = p_dbl(default = NULL, special_vals = list(NULL), lower = 1, tags = c("train", "predict")), - events = p_uty(default = NULL, tags = c("train", "predict")), - ztimes = p_uty(default = NULL, tags = c("train", "predict")), - zdelta = p_uty(default = NULL, tags = c("train", "predict")), - sparse = p_lgl(default = FALSE, tags = "train"), - theta = p_dbl(default = 0, tags = "train"), - omega = p_dbl(default = 1, tags = "train"), - a = p_dbl(default = 0.5, lower = 0.5, upper = 1, tags = "train"), - b = p_dbl(default = 1L, tags = "train"), - augment = p_lgl(default = FALSE, tags = "train"), - rho = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"), - usequants = p_lgl(default = FALSE, tags = "train"), - rm.const = p_lgl(default = TRUE, tags = "train"), - type = p_fct(levels = c("pbart", "lbart"), default = "pbart", tags = "train"), - ntype = p_int(lower = 1, upper = 3, tags = "train"), - k = p_dbl(default = 2.0, lower = 0, tags = "train"), - power = p_dbl(default = 2.0, lower = 0, tags = "train"), - base = p_dbl(default = 0.95, lower = 0, upper = 1, tags = "train"), - offset = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"), - ntree = p_int(default = 50L, lower = 1L, tags = "train"), - numcut = p_int(default = 100L, lower = 1L, tags = "train"), - ndpost = p_int(default = 1000L, lower = 1L, tags = "train"), - nskip = p_int(default = 250L, lower = 0L, tags = "train"), - keepevery = p_int(default = 10L, lower = 1L, tags = "train"), + K = p_dbl(default = NULL, special_vals = list(NULL), lower = 1, tags = c("train", "predict")), + events = p_uty(default = NULL, tags = c("train", "predict")), + ztimes = p_uty(default = NULL, tags = c("train", "predict")), + zdelta = p_uty(default = NULL, tags = c("train", "predict")), + sparse = p_lgl(default = FALSE, tags = "train"), + theta = p_dbl(default = 0, tags = "train"), + omega = p_dbl(default = 1, tags = "train"), + a = p_dbl(default = 0.5, lower = 0.5, upper = 1, tags = "train"), + b = p_dbl(default = 1L, tags = "train"), + augment = p_lgl(default = FALSE, tags = "train"), + rho = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"), + usequants = p_lgl(default = FALSE, tags = "train"), + rm.const = p_lgl(default = TRUE, tags = "train"), + type = p_fct(levels = c("pbart", "lbart"), default = "pbart", tags = "train"), + ntype = p_int(lower = 1, upper = 3, tags = "train"), + k = p_dbl(default = 2.0, lower = 0, tags = "train"), + power = p_dbl(default = 2.0, lower = 0, tags = "train"), + base = p_dbl(default = 0.95, lower = 0, upper = 1, tags = "train"), + offset = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"), + ntree = p_int(default = 50L, lower = 1L, tags = "train"), + numcut = p_int(default = 100L, lower = 1L, tags = "train"), + ndpost = p_int(default = 1000L, lower = 1L, tags = "train"), + nskip = p_int(default = 250L, lower = 0L, tags = "train"), + keepevery = p_int(default = 10L, lower = 1L, tags = "train"), printevery = p_int(default = 100L, lower = 1L, tags = "train"), - seed = p_int(default = 99L, tags = "train"), - mc.cores = p_int(default = 2L, lower = 1L, tags = c("train", "predict")), - nice = p_int(default = 19L, lower = 0L, upper = 19L, tags = c("train", "predict")), - openmp = p_lgl(default = TRUE, tags = "predict"), - quiet = p_lgl(default = TRUE, tags = "predict"), - importance = p_fct(default = "count", levels = c("count", "prob"), tags = "train") + seed = p_int(default = 99L, tags = "train"), + mc.cores = p_int(default = 2L, lower = 1L, tags = c("train", "predict")), + nice = p_int(default = 19L, lower = 0L, upper = 19L, tags = c("train", "predict")), + openmp = p_lgl(default = TRUE, tags = "predict"), + quiet = p_lgl(default = TRUE, tags = "predict"), + importance = p_fct(default = "count", levels = c("count", "prob"), tags = "train"), + which.curve = p_uty(tags = "predict") ) # custom defaults - param_set$values = list(mc.cores = 1, quiet = TRUE, importance = "count") + param_set$values = list(mc.cores = 1, quiet = TRUE, importance = "count", + which.curve = 0.5) # 0.5 quantile => median posterior super$initialize( id = "surv.bart", @@ -200,9 +203,10 @@ delayedAssign( c(3, 2, 1) ) - # distr => full survival array - # crank => expected mortality using the mean posterior survival matrix - mlr3proba::.surv_return(times = times, surv = surv.array, which.curve = 'mean') + # distr => 3d survival array + # crank => expected mortality + mlr3proba::.surv_return(times = times, surv = surv.array, + which.curve = pars_pred$which.curve) } ) ) From bdc09afebbe5e22fabeab232dbb530d3ae09459f Mon Sep 17 00:00:00 2001 From: john Date: Wed, 20 Sep 2023 00:40:02 +0300 Subject: [PATCH 31/47] update doc --- R/learner_BART_surv_bart.R | 20 +++++++++++--------- man/mlr_learners_surv.bart.Rd | 24 +++++++++++++++--------- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 31d8b9b6b..68a8060ea 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -4,15 +4,17 @@ #' #' @description #' Fits a Bayesian Additive Regression Trees (BART) learner to right-censored -#' survival data. +#' survival data. Calls [BART::mc.surv.bart()] from \CRANpkg{BART}. #' -#' For prediction, we return the mean posterior estimates of the survival -#' function and the corresponding `crank` (expected mortality) using -#' [mlr3proba::.surv_return]. -#' The full posterior estimates are currently stored in the -#' `learner$model$surv.test` slot. -#' -#' Calls [BART::mc.surv.bart()] from \CRANpkg{BART}. +#' @details +#' Two types of prediction are returned for this learner: +#' 1. `distr`: a 3d survival array with observations as 1st dimension, time +#' points as 2nd and the posterior draws as 3rd dimension. +#' 2. `crank`: the expected mortality using [mlr3proba::.surv_return]. The parameter +#' `which.curve` decides which posterior draw (3rd dimension) will be used for the +#' calculation of the expected mortality. Note that the median posterior is +#' by default used for the calculation of survival measures that require a `distr` +#' prediction, see more info on [PredictionSurv][mlr3proba::PredictionSurv]. #' #' @section Custom mlr3 defaults: #' - `mc.cores` is initialized to 1 to avoid threading conflicts with \CRANpkg{future}. @@ -20,7 +22,7 @@ #' @section Custom mlr3 parameters: #' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is `TRUE` by default. #' - `importance` allows to choose the type of importance. Default is `count`, see `importance()` for more details. -#' - `which.curve` allows to choose which posterior draw will be used for the calculation of the `crank` prediction. By default the median posterior is used. +#' - `which.curve` allows to choose which posterior draw will be used for the calculation of the `crank` prediction. By default the **median posterior** is used. #' #' @templateVar id surv.bart #' @template learner diff --git a/man/mlr_learners_surv.bart.Rd b/man/mlr_learners_surv.bart.Rd index 1fb67cc0f..0ce8e2cd6 100644 --- a/man/mlr_learners_surv.bart.Rd +++ b/man/mlr_learners_surv.bart.Rd @@ -6,15 +6,19 @@ \title{Survival Bayesian Additive Regression Trees Learner} \description{ Fits a Bayesian Additive Regression Trees (BART) learner to right-censored -survival data. - -For prediction, we return the mean posterior estimates of the survival -function and the corresponding \code{crank} (expected mortality) using -\link[mlr3proba:dot-surv_return]{mlr3proba::.surv_return}. -The full posterior estimates are currently stored in the -\code{learner$model$surv.test} slot. - -Calls \code{\link[BART:surv.bart]{BART::mc.surv.bart()}} from \CRANpkg{BART}. +survival data. Calls \code{\link[BART:surv.bart]{BART::mc.surv.bart()}} from \CRANpkg{BART}. +} +\details{ +Two types of prediction are returned for this learner: +\enumerate{ +\item \code{distr}: a 3d survival array with observations as 1st dimension, time +points as 2nd and the posterior draws as 3rd dimension. +\item \code{crank}: the expected mortality using \link[mlr3proba:dot-surv_return]{mlr3proba::.surv_return}. The parameter +\code{which.curve} decides which posterior draw (3rd dimension) will be used for the +calculation of the expected mortality. Note that the median posterior is +by default used for the calculation of survival measures that require a \code{distr} +prediction, see more info on \link[mlr3proba:PredictionSurv]{PredictionSurv}. +} } \section{Custom mlr3 defaults}{ @@ -28,6 +32,7 @@ Calls \code{\link[BART:surv.bart]{BART::mc.surv.bart()}} from \CRANpkg{BART}. \itemize{ \item \code{quiet} allows to suppress messages generated by the wrapped C++ code. Is \code{TRUE} by default. \item \code{importance} allows to choose the type of importance. Default is \code{count}, see \code{importance()} for more details. +\item \code{which.curve} allows to choose which posterior draw will be used for the calculation of the \code{crank} prediction. By default the \strong{median posterior} is used. } } @@ -84,6 +89,7 @@ lrn("surv.bart") openmp \tab logical \tab TRUE \tab TRUE, FALSE \tab - \cr quiet \tab logical \tab TRUE \tab TRUE, FALSE \tab - \cr importance \tab character \tab count \tab count, prob \tab - \cr + which.curve \tab untyped \tab - \tab \tab - \cr } } From 25b9d8dac83f847ff35a6ebf20e2cd4759ea7f44 Mon Sep 17 00:00:00 2001 From: john Date: Wed, 20 Sep 2023 00:42:34 +0300 Subject: [PATCH 32/47] update BART test --- tests/testthat/test_paramtest_BART_surv_bart.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/testthat/test_paramtest_BART_surv_bart.R b/tests/testthat/test_paramtest_BART_surv_bart.R index d0dff88d5..d55c865eb 100644 --- a/tests/testthat/test_paramtest_BART_surv_bart.R +++ b/tests/testthat/test_paramtest_BART_surv_bart.R @@ -29,7 +29,8 @@ test_that("paramset surv.bart predict", { "x.train", # handled internally, used in `surv.pre.bart` "x.test", # handled internally, used in `surv.pre.bart` "nice", # handled internally - "quiet" # added to suppress print messages + "quiet", # added to suppress print messages + "which.curve" # added to choose 3rd dimension (posterior draw) for crank calculation ) paramtest = run_paramtest(learner, fun_list, exclude, tag = "predict") From 6e9bcff2e5d79287e21d0c9032307a40219f3781 Mon Sep 17 00:00:00 2001 From: john Date: Thu, 5 Oct 2023 10:58:05 +0200 Subject: [PATCH 33/47] remove code after checks (distr6 converts survival array correctly) --- R/learner_BART_surv_bart.R | 5 ----- 1 file changed, 5 deletions(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 68a8060ea..398e9a81f 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -193,11 +193,6 @@ delayedAssign( # Number of posterior draws M = nrow(pred$surv.test) - # save the full posterior survival matrix and the mean for checking - # TODO: remove next two lines - self$model$surv.test = pred$surv.test - self$model$surv.test.mean = pred$surv.test.mean - # Convert full posterior survival matrix to 3D survival array # See page 34-35 in Sparapani (2021) for more details surv.array = aperm( From 21e5be9e87055df5418f0669113046e2d46f95c7 Mon Sep 17 00:00:00 2001 From: john Date: Fri, 6 Oct 2023 14:16:17 +0200 Subject: [PATCH 34/47] better constraction of 'which.curve' parameter --- R/learner_BART_surv_bart.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 398e9a81f..bb0bf1f55 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -72,7 +72,7 @@ delayedAssign( openmp = p_lgl(default = TRUE, tags = "predict"), quiet = p_lgl(default = TRUE, tags = "predict"), importance = p_fct(default = "count", levels = c("count", "prob"), tags = "train"), - which.curve = p_uty(tags = "predict") + which.curve = p_dbl(special_vals = list("mean"), tags = "predict") ) # custom defaults From b7064ffdf4946a45acb7cb559387677a9d6cb46d Mon Sep 17 00:00:00 2001 From: john Date: Fri, 6 Oct 2023 15:22:22 +0200 Subject: [PATCH 35/47] fix bug (which.curve was always NULL) --- R/learner_BART_surv_bart.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index bb0bf1f55..5ac5eb852 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -203,7 +203,7 @@ delayedAssign( # distr => 3d survival array # crank => expected mortality mlr3proba::.surv_return(times = times, surv = surv.array, - which.curve = pars_pred$which.curve) + which.curve = pars$which.curve) } ) ) From ffbcdc0cd659882f45d145a25995e158f65cccc5 Mon Sep 17 00:00:00 2001 From: John Zobolas Date: Mon, 16 Oct 2023 17:05:20 +0200 Subject: [PATCH 36/47] Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer --- R/learner_BART_surv_bart.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 5ac5eb852..8dff14d0e 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -20,7 +20,7 @@ #' - `mc.cores` is initialized to 1 to avoid threading conflicts with \CRANpkg{future}. #' #' @section Custom mlr3 parameters: -#' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is `TRUE` by default. +#' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is initialized to `TRUE`. #' - `importance` allows to choose the type of importance. Default is `count`, see `importance()` for more details. #' - `which.curve` allows to choose which posterior draw will be used for the calculation of the `crank` prediction. By default the **median posterior** is used. #' From 362c8459f92b0996622d57d9d5c9d3b340a63bd0 Mon Sep 17 00:00:00 2001 From: John Zobolas Date: Mon, 16 Oct 2023 17:10:10 +0200 Subject: [PATCH 37/47] Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer --- R/learner_BART_surv_bart.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 8dff14d0e..b8ee0cc43 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -92,7 +92,7 @@ delayedAssign( }, #' @description - #' Two types of importance scores are supported based on the initial value + #' Two types of importance scores are supported based on the value #' for the parameter `importance`: #' 1. `prob`: The mean selection probability of each feature in the trees, #' extracted from the slot `varprob.mean`. From 1231d25b0dcad271b0b001001829304f640a1e2e Mon Sep 17 00:00:00 2001 From: John Zobolas Date: Mon, 16 Oct 2023 17:10:46 +0200 Subject: [PATCH 38/47] Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer --- R/learner_BART_surv_bart.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index b8ee0cc43..d4dad378b 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -93,7 +93,7 @@ delayedAssign( #' @description #' Two types of importance scores are supported based on the value - #' for the parameter `importance`: + #' of the parameter `importance`: #' 1. `prob`: The mean selection probability of each feature in the trees, #' extracted from the slot `varprob.mean`. #' If `sparse = FALSE` (default), this is a fixed constant. From 7e30f2a6ec0c315b4540fb4d827c4c78260bda54 Mon Sep 17 00:00:00 2001 From: john Date: Mon, 16 Oct 2023 17:39:25 +0200 Subject: [PATCH 39/47] remove new parameter (to be corrected in another PR) --- tests/testthat/test_paramtest_survivalmodels_surv_dnnsurv.R | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/testthat/test_paramtest_survivalmodels_surv_dnnsurv.R b/tests/testthat/test_paramtest_survivalmodels_surv_dnnsurv.R index 8daaec6aa..005560f9a 100644 --- a/tests/testthat/test_paramtest_survivalmodels_surv_dnnsurv.R +++ b/tests/testthat/test_paramtest_survivalmodels_surv_dnnsurv.R @@ -10,7 +10,6 @@ test_that("paramtest surv.dnnsurv train", { "status_variable", # handled internally "x", # unused "y", # unused - "schedule_decay", # not used "rho", # handled internally "global_clipnorm", # handled internally "use_ema", # handled internally From eeff74fbe511de5adb5fed8a42be0d9865b4b162 Mon Sep 17 00:00:00 2001 From: john Date: Mon, 16 Oct 2023 17:47:38 +0200 Subject: [PATCH 40/47] changes after code review --- R/learner_BART_surv_bart.R | 21 +++++++++++---------- man/mlr_learners_surv.bart.Rd | 6 +++--- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index d4dad378b..592dfc755 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -21,7 +21,7 @@ #' #' @section Custom mlr3 parameters: #' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is initialized to `TRUE`. -#' - `importance` allows to choose the type of importance. Default is `count`, see `importance()` for more details. +#' - `importance` allows to choose the type of importance. Default is `count`, see documentation of method `$importance()` for more details. #' - `which.curve` allows to choose which posterior draw will be used for the calculation of the `crank` prediction. By default the **median posterior** is used. #' #' @templateVar id surv.bart @@ -72,7 +72,7 @@ delayedAssign( openmp = p_lgl(default = TRUE, tags = "predict"), quiet = p_lgl(default = TRUE, tags = "predict"), importance = p_fct(default = "count", levels = c("count", "prob"), tags = "train"), - which.curve = p_dbl(special_vals = list("mean"), tags = "predict") + which.curve = p_dbl(lower = 0L, special_vals = list("mean"), tags = "predict") ) # custom defaults @@ -110,12 +110,12 @@ delayedAssign( stopf("No model stored") } - pars = self$param_set$get_values(tags = 'train') + pars = self$param_set$get_values(tags = "train") - if (pars$importance == 'prob') { - sort(self$model$model$varprob.mean[-1], decreasing = T) + if (pars$importance == "prob") { + sort(self$model$model$varprob.mean[-1], decreasing = TRUE) } else { - sort(self$model$model$varcount.mean[-1], decreasing = T) + sort(self$model$model$varcount.mean[-1], decreasing = TRUE) } } ), @@ -126,8 +126,9 @@ delayedAssign( pars$importance = NULL # not used in the train function x.train = as.data.frame(task$data(cols = task$feature_names)) - times = task$truth()[,1] - delta = task$truth()[,2] # delta => status + truth = task$truth() + times = truth[,1] + delta = truth[,2] # delta => status list( model = invoke( @@ -152,7 +153,7 @@ delayedAssign( x.test = as.data.frame(ordered_features(task, self)) # subset parameters to use in `surv.pre.bart` - pars_pre = pars[names(pars) %in% c('K', 'events', 'ztimes', 'zdelta')] + pars_pre = pars[names(pars) %in% c("K", "events", "ztimes", "zdelta")] # transform data to be suitable for BART survival analysis (needs train data) trans_data = invoke( @@ -165,7 +166,7 @@ delayedAssign( ) # subset parameters to use in `predict` - pars_pred = pars[names(pars) %in% c('mc.cores', 'nice')] + pars_pred = pars[names(pars) %in% c("mc.cores", "nice")] pred_fun = function() { invoke( diff --git a/man/mlr_learners_surv.bart.Rd b/man/mlr_learners_surv.bart.Rd index 0ce8e2cd6..a8a1c42cc 100644 --- a/man/mlr_learners_surv.bart.Rd +++ b/man/mlr_learners_surv.bart.Rd @@ -30,8 +30,8 @@ prediction, see more info on \link[mlr3proba:PredictionSurv]{PredictionSurv}. \section{Custom mlr3 parameters}{ \itemize{ -\item \code{quiet} allows to suppress messages generated by the wrapped C++ code. Is \code{TRUE} by default. -\item \code{importance} allows to choose the type of importance. Default is \code{count}, see \code{importance()} for more details. +\item \code{quiet} allows to suppress messages generated by the wrapped C++ code. Is initialized to \code{TRUE}. +\item \code{importance} allows to choose the type of importance. Default is \code{count}, see documentation of method \verb{$importance()} for more details. \item \code{which.curve} allows to choose which posterior draw will be used for the calculation of the \code{crank} prediction. By default the \strong{median posterior} is used. } } @@ -89,7 +89,7 @@ lrn("surv.bart") openmp \tab logical \tab TRUE \tab TRUE, FALSE \tab - \cr quiet \tab logical \tab TRUE \tab TRUE, FALSE \tab - \cr importance \tab character \tab count \tab count, prob \tab - \cr - which.curve \tab untyped \tab - \tab \tab - \cr + which.curve \tab numeric \tab - \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr } } From 2c0d0a3a04bdc3123fc562f982e17208e44f3701 Mon Sep 17 00:00:00 2001 From: john Date: Mon, 16 Oct 2023 17:51:21 +0200 Subject: [PATCH 41/47] remove delayedAssign + add more review suggestions --- R/learner_BART_surv_bart.R | 331 ++++++++++++++++++------------------- 1 file changed, 165 insertions(+), 166 deletions(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 592dfc755..2fb09247a 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -33,180 +33,179 @@ #' @template seealso_learner #' @template example #' @export -delayedAssign( - "LearnerSurvLearnerSurvBART", R6Class("LearnerSurvLearnerSurvBART", - inherit = mlr3proba::LearnerSurv, - public = list( - #' @description - #' Creates a new instance of this [R6][R6::R6Class] class. - initialize = function() { - param_set = ps( - K = p_dbl(default = NULL, special_vals = list(NULL), lower = 1, tags = c("train", "predict")), - events = p_uty(default = NULL, tags = c("train", "predict")), - ztimes = p_uty(default = NULL, tags = c("train", "predict")), - zdelta = p_uty(default = NULL, tags = c("train", "predict")), - sparse = p_lgl(default = FALSE, tags = "train"), - theta = p_dbl(default = 0, tags = "train"), - omega = p_dbl(default = 1, tags = "train"), - a = p_dbl(default = 0.5, lower = 0.5, upper = 1, tags = "train"), - b = p_dbl(default = 1L, tags = "train"), - augment = p_lgl(default = FALSE, tags = "train"), - rho = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"), - usequants = p_lgl(default = FALSE, tags = "train"), - rm.const = p_lgl(default = TRUE, tags = "train"), - type = p_fct(levels = c("pbart", "lbart"), default = "pbart", tags = "train"), - ntype = p_int(lower = 1, upper = 3, tags = "train"), - k = p_dbl(default = 2.0, lower = 0, tags = "train"), - power = p_dbl(default = 2.0, lower = 0, tags = "train"), - base = p_dbl(default = 0.95, lower = 0, upper = 1, tags = "train"), - offset = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"), - ntree = p_int(default = 50L, lower = 1L, tags = "train"), - numcut = p_int(default = 100L, lower = 1L, tags = "train"), - ndpost = p_int(default = 1000L, lower = 1L, tags = "train"), - nskip = p_int(default = 250L, lower = 0L, tags = "train"), - keepevery = p_int(default = 10L, lower = 1L, tags = "train"), - printevery = p_int(default = 100L, lower = 1L, tags = "train"), - seed = p_int(default = 99L, tags = "train"), - mc.cores = p_int(default = 2L, lower = 1L, tags = c("train", "predict")), - nice = p_int(default = 19L, lower = 0L, upper = 19L, tags = c("train", "predict")), - openmp = p_lgl(default = TRUE, tags = "predict"), - quiet = p_lgl(default = TRUE, tags = "predict"), - importance = p_fct(default = "count", levels = c("count", "prob"), tags = "train"), - which.curve = p_dbl(lower = 0L, special_vals = list("mean"), tags = "predict") - ) +LearnerSurvLearnerSurvBART = R6Class("LearnerSurvLearnerSurvBART", + inherit = mlr3proba::LearnerSurv, + public = list( + #' @description + #' Creates a new instance of this [R6][R6::R6Class] class. + initialize = function() { + param_set = ps( + K = p_dbl(default = NULL, special_vals = list(NULL), lower = 1, tags = c("train", "predict")), + events = p_uty(default = NULL, tags = c("train", "predict")), + ztimes = p_uty(default = NULL, tags = c("train", "predict")), + zdelta = p_uty(default = NULL, tags = c("train", "predict")), + sparse = p_lgl(default = FALSE, tags = "train"), + theta = p_dbl(default = 0, tags = "train"), + omega = p_dbl(default = 1, tags = "train"), + a = p_dbl(default = 0.5, lower = 0.5, upper = 1, tags = "train"), + b = p_dbl(default = 1L, tags = "train"), + augment = p_lgl(default = FALSE, tags = "train"), + rho = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"), + usequants = p_lgl(default = FALSE, tags = "train"), + rm.const = p_lgl(default = TRUE, tags = "train"), + type = p_fct(levels = c("pbart", "lbart"), default = "pbart", tags = "train"), + ntype = p_int(lower = 1, upper = 3, tags = "train"), + k = p_dbl(default = 2.0, lower = 0, tags = "train"), + power = p_dbl(default = 2.0, lower = 0, tags = "train"), + base = p_dbl(default = 0.95, lower = 0, upper = 1, tags = "train"), + offset = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"), + ntree = p_int(default = 50L, lower = 1L, tags = "train"), + numcut = p_int(default = 100L, lower = 1L, tags = "train"), + ndpost = p_int(default = 1000L, lower = 1L, tags = "train"), + nskip = p_int(default = 250L, lower = 0L, tags = "train"), + keepevery = p_int(default = 10L, lower = 1L, tags = "train"), + printevery = p_int(default = 100L, lower = 1L, tags = "train"), + seed = p_int(default = 99L, tags = "train"), + mc.cores = p_int(default = 2L, lower = 1L, tags = c("train", "predict")), + nice = p_int(default = 19L, lower = 0L, upper = 19L, tags = c("train", "predict")), + openmp = p_lgl(default = TRUE, tags = "predict"), + quiet = p_lgl(default = TRUE, tags = "predict"), + importance = p_fct(default = "count", levels = c("count", "prob"), tags = "train"), + which.curve = p_dbl(lower = 0L, special_vals = list("mean"), tags = "predict") + ) + + # custom defaults + param_set$values = list(mc.cores = 1, quiet = TRUE, importance = "count", + which.curve = 0.5) # 0.5 quantile => median posterior + + super$initialize( + id = "surv.bart", + packages = "BART", + feature_types = c("logical", "integer", "numeric"), + predict_types = c("crank", "distr"), + param_set = param_set, + properties = c("importance", "missings"), + man = "mlr3extralearners::mlr_learners_surv.bart", + label = "Bayesian Additive Regression Trees" + ) + }, + + #' @description + #' Two types of importance scores are supported based on the value + #' of the parameter `importance`: + #' 1. `prob`: The mean selection probability of each feature in the trees, + #' extracted from the slot `varprob.mean`. + #' If `sparse = FALSE` (default), this is a fixed constant. + #' Recommended to use this option when `sparse = TRUE`. + #' 2. `count`: The mean observed count of each feature in the trees (average + #' number of times the feature was used in a tree decision rule across all + #' posterior draws), extracted from the slot `varcount.mean`. + #' This is the default importance scores. + #' + #' In both cases, higher values signify more important variables. + #' + #' @return Named `numeric()`. + importance = function() { + if (is.null(self$model$model)) { + stopf("No model stored") + } - # custom defaults - param_set$values = list(mc.cores = 1, quiet = TRUE, importance = "count", - which.curve = 0.5) # 0.5 quantile => median posterior - - super$initialize( - id = "surv.bart", - packages = "BART", - feature_types = c("logical", "integer", "numeric"), - predict_types = c("crank", "distr"), - param_set = param_set, - properties = c("importance", "missings"), - man = "mlr3extralearners::mlr_learners_surv.bart", - label = "Bayesian Additive Regression Trees" - ) - }, - - #' @description - #' Two types of importance scores are supported based on the value - #' of the parameter `importance`: - #' 1. `prob`: The mean selection probability of each feature in the trees, - #' extracted from the slot `varprob.mean`. - #' If `sparse = FALSE` (default), this is a fixed constant. - #' Recommended to use this option when `sparse = TRUE`. - #' 2. `count`: The observed count of each feature in the trees, extracted - #' from the slot `varcount.mean`. - #' This is the default importance scores. - #' - #' In both cases, higher values signify more important variables. - #' - #' @return Named `numeric()`. - importance = function() { - if (is.null(self$model$model)) { - stopf("No model stored") - } - - pars = self$param_set$get_values(tags = "train") - - if (pars$importance == "prob") { - sort(self$model$model$varprob.mean[-1], decreasing = TRUE) - } else { - sort(self$model$model$varcount.mean[-1], decreasing = TRUE) - } + pars = self$param_set$get_values(tags = "train") + + if (pars$importance == "prob") { + sort(self$model$model$varprob.mean[-1], decreasing = TRUE) + } else { + sort(self$model$model$varcount.mean[-1], decreasing = TRUE) } - ), - - private = list( - .train = function(task) { - pars = self$param_set$get_values(tags = "train") - pars$importance = NULL # not used in the train function - - x.train = as.data.frame(task$data(cols = task$feature_names)) - truth = task$truth() - times = truth[,1] - delta = truth[,2] # delta => status - - list( - model = invoke( - BART::mc.surv.bart, - x.train = x.train, - times = times, - delta = delta, - .args = pars - ), - # need these for predict + } + ), + + private = list( + .train = function(task) { + pars = self$param_set$get_values(tags = "train") + pars$importance = NULL # not used in the train function + + x.train = as.data.frame(task$data(cols = task$feature_names)) + truth = task$truth() + times = truth[,1] + delta = truth[,2] # delta => status + + list( + model = invoke( + BART::mc.surv.bart, x.train = x.train, times = times, - delta = delta - ) - }, - - .predict = function(task) { - # get parameters with tag "predict" - pars = self$param_set$get_values(tags = "predict") - - # get newdata and ensure same ordering in train and predict - x.test = as.data.frame(ordered_features(task, self)) - - # subset parameters to use in `surv.pre.bart` - pars_pre = pars[names(pars) %in% c("K", "events", "ztimes", "zdelta")] - - # transform data to be suitable for BART survival analysis (needs train data) - trans_data = invoke( - BART::surv.pre.bart, - times = self$model$times, - delta = self$model$delta, - x.train = self$model$x.train, - x.test = x.test, - .args = pars_pre + delta = delta, + .args = pars + ), + # need these for predict + x.train = x.train, + times = times, + delta = delta + ) + }, + + .predict = function(task) { + # get parameters with tag "predict" + pars = self$param_set$get_values(tags = "predict") + + # get newdata and ensure same ordering in train and predict + x.test = as.data.frame(ordered_features(task, self)) + + # subset parameters to use in `surv.pre.bart` + pars_pre = pars[names(pars) %in% c("K", "events", "ztimes", "zdelta")] + + # transform data to be suitable for BART survival analysis (needs train data) + trans_data = invoke( + BART::surv.pre.bart, + times = self$model$times, + delta = self$model$delta, + x.train = self$model$x.train, + x.test = x.test, + .args = pars_pre + ) + + # subset parameters to use in `predict` + pars_pred = pars[names(pars) %in% c("mc.cores", "nice")] + + pred_fun = function() { + invoke( + predict, + self$model$model, + newdata = trans_data$tx.test, + .args = pars_pred ) + } - # subset parameters to use in `predict` - pars_pred = pars[names(pars) %in% c("mc.cores", "nice")] - - pred_fun = function() { - invoke( - predict, - self$model$model, - newdata = trans_data$tx.test, - .args = pars_pred - ) - } - - # don't print C++ generated info during prediction - if (pars$quiet) { - utils::capture.output({ - pred = pred_fun() - }) - } else { + # don't print C++ generated info during prediction + if (pars$quiet) { + utils::capture.output({ pred = pred_fun() - } - - # Number of test observations - N = task$nrow - # Number of unique times - K = pred$K - times = pred$times - # Number of posterior draws - M = nrow(pred$surv.test) - - # Convert full posterior survival matrix to 3D survival array - # See page 34-35 in Sparapani (2021) for more details - surv.array = aperm( - array(pred$surv.test, dim = c(M, K, N), dimnames = list(NULL, times, NULL)), - c(3, 2, 1) - ) - - # distr => 3d survival array - # crank => expected mortality - mlr3proba::.surv_return(times = times, surv = surv.array, - which.curve = pars$which.curve) + }) + } else { + pred = pred_fun() } - ) + + # Number of test observations + N = task$nrow + # Number of unique times + K = pred$K + times = pred$times + # Number of posterior draws + M = nrow(pred$surv.test) + + # Convert full posterior survival matrix to 3D survival array + # See page 34-35 in Sparapani (2021) for more details + surv.array = aperm( + array(pred$surv.test, dim = c(M, K, N), dimnames = list(NULL, times, NULL)), + c(3, 2, 1) + ) + + # distr => 3d survival array + # crank => expected mortality + mlr3proba::.surv_return(times = times, surv = surv.array, + which.curve = pars$which.curve) + } ) ) From 594ed5830f46d0adc5cfc86a9108a9702d242646 Mon Sep 17 00:00:00 2001 From: john Date: Mon, 16 Oct 2023 17:58:36 +0200 Subject: [PATCH 42/47] explain better 'varcount.mean' --- man/mlr_learners_surv.bart.Rd | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/man/mlr_learners_surv.bart.Rd b/man/mlr_learners_surv.bart.Rd index a8a1c42cc..e90a5fe8f 100644 --- a/man/mlr_learners_surv.bart.Rd +++ b/man/mlr_learners_surv.bart.Rd @@ -169,8 +169,9 @@ for the parameter \code{importance}: extracted from the slot \code{varprob.mean}. If \code{sparse = FALSE} (default), this is a fixed constant. Recommended to use this option when \code{sparse = TRUE}. -\item \code{count}: The observed count of each feature in the trees, extracted -from the slot \code{varcount.mean}. +\item \code{count}: The mean observed count of each feature in the trees (average +number of times the feature was used in a tree decision rule across all +posterior draws), extracted from the slot \code{varcount.mean}. This is the default importance scores. } From 4204c2d18e21a05d3148c2a82b9b82c28360078d Mon Sep 17 00:00:00 2001 From: john Date: Mon, 16 Oct 2023 18:04:41 +0200 Subject: [PATCH 43/47] small update of BART doc --- man/mlr_learners_surv.bart.Rd | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/man/mlr_learners_surv.bart.Rd b/man/mlr_learners_surv.bart.Rd index e90a5fe8f..d937cf312 100644 --- a/man/mlr_learners_surv.bart.Rd +++ b/man/mlr_learners_surv.bart.Rd @@ -162,8 +162,8 @@ Creates a new instance of this \link[R6:R6Class]{R6} class. \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-LearnerSurvLearnerSurvBART-importance}{}}} \subsection{Method \code{importance()}}{ -Two types of importance scores are supported based on the initial value -for the parameter \code{importance}: +Two types of importance scores are supported based on the value +of the parameter \code{importance}: \enumerate{ \item \code{prob}: The mean selection probability of each feature in the trees, extracted from the slot \code{varprob.mean}. From 1e66dc9c4b2601b483090cfb7e23fd9faf3f300f Mon Sep 17 00:00:00 2001 From: john Date: Mon, 16 Oct 2023 19:54:36 +0200 Subject: [PATCH 44/47] add more doc for 'which.curve' --- R/learner_BART_surv_bart.R | 2 +- man/mlr_learners_surv.bart.Rd | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 2fb09247a..376d44b5b 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -22,7 +22,7 @@ #' @section Custom mlr3 parameters: #' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is initialized to `TRUE`. #' - `importance` allows to choose the type of importance. Default is `count`, see documentation of method `$importance()` for more details. -#' - `which.curve` allows to choose which posterior draw will be used for the calculation of the `crank` prediction. By default the **median posterior** is used. +#' - `which.curve` allows to choose which posterior draw will be used for the calculation of the `crank` prediction. If between (0,1) it is taken as the quantile of the curves otherwise if greater than 1 it is taken as the curve index, can also be 'mean'. By default the **median posterior** is used, i.e. `which.curve` is 0.5. #' #' @templateVar id surv.bart #' @template learner diff --git a/man/mlr_learners_surv.bart.Rd b/man/mlr_learners_surv.bart.Rd index d937cf312..3045803f6 100644 --- a/man/mlr_learners_surv.bart.Rd +++ b/man/mlr_learners_surv.bart.Rd @@ -32,7 +32,7 @@ prediction, see more info on \link[mlr3proba:PredictionSurv]{PredictionSurv}. \itemize{ \item \code{quiet} allows to suppress messages generated by the wrapped C++ code. Is initialized to \code{TRUE}. \item \code{importance} allows to choose the type of importance. Default is \code{count}, see documentation of method \verb{$importance()} for more details. -\item \code{which.curve} allows to choose which posterior draw will be used for the calculation of the \code{crank} prediction. By default the \strong{median posterior} is used. +\item \code{which.curve} allows to choose which posterior draw will be used for the calculation of the \code{crank} prediction. If between (0,1) it is taken as the quantile of the curves otherwise if greater than 1 it is taken as the curve index, can also be 'mean'. By default the \strong{median posterior} is used, i.e. \code{which.curve} is 0.5. } } From 83d09b99be151b2a9838121991f1645c0cfcff3b Mon Sep 17 00:00:00 2001 From: john Date: Mon, 16 Oct 2023 23:18:03 +0200 Subject: [PATCH 45/47] small style changes --- R/learner_BART_surv_bart.R | 20 +++++++++++++------- man/mlr_learners_surv.bart.Rd | 12 +++++++++--- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 376d44b5b..4ec31f464 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -20,9 +20,15 @@ #' - `mc.cores` is initialized to 1 to avoid threading conflicts with \CRANpkg{future}. #' #' @section Custom mlr3 parameters: -#' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is initialized to `TRUE`. -#' - `importance` allows to choose the type of importance. Default is `count`, see documentation of method `$importance()` for more details. -#' - `which.curve` allows to choose which posterior draw will be used for the calculation of the `crank` prediction. If between (0,1) it is taken as the quantile of the curves otherwise if greater than 1 it is taken as the curve index, can also be 'mean'. By default the **median posterior** is used, i.e. `which.curve` is 0.5. +#' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is +#' initialized to `TRUE`. +#' - `importance` allows to choose the type of importance. Default is `count`, +#' see documentation of method `$importance()` for more details. +#' - `which.curve` allows to choose which posterior draw will be used for the +#' calculation of the `crank` prediction. If between (0,1) it is taken as the +#' quantile of the curves otherwise if greater than 1 it is taken as the curve +#' index, can also be 'mean'. By default the **median posterior** is used, +#' i.e. `which.curve` is 0.5. #' #' @templateVar id surv.bart #' @template learner @@ -127,8 +133,8 @@ LearnerSurvLearnerSurvBART = R6Class("LearnerSurvLearnerSurvBART", x.train = as.data.frame(task$data(cols = task$feature_names)) truth = task$truth() - times = truth[,1] - delta = truth[,2] # delta => status + times = truth[, 1] + delta = truth[, 2] # delta => status list( model = invoke( @@ -196,14 +202,14 @@ LearnerSurvLearnerSurvBART = R6Class("LearnerSurvLearnerSurvBART", # Convert full posterior survival matrix to 3D survival array # See page 34-35 in Sparapani (2021) for more details - surv.array = aperm( + surv_array = aperm( array(pred$surv.test, dim = c(M, K, N), dimnames = list(NULL, times, NULL)), c(3, 2, 1) ) # distr => 3d survival array # crank => expected mortality - mlr3proba::.surv_return(times = times, surv = surv.array, + mlr3proba::.surv_return(times = times, surv = surv_array, which.curve = pars$which.curve) } ) diff --git a/man/mlr_learners_surv.bart.Rd b/man/mlr_learners_surv.bart.Rd index 3045803f6..3c72ce758 100644 --- a/man/mlr_learners_surv.bart.Rd +++ b/man/mlr_learners_surv.bart.Rd @@ -30,9 +30,15 @@ prediction, see more info on \link[mlr3proba:PredictionSurv]{PredictionSurv}. \section{Custom mlr3 parameters}{ \itemize{ -\item \code{quiet} allows to suppress messages generated by the wrapped C++ code. Is initialized to \code{TRUE}. -\item \code{importance} allows to choose the type of importance. Default is \code{count}, see documentation of method \verb{$importance()} for more details. -\item \code{which.curve} allows to choose which posterior draw will be used for the calculation of the \code{crank} prediction. If between (0,1) it is taken as the quantile of the curves otherwise if greater than 1 it is taken as the curve index, can also be 'mean'. By default the \strong{median posterior} is used, i.e. \code{which.curve} is 0.5. +\item \code{quiet} allows to suppress messages generated by the wrapped C++ code. Is +initialized to \code{TRUE}. +\item \code{importance} allows to choose the type of importance. Default is \code{count}, +see documentation of method \verb{$importance()} for more details. +\item \code{which.curve} allows to choose which posterior draw will be used for the +calculation of the \code{crank} prediction. If between (0,1) it is taken as the +quantile of the curves otherwise if greater than 1 it is taken as the curve +index, can also be 'mean'. By default the \strong{median posterior} is used, +i.e. \code{which.curve} is 0.5. } } From 3b4db4380e0796cd21621d52b6630eb4ca6db3a9 Mon Sep 17 00:00:00 2001 From: john Date: Tue, 17 Oct 2023 09:19:50 +0200 Subject: [PATCH 46/47] fix hanging indent --- R/learner_BART_surv_bart.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 4ec31f464..6ce20cf84 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -82,7 +82,7 @@ LearnerSurvLearnerSurvBART = R6Class("LearnerSurvLearnerSurvBART", # custom defaults param_set$values = list(mc.cores = 1, quiet = TRUE, importance = "count", - which.curve = 0.5) # 0.5 quantile => median posterior + which.curve = 0.5) # 0.5 quantile => median posterior super$initialize( id = "surv.bart", @@ -210,7 +210,7 @@ LearnerSurvLearnerSurvBART = R6Class("LearnerSurvLearnerSurvBART", # distr => 3d survival array # crank => expected mortality mlr3proba::.surv_return(times = times, surv = surv_array, - which.curve = pars$which.curve) + which.curve = pars$which.curve) } ) ) From b3699a9fad72af37c02cb4b9283b8a7e6ffe4267 Mon Sep 17 00:00:00 2001 From: john Date: Tue, 17 Oct 2023 09:24:31 +0200 Subject: [PATCH 47/47] add no lint --- R/learner_BART_surv_bart.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 6ce20cf84..3c35f4c01 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -131,7 +131,7 @@ LearnerSurvLearnerSurvBART = R6Class("LearnerSurvLearnerSurvBART", pars = self$param_set$get_values(tags = "train") pars$importance = NULL # not used in the train function - x.train = as.data.frame(task$data(cols = task$feature_names)) + x.train = as.data.frame(task$data(cols = task$feature_names)) # nolint truth = task$truth() times = truth[, 1] delta = truth[, 2] # delta => status @@ -156,7 +156,7 @@ LearnerSurvLearnerSurvBART = R6Class("LearnerSurvLearnerSurvBART", pars = self$param_set$get_values(tags = "predict") # get newdata and ensure same ordering in train and predict - x.test = as.data.frame(ordered_features(task, self)) + x.test = as.data.frame(ordered_features(task, self)) # nolint # subset parameters to use in `surv.pre.bart` pars_pre = pars[names(pars) %in% c("K", "events", "ztimes", "zdelta")]